aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-12-10 00:15:38 +0800
committertxh18 <[email protected]>2015-12-10 00:15:38 +0800
commit00c3f11361967a0f78fd770d20a2af3e9e7c1f50 (patch)
tree16b2a49c84c28546838182f5eb83ca2e4eab3b51
parentebb8ba41886f5860f27157343bcb022eb672143c (diff)
bilstm_v2 did not run well
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua2
-rw-r--r--nerv/tnn/tnn.lua13
2 files changed, 13 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 3b8b5c3..ecedc9f 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -196,7 +196,6 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
if r == false then
break
end
-
for t = 1, chunk_size do
tnn.err_inputs_m[t][1]:fill(1)
for i = 1, batch_size do
@@ -269,6 +268,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
collectgarbage("collect")
+ tnn:flush_all()
--break --debug
end
diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua
index cf02123..bcfeb40 100644
--- a/nerv/tnn/tnn.lua
+++ b/nerv/tnn/tnn.lua
@@ -64,7 +64,7 @@ function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, gl
if (type(st) ~= "table") then
nerv.error("st should be a table")
end
- for i = 1 - extend_t - 1, chunk_size + extend_t + 1 do --intentionally allocated more time
+ for i = 1 - extend_t - 2, chunk_size + extend_t + 2 do --intentionally allocated more time
if (st[i] == nil) then
st[i] = {}
end
@@ -339,6 +339,11 @@ function TNN:net_propagate() --propagate according to feeds_now
end
local feeds_now = self.feeds_now
+ for t = 1, self.chunk_size do --some layer maybe do not have inputs from time 1..chunk_size
+ for id, ref in pairs(self.layers) do
+ self:propagate_dfs(ref, t)
+ end
+ end
for t = 1, self.chunk_size do
if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
for i = 1, #self.dim_in do
@@ -362,6 +367,7 @@ function TNN:net_propagate() --propagate according to feeds_now
end
end
end
+
if (flag_out == false) then
nerv.error("some thing wrong, some labeled output is not propagated")
end
@@ -458,6 +464,11 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now
end
local feeds_now = self.feeds_now
+ for t = 1, self.chunk_size do --some layer maybe do not have outputs from time 1..chunk_size
+ for id, ref in pairs(self.layers) do
+ self:backpropagate_dfs(ref, t)
+ end
+ end
for t = 1, self.chunk_size do
if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then
for i = 1, #self.dim_out do