diff options
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r-- | nerv/examples/lmptb/rnn/tnn.lua | 31 |
1 files changed, 1 insertions, 30 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua index fc5321d..9850fe5 100644 --- a/nerv/examples/lmptb/rnn/tnn.lua +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -1,5 +1,4 @@ local TNN = nerv.class("nerv.TNN", "nerv.Layer") -local DAGLayer = TNN local function parse_id(str) --used to parse layerid[portid],time @@ -91,7 +90,7 @@ end function TNN:__init(id, global_conf, layer_conf) local layers = {} - local inputs_p = {} --map:port of the TDAGLayer to layer ref and port + local inputs_p = {} --map:port of the TNN to layer ref and port local outputs_p = {} local dim_in = layer_conf.dim_in local dim_out = layer_conf.dim_out @@ -394,7 +393,6 @@ function TNN:propagate_dfs(ref, t) if (seq_start > 0 or seq_end > 0) then for p, conn in pairs(ref.o_conns_p) do if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then - self.gconf.fz2 = self.gconf.fz2 + 1 ref.outputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default) end end @@ -502,7 +500,6 @@ function TNN:backpropagate_dfs(ref, t, do_update) if (seq_start > 0 or seq_end > 0) then for p, conn in pairs(ref.i_conns_p) do if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to zero - self.gconf.fz = self.gconf.fz + 1 ref.err_outputs_m[t][p][i - 1]:fill(0) end end @@ -534,29 +531,3 @@ function TNN:get_params() return nerv.ParamRepo.merge(param_repos) end -DAGLayer.PORT_TYPES = { - INPUT = {}, - OUTPUT = {}, - ERR_INPUT = {}, - ERR_OUTPUT = {} -} - -function DAGLayer:get_intermediate(id, port_type) - if id == "<input>" or id == "<output>" then - nerv.error("an actual real layer id is expected") - end - local layer = self.layers[id] - if layer == nil then - nerv.error("layer id %s not found", id) - end - if port_type == DAGLayer.PORT_TYPES.INPUT then - return layer.inputs - elseif port_type == DAGLayer.PORT_TYPES.OUTPUT then - return layer.outputs - elseif port_type == DAGLayer.PORT_TYPES.ERR_INPUT then - return layer.err_inputs - elseif port_type == DAGLayer.PORT_TYPES.ERR_OUTPUT then - return layer.err_outputs - end - nerv.error("unrecognized port type") -end |