diff options
author | txh18 <cloudygooseg@gmail.com> | 2015-11-17 13:20:43 +0800 |
---|---|---|
committer | txh18 <cloudygooseg@gmail.com> | 2015-11-17 13:20:43 +0800 |
commit | 317ff51cae8dcfaff26855c42ce99656b4d293b5 (patch) | |
tree | 67f779bcffa9bace5e080329ae427796e785302d /nerv/examples/lmptb/lm_trainer.lua | |
parent | bd563c1ebcd676059e0384532ab192d98b3eabf2 (diff) |
added small opt: use mmatrix in lm_trainer and reader
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 7dd70e2..62d8b50 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -15,17 +15,18 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) result:init("rnn") - + global_conf.timer:flush() tnn:flush_all() --caution: will also flush the inputs from the reader! local next_log_wcn = global_conf.log_w_num + local neto_bakm = global_conf.mmat_type(global_conf.batch_size, 1) --space backup matrix for network output while (1) do global_conf.timer:tic('most_out_loop_lmprocessfile') local r, feeds - + global_conf.timer:tic('tnn_beforeprocess') r, feeds = tnn:getfeed_from_reader(reader) if r == false then break @@ -39,6 +40,7 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) end end end + global_conf.timer:toc('tnn_beforeprocess') --[[ for j = 1, global_conf.chunk_size, 1 do @@ -56,15 +58,20 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) tnn:net_backpropagate(false) tnn:net_backpropagate(true) end - + + global_conf.timer:tic('tnn_afterprocess') for t = 1, global_conf.chunk_size, 1 do + tnn.outputs_m[t][1]:copy_toh(neto_bakm) for i = 1, global_conf.batch_size, 1 do if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then - result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) + --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) + result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0])) end end end - tnn:move_right_to_nextmb() + tnn:move_right_to_nextmb({0}) --only copy for time 0 + global_conf.timer:toc('tnn_afterprocess') + global_conf.timer:toc('most_out_loop_lmprocessfile') --print log |