aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-08-26 14:26:54 +0800
committerDeterminant <ted.sybil@gmail.com>2015-08-26 14:26:54 +0800
commite81e9832ec4f2ad031fd42b5018cea134e8cda7e (patch)
treeed49289619399a99c80f47398ccc4de9ae4cedf6 /nerv/nn
parented2a4148dbb9c18f428571b3e2970d7b2adfb058 (diff)
move global_transf to asr_trainer.lua
Diffstat (limited to 'nerv/nn')
-rw-r--r--nerv/nn/layer_dag.lua27
1 files changed, 27 insertions, 0 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
index e9d4d86..25297c2 100644
--- a/nerv/nn/layer_dag.lua
+++ b/nerv/nn/layer_dag.lua
@@ -254,3 +254,30 @@ function DAGLayer:get_params()
end
return nerv.ParamRepo.merge(param_repos)
end
+
+DAGLayer.PORT_TYPES = {
+ INPUT = {},
+ OUTPUT = {},
+ ERR_INPUT = {},
+ ERR_OUTPUT = {}
+}
+
+function DAGLayer:get_intermediate(id, port_type)
+ if id == "<input>" or id == "<output>" then
+ nerv.error("an actual real layer id is expected")
+ end
+ local layer = layers[id]
+ if layer == nil then
+ nerv.error("layer id %s not found", id)
+ end
+ if port_type == DAGLayer.PORT_TYPES.INPUT then
+ return layer.inputs
+ elseif port_type == DAGLayer.PORT_TYPES.OUTPUT then
+ return layer.outputs
+ elseif port_type == DAGLayer.PORT_TYPES.ERR_INPUT then
+ return layer.err_inputs
+ elseif port_type == DAGLayer.PORT_TYPES.ERR_OUTPUT then
+ return layer.err_outputs
+ end
+ nerv.error("unrecognized port type")
+end