aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-12-10 13:28:13 +0800
committertxh18 <cloudygooseg@gmail.com>2015-12-10 13:28:13 +0800
commit91075c34160fa24e484148b26c1178e05c2212a4 (patch)
tree9bfab2962f6f8b6c28b41c56793fec3e48d94412
parent62169f73b935dd6df8fe0c5628beed58820d186e (diff)
bug fix for recent changes in tnn
-rw-r--r--nerv/examples/lmptb/lmptb/layer/select_linear.lua2
-rw-r--r--nerv/tnn/tnn.lua5
2 files changed, 5 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/lmptb/layer/select_linear.lua b/nerv/examples/lmptb/lmptb/layer/select_linear.lua
index 580b9c5..431ef3a 100644
--- a/nerv/examples/lmptb/lmptb/layer/select_linear.lua
+++ b/nerv/examples/lmptb/lmptb/layer/select_linear.lua
@@ -30,7 +30,7 @@ function SL:init(batch_size)
end
function SL:update(bp_err, input, output)
- --use this to produce reproducable result
+ --use this to produce reproducable result, don't forget to set the dropout to zero!
--for i = 1, input[1]:nrow(), 1 do
-- local word_vec = self.ltp.trans[input[1][i - 1][0]]
-- word_vec:add(word_vec, bp_err[1][i - 1], 1, - self.gconf.lrate / self.gconf.batch_size)
diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua
index bcfeb40..7ae3172 100644
--- a/nerv/tnn/tnn.lua
+++ b/nerv/tnn/tnn.lua
@@ -466,7 +466,7 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now
local feeds_now = self.feeds_now
for t = 1, self.chunk_size do --some layer maybe do not have outputs from time 1..chunk_size
for id, ref in pairs(self.layers) do
- self:backpropagate_dfs(ref, t)
+ self:backpropagate_dfs(ref, t, do_update)
end
end
for t = 1, self.chunk_size do
@@ -500,6 +500,9 @@ end
--ref: the TNN_ref of a layer
--t: the current time to propagate
function TNN:backpropagate_dfs(ref, t, do_update)
+ if do_update == nil then
+ nerv.error("got a nil do_update")
+ end
if self:out_of_feedrange(t) then
return
end