aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua12
1 files changed, 10 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index eab6e2d..06c1a4c 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -23,6 +23,9 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
local reader
local r_conf = {}
+ if p_conf.compressed_label ~= nil then
+ r_conf.compressed_label = p_conf.compressed_label
+ end
local chunk_size, batch_size
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
@@ -156,13 +159,16 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
local reader
local chunk_size, batch_size
local r_conf = {["se_mode"] = true}
+ if p_conf.compressed_label ~= nil then
+ r_conf.compressed_label = p_conf.compressed_label
+ end
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
nerv.warning("LMTrainer.lm_process_file_birnn: warning, one_sen_report is true while do_train is also true, strange")
end
nerv.printf("lm_process_file_birnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n",
global_conf.max_sen_len)
- batch_size = 1
+ batch_size = global_conf.batch_size
chunk_size = global_conf.max_sen_len
else
batch_size = global_conf.batch_size
@@ -239,7 +245,9 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
end
if p_conf.one_sen_report == true then
for i = 1, batch_size do
- nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i])
+ if sen_logp[i] ~= nil then
+ nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i])
+ end
end
end