diff options
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 2 | ||||
-rw-r--r-- | nerv/tnn/tnn.lua | 13 |
2 files changed, 13 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 3b8b5c3..ecedc9f 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -196,7 +196,6 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) if r == false then break end - for t = 1, chunk_size do tnn.err_inputs_m[t][1]:fill(1) for i = 1, batch_size do @@ -269,6 +268,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) collectgarbage("collect") + tnn:flush_all() --break --debug end diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua index cf02123..bcfeb40 100644 --- a/nerv/tnn/tnn.lua +++ b/nerv/tnn/tnn.lua @@ -64,7 +64,7 @@ function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, gl if (type(st) ~= "table") then nerv.error("st should be a table") end - for i = 1 - extend_t - 1, chunk_size + extend_t + 1 do --intentionally allocated more time + for i = 1 - extend_t - 2, chunk_size + extend_t + 2 do --intentionally allocated more time if (st[i] == nil) then st[i] = {} end @@ -339,6 +339,11 @@ function TNN:net_propagate() --propagate according to feeds_now end local feeds_now = self.feeds_now + for t = 1, self.chunk_size do --some layer maybe do not have inputs from time 1..chunk_size + for id, ref in pairs(self.layers) do + self:propagate_dfs(ref, t) + end + end for t = 1, self.chunk_size do if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then for i = 1, #self.dim_in do @@ -362,6 +367,7 @@ function TNN:net_propagate() --propagate according to feeds_now end end end + if (flag_out == false) then nerv.error("some thing wrong, some labeled output is not propagated") end @@ -458,6 +464,11 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now end local feeds_now = self.feeds_now + for t = 1, self.chunk_size do --some layer maybe do not have outputs from time 1..chunk_size + for id, ref in pairs(self.layers) do + self:backpropagate_dfs(ref, t) + end + end for t = 1, self.chunk_size do if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then for i = 1, #self.dim_out do |