aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/rnn
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/rnn')
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua8
1 files changed, 7 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index 9850fe5..d6bf42e 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -384,8 +384,10 @@ function TNN:propagate_dfs(ref, t)
end
end
]]--
+ self.gconf.timer:tic("tnn_actual_layer_propagate")
ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate!
-
+ self.gconf.timer:toc("tnn_actual_layer_propagate")
+
if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history
for i = 1, self.batch_size do
local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
@@ -487,10 +489,14 @@ function TNN:backpropagate_dfs(ref, t, do_update)
--ok, do back_propagate
--print("debug ok, back-propagating(or updating)")
if (do_update == false) then
+ self.gconf.timer:tic("tnn_actual_layer_backpropagate")
ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
+ self.gconf.timer:toc("tnn_actual_layer_backpropagate")
else
--print(ref.err_inputs_m[t][1])
+ self.gconf.timer:tic("tnn_actual_layer_update")
ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
+ self.gconf.timer:toc("tnn_actual_layer_update")
end
if (do_update == false and bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors