diff options
author | txh18 <cloudygooseg@gmail.com> | 2015-11-07 13:50:33 +0800 |
---|---|---|
committer | txh18 <cloudygooseg@gmail.com> | 2015-11-07 13:50:33 +0800 |
commit | 7a46eeb6ce9189b3f7baa166db92234d85a2e828 (patch) | |
tree | 1abbee6a4251c7e933234bcb7d7b194f0a07f686 /nerv/examples/lmptb/rnn/tnn.lua | |
parent | 5a5a84173c2caee7e6a528f2312057b9acee8216 (diff) |
ready to test on ptb
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r-- | nerv/examples/lmptb/rnn/tnn.lua | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua index 019d24c..ae9ed7a 100644 --- a/nerv/examples/lmptb/rnn/tnn.lua +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -321,6 +321,22 @@ function TNN:net_propagate() --propagate according to feeds_now end end end + + local flag_out = true + for t = 1, self.chunk_size do --check whether every output has been computed + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then + for i = 1, #self.dim_out do + local ref = self.outputs_p[i].ref + if (ref.outputs_b[t][1] ~= true) then + flag_out = false + break + end + end + end + end + if (flag_out == false) then + nerv.error("some thing wrong, some labeled output is not propagated") + end end --ref: the TNN_ref of a layer @@ -421,6 +437,22 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now end end end + + local flag_out = true + for t = 1, self.chunk_size do --check whether every output has been computed + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then + for i = 1, #self.dim_in do + local ref = self.inputs_p[i].ref + if (ref.err_outputs_b[t][1] ~= true) then + flag_out = false + break + end + end + end + end + if (flag_out == false) then + nerv.error("some thing wrong, some input is not back_propagated") + end end --ref: the TNN_ref of a layer |