aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/rnn/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua31
1 files changed, 1 insertions, 30 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index fc5321d..9850fe5 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -1,5 +1,4 @@
local TNN = nerv.class("nerv.TNN", "nerv.Layer")
-local DAGLayer = TNN
local function parse_id(str)
--used to parse layerid[portid],time
@@ -91,7 +90,7 @@ end
function TNN:__init(id, global_conf, layer_conf)
local layers = {}
- local inputs_p = {} --map:port of the TDAGLayer to layer ref and port
+ local inputs_p = {} --map:port of the TNN to layer ref and port
local outputs_p = {}
local dim_in = layer_conf.dim_in
local dim_out = layer_conf.dim_out
@@ -394,7 +393,6 @@ function TNN:propagate_dfs(ref, t)
if (seq_start > 0 or seq_end > 0) then
for p, conn in pairs(ref.o_conns_p) do
if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then
- self.gconf.fz2 = self.gconf.fz2 + 1
ref.outputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
end
end
@@ -502,7 +500,6 @@ function TNN:backpropagate_dfs(ref, t, do_update)
if (seq_start > 0 or seq_end > 0) then
for p, conn in pairs(ref.i_conns_p) do
if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to zero
- self.gconf.fz = self.gconf.fz + 1
ref.err_outputs_m[t][p][i - 1]:fill(0)
end
end
@@ -534,29 +531,3 @@ function TNN:get_params()
return nerv.ParamRepo.merge(param_repos)
end
-DAGLayer.PORT_TYPES = {
- INPUT = {},
- OUTPUT = {},
- ERR_INPUT = {},
- ERR_OUTPUT = {}
-}
-
-function DAGLayer:get_intermediate(id, port_type)
- if id == "<input>" or id == "<output>" then
- nerv.error("an actual real layer id is expected")
- end
- local layer = self.layers[id]
- if layer == nil then
- nerv.error("layer id %s not found", id)
- end
- if port_type == DAGLayer.PORT_TYPES.INPUT then
- return layer.inputs
- elseif port_type == DAGLayer.PORT_TYPES.OUTPUT then
- return layer.outputs
- elseif port_type == DAGLayer.PORT_TYPES.ERR_INPUT then
- return layer.err_inputs
- elseif port_type == DAGLayer.PORT_TYPES.ERR_OUTPUT then
- return layer.err_outputs
- end
- nerv.error("unrecognized port type")
-end