aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/rnn/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua32
1 files changed, 32 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index 019d24c..ae9ed7a 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -321,6 +321,22 @@ function TNN:net_propagate() --propagate according to feeds_now
end
end
end
+
+ local flag_out = true
+ for t = 1, self.chunk_size do --check whether every output has been computed
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then
+ for i = 1, #self.dim_out do
+ local ref = self.outputs_p[i].ref
+ if (ref.outputs_b[t][1] ~= true) then
+ flag_out = false
+ break
+ end
+ end
+ end
+ end
+ if (flag_out == false) then
+ nerv.error("some thing wrong, some labeled output is not propagated")
+ end
end
--ref: the TNN_ref of a layer
@@ -421,6 +437,22 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now
end
end
end
+
+ local flag_out = true
+ for t = 1, self.chunk_size do --check whether every output has been computed
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
+ for i = 1, #self.dim_in do
+ local ref = self.inputs_p[i].ref
+ if (ref.err_outputs_b[t][1] ~= true) then
+ flag_out = false
+ break
+ end
+ end
+ end
+ end
+ if (flag_out == false) then
+ nerv.error("some thing wrong, some input is not back_propagated")
+ end
end
--ref: the TNN_ref of a layer