aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua70
1 files changed, 56 insertions, 14 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index 460fcc4..019d24c 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -34,16 +34,16 @@ local function discover(id, layers, layer_repo)
layer = layer,
inputs_m = {}, --storage for computation, inputs_m[time][port]
inputs_b = {}, --inputs_g[time][port], whether this input can been computed
- inputs_p_matbak = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port]
+ inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port]
outputs_m = {},
outputs_b = {},
err_inputs_m = {},
- err_inputs_p_matbak = {}, --which is a back-up space to handle some cross-border computation
+ err_inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation
err_inputs_b = {},
err_outputs_m = {},
err_outputs_b = {},
- conns_i = {}, --list of inputing connections
- conns_o = {}, --list of outputing connections
+ i_conns_p = {}, --list of inputing connections
+ o_conns_p = {}, --list of outputing connections
dim_in = dim_in, --list of dimensions of ports
dim_out = dim_out,
}
@@ -130,8 +130,8 @@ function TNN:__init(id, global_conf, layer_conf)
nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
end
table.insert(parsed_conns, conn_now)
- table.insert(ref_to.conns_i, conn_now)
- table.insert(ref_from.conns_o, conn_now)
+ ref_to.i_conns_p[conn_now.dst.port] = conn_now
+ ref_from.o_conns_p[conn_now.src.port] = conn_now
end
end
@@ -161,9 +161,9 @@ function TNN:init(batch_size, chunk_size)
end
print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id)
- ref_to.inputs_p_matbak[port_to] = self.gconf.cumat_type(batch_size, dim)
+ ref_to.inputs_matbak_p[port_to] = self.gconf.cumat_type(batch_size, dim)
self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to, time)
- ref_from.err_inputs_p_matbak[port_from] = self.gconf.cumat_type(batch_size, dim)
+ ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim)
self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to, time)
end
@@ -289,7 +289,7 @@ function TNN:getFeedFromReader(reader)
end
function TNN:moveRightToNextMB() --move output history activations of 1..chunk_size to 1-chunk_size..0
- for t = self.chunk_size, 1, -1 do
+ for t = 1, self.chunk_size, 1 do
for id, ref in pairs(self.layers) do
for p = 1, #ref.dim_out do
ref.outputs_m[t - self.chunk_size][p]:copy_fromd(ref.outputs_m[t][p])
@@ -336,7 +336,7 @@ function TNN:propagate_dfs(ref, t)
--print("debug dfs", ref.layer.id, t)
local flag = true --whether have all inputs
- for _, conn in pairs(ref.conns_i) do
+ for _, conn in pairs(ref.i_conns_p) do
local p = conn.dst.port
if (not (ref.inputs_b[t][p] or self:outOfFeedRange(t - conn.time))) then
flag = false
@@ -349,7 +349,36 @@ function TNN:propagate_dfs(ref, t)
--ok, do propagate
--print("debug ok, propagating");
- ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t])
+ if (bit.bor(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border history
+ for i = 1, self.batch_size do
+ local seq_start = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ if (seq_start > 0 or seq_end > 0) then
+ for p = 1, #ref.i_conns_p do
+ if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to default
+ ref.inputs_matbak_p[p][i - 1]:copy_fromd(ref.inputs_m[t][p][i - 1])
+ ref.inputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
+ end
+ end
+ end
+ end
+ end
+ ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t]) --propagate!
+ if (bit.bor(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history
+ for i = 1, self.batch_size do
+ local seq_start = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ if (seq_start > 0 or seq_end > 0) then
+ for p = 1, #ref.i_conns_p do
+ if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then
+ ref.inputs_m[t][p][i - 1]:copy_fromd(ref.inputs_matbak_p[p][i - 1])
+ end
+ end
+ end
+ end
+ end
+
+ --set input flag for future layers
for i = 1, #ref.dim_out do
if (ref.outputs_b[t][i] == true) then
nerv.error("this time's outputs_b should be false")
@@ -358,7 +387,7 @@ function TNN:propagate_dfs(ref, t)
end
--try dfs for further layers
- for _, conn in pairs(ref.conns_o) do
+ for _, conn in pairs(ref.o_conns_p) do
--print("debug dfs-searching", conn.dst.ref.layer.id)
conn.dst.ref.inputs_b[t + conn.time][conn.dst.port] = true
self:propagate_dfs(conn.dst.ref, t + conn.time)
@@ -407,7 +436,7 @@ function TNN:backpropagate_dfs(ref, t, do_update)
--print("debug dfs", ref.layer.id, t)
local flag = true --whether have all inputs
- for _, conn in pairs(ref.conns_o) do
+ for _, conn in pairs(ref.o_conns_p) do
local p = conn.src.port
if (not (ref.err_inputs_b[t][p] or self:outOfFeedRange(t + conn.time))) then
flag = false
@@ -420,6 +449,19 @@ function TNN:backpropagate_dfs(ref, t, do_update)
--ok, do back_propagate
--print("debug ok, back-propagating(or updating)")
+ if (bit.bor(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors
+ for i = 1, self.batch_size do
+ local seq_start = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ if (seq_start > 0 or seq_end > 0) then
+ for p = 1, #ref.o_conns_p do
+ if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then --cross-border, set to zero
+ ref.err_inputs_m[t][p][i - 1]:fill(0)
+ end
+ end
+ end
+ end
+ end
if (do_update == false) then
ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t])
else
@@ -434,7 +476,7 @@ function TNN:backpropagate_dfs(ref, t, do_update)
end
--try dfs for further layers
- for _, conn in pairs(ref.conns_i) do
+ for _, conn in pairs(ref.i_conns_p) do
--print("debug dfs-searching", conn.src.ref.layer.id)
conn.src.ref.err_inputs_b[t - conn.time][conn.src.port] = true
self:backpropagate_dfs(conn.src.ref, t - conn.time, do_update)