aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-11-17 13:20:43 +0800
committertxh18 <cloudygooseg@gmail.com>2015-11-17 13:20:43 +0800
commit317ff51cae8dcfaff26855c42ce99656b4d293b5 (patch)
tree67f779bcffa9bace5e080329ae427796e785302d /nerv/examples/lmptb/lm_trainer.lua
parentbd563c1ebcd676059e0384532ab192d98b3eabf2 (diff)
added small opt: use mmatrix in lm_trainer and reader
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua17
1 files changed, 12 insertions, 5 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 7dd70e2..62d8b50 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -15,17 +15,18 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
reader:open_file(fn)
local result = nerv.LMResult(global_conf, global_conf.vocab)
result:init("rnn")
-
+
global_conf.timer:flush()
tnn:flush_all() --caution: will also flush the inputs from the reader!
local next_log_wcn = global_conf.log_w_num
+ local neto_bakm = global_conf.mmat_type(global_conf.batch_size, 1) --space backup matrix for network output
while (1) do
global_conf.timer:tic('most_out_loop_lmprocessfile')
local r, feeds
-
+ global_conf.timer:tic('tnn_beforeprocess')
r, feeds = tnn:getfeed_from_reader(reader)
if r == false then
break
@@ -39,6 +40,7 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
end
end
end
+ global_conf.timer:toc('tnn_beforeprocess')
--[[
for j = 1, global_conf.chunk_size, 1 do
@@ -56,15 +58,20 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
tnn:net_backpropagate(false)
tnn:net_backpropagate(true)
end
-
+
+ global_conf.timer:tic('tnn_afterprocess')
for t = 1, global_conf.chunk_size, 1 do
+ tnn.outputs_m[t][1]:copy_toh(neto_bakm)
for i = 1, global_conf.batch_size, 1 do
if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
- result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
+ --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
+ result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
end
end
end
- tnn:move_right_to_nextmb()
+ tnn:move_right_to_nextmb({0}) --only copy for time 0
+ global_conf.timer:toc('tnn_afterprocess')
+
global_conf.timer:toc('most_out_loop_lmprocessfile')
--print log