diff options
Diffstat (limited to 'nerv/examples/lmptb/rnn')
-rw-r--r-- | nerv/examples/lmptb/rnn/tnn.lua | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua index 019d24c..ae9ed7a 100644 --- a/nerv/examples/lmptb/rnn/tnn.lua +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -321,6 +321,22 @@ function TNN:net_propagate() --propagate according to feeds_now end end end + + local flag_out = true + for t = 1, self.chunk_size do --check whether every output has been computed + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then + for i = 1, #self.dim_out do + local ref = self.outputs_p[i].ref + if (ref.outputs_b[t][1] ~= true) then + flag_out = false + break + end + end + end + end + if (flag_out == false) then + nerv.error("some thing wrong, some labeled output is not propagated") + end end --ref: the TNN_ref of a layer @@ -421,6 +437,22 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now end end end + + local flag_out = true + for t = 1, self.chunk_size do --check whether every output has been computed + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then + for i = 1, #self.dim_in do + local ref = self.inputs_p[i].ref + if (ref.err_outputs_b[t][1] ~= true) then + flag_out = false + break + end + end + end + end + if (flag_out == false) then + nerv.error("some thing wrong, some input is not back_propagated") + end end --ref: the TNN_ref of a layer |