summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/m-tests/tnn_test.lua4
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua28
2 files changed, 16 insertions, 16 deletions
diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua
index 7a8519e..276ced4 100644
--- a/nerv/examples/lmptb/m-tests/tnn_test.lua
+++ b/nerv/examples/lmptb/m-tests/tnn_test.lua
@@ -153,10 +153,6 @@ function lm_process_file(global_conf, fn, tnn, do_train)
tnn:flush_all() --caution: will also flush the inputs from the reader!
- for t = 1, global_conf.chunk_size do
- tnn.err_inputs_m[t][1]:fill(1)
- end
-
local next_log_wcn = global_conf.log_w_num
while (1) do
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index f470190..dfcef63 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -135,6 +135,10 @@ function TNN:__init(id, global_conf, layer_conf)
end
end
+ for id, ref in pairs(layers) do
+ print(id, "#dim_in:", #ref.dim_in, "#dim_out:", #ref.dim_out, "#i_conns_p:", #ref.i_conns_p, "#o_conns_p", #ref.o_conns_p)
+ end
+
self.layers = layers
self.inputs_p = inputs_p
self.outputs_p = outputs_p
@@ -365,12 +369,12 @@ function TNN:propagate_dfs(ref, t)
--ok, do propagate
--print("debug ok, propagating");
- if (bit.bor(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border history
+ if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border history
for i = 1, self.batch_size do
- local seq_start = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
- local seq_end = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
if (seq_start > 0 or seq_end > 0) then
- for p = 1, #ref.i_conns_p do
+ for p, conn in pairs(ref.i_conns_p) do
if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to default
ref.inputs_matbak_p[p][i - 1]:copy_fromd(ref.inputs_m[t][p][i - 1])
ref.inputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
@@ -380,12 +384,12 @@ function TNN:propagate_dfs(ref, t)
end
end
ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate!
- if (bit.bor(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history
+ if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history
for i = 1, self.batch_size do
- local seq_start = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
- local seq_end = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
if (seq_start > 0 or seq_end > 0) then
- for p = 1, #ref.i_conns_p do
+ for p, conn in pairs(ref.i_conns_p) do
if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then
ref.inputs_m[t][p][i - 1]:copy_fromd(ref.inputs_matbak_p[p][i - 1])
end
@@ -481,12 +485,12 @@ function TNN:backpropagate_dfs(ref, t, do_update)
--ok, do back_propagate
--print("debug ok, back-propagating(or updating)")
- if (bit.bor(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors
+ if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors
for i = 1, self.batch_size do
- local seq_start = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
- local seq_end = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
if (seq_start > 0 or seq_end > 0) then
- for p = 1, #ref.o_conns_p do
+ for p, conn in pairs(ref.o_conns_p) do
if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then --cross-border, set to zero
ref.err_inputs_m[t][p][i - 1]:fill(0)
end