aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua8
1 files changed, 5 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 6bd06bb..2cdbd4f 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -22,7 +22,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
p_conf = {}
end
local reader
- local r_conf
+ local r_conf = {}
local chunk_size, batch_size
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
@@ -48,6 +48,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
global_conf.timer:flush()
+ tnn:init(batch_size, chunk_size)
tnn:flush_all() --caution: will also flush the inputs from the reader!
local next_log_wcn = global_conf.log_w_num
@@ -107,7 +108,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
if p_conf.one_sen_report == true then
for i = 1, batch_size do
- nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report, %f\n", sen_logp[i])
+ nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report_output, %f\n", sen_logp[i])
end
end
@@ -177,6 +178,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
end
global_conf.timer:flush()
+ tnn:init(batch_size, chunk_size)
tnn:flush_all() --caution: will also flush the inputs from the reader!
local next_log_wcn = global_conf.log_w_num
@@ -235,7 +237,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
end
if p_conf.one_sen_report == true then
for i = 1, batch_size do
- nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report, %f\n", sen_logp[i])
+ nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i])
end
end