From 7ee8988f21075246106a4d990190d0ef25fa82a8 Mon Sep 17 00:00:00 2001 From: txh18 Date: Fri, 4 Dec 2015 21:03:06 +0800 Subject: added one_sen_report to lm_process_file --- nerv/examples/lmptb/lm_trainer.lua | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 58d5bfc..3c7078e 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -22,15 +22,23 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) p_conf = {} end local reader + local chunk_size, batch_size if p_conf.one_sen_report == true then --report log prob one by one sentence if do_train == true then nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange") end - reader = nerv.LMSeqReader(global_conf, 1, global_conf.max_sen_len, global_conf.vocab) + nerv.printf("lm_process_file_rnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n", + global_conf.max_sen_len) + batch_size = 1 + chunk_size = global_conf.max_sen_len else - reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab) + batch_size = global_conf.batch_size + chunk_size = global_conf.chunk_size end + + reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab) reader:open_file(fn) + local result = nerv.LMResult(global_conf, global_conf.vocab) result:init("rnn") if global_conf.dropout_rate ~= nil then @@ -41,7 +49,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) tnn:flush_all() --caution: will also flush the inputs from the reader! local next_log_wcn = global_conf.log_w_num - local neto_bakm = global_conf.mmat_type(global_conf.batch_size, 1) --space backup matrix for network output + local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output while (1) do global_conf.timer:tic('most_out_loop_lmprocessfile') @@ -53,9 +61,9 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) break end - for t = 1, global_conf.chunk_size do + for t = 1, chunk_size do tnn.err_inputs_m[t][1]:fill(1) - for i = 1, global_conf.batch_size do + for i = 1, batch_size do if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then tnn.err_inputs_m[t][1][i - 1][0] = 0 end @@ -81,15 +89,26 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) end global_conf.timer:tic('tnn_afterprocess') - for t = 1, global_conf.chunk_size, 1 do + local sen_logp = {} + for t = 1, chunk_size, 1 do tnn.outputs_m[t][1]:copy_toh(neto_bakm) - for i = 1, global_conf.batch_size, 1 do + for i = 1, batch_size, 1 do if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0])) + if sen_logp[i] == nil then + sen_logp[i] = 0 + end + sen_logp[i] = sen_logp[i] + neto_bakm[i - 1][0] end end end + if p_conf.one_sen_report == true then + for i = 1, batch_size do + nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report, %f\n", sen_logp[i]) + end + end + tnn:move_right_to_nextmb({0}) --only copy for time 0 global_conf.timer:toc('tnn_afterprocess') @@ -113,7 +132,6 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) end ]]-- - collectgarbage("collect") --break --debug -- cgit v1.2.3-70-g09d2