diff options
author | txh18 <cloudygooseg@gmail.com> | 2015-12-06 13:33:26 +0800 |
---|---|---|
committer | txh18 <cloudygooseg@gmail.com> | 2015-12-06 13:33:26 +0800 |
commit | 79c711d9c92a8e92f7ad9187a66d3e2aac239356 (patch) | |
tree | 4310e2dcd54f735780a92d43d7c485ca37abfdad /nerv/examples/lmptb/lm_trainer.lua | |
parent | ea0e37892ae70357305da3b1fbae617215a25778 (diff) |
small bug fix in lm training script
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 6bd06bb..2cdbd4f 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -22,7 +22,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) p_conf = {} end local reader - local r_conf + local r_conf = {} local chunk_size, batch_size if p_conf.one_sen_report == true then --report log prob one by one sentence if do_train == true then @@ -48,6 +48,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) end global_conf.timer:flush() + tnn:init(batch_size, chunk_size) tnn:flush_all() --caution: will also flush the inputs from the reader! local next_log_wcn = global_conf.log_w_num @@ -107,7 +108,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) end if p_conf.one_sen_report == true then for i = 1, batch_size do - nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report, %f\n", sen_logp[i]) + nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report_output, %f\n", sen_logp[i]) end end @@ -177,6 +178,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) end global_conf.timer:flush() + tnn:init(batch_size, chunk_size) tnn:flush_all() --caution: will also flush the inputs from the reader! local next_log_wcn = global_conf.log_w_num @@ -235,7 +237,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) end if p_conf.one_sen_report == true then for i = 1, batch_size do - nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report, %f\n", sen_logp[i]) + nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i]) end end |