aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-12-06 13:33:26 +0800
committertxh18 <[email protected]>2015-12-06 13:33:26 +0800
commit79c711d9c92a8e92f7ad9187a66d3e2aac239356 (patch)
tree4310e2dcd54f735780a92d43d7c485ca37abfdad
parentea0e37892ae70357305da3b1fbae617215a25778 (diff)
small bug fix in lm training script
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua8
-rw-r--r--nerv/examples/lmptb/lstmlm_ptb_main.lua10
2 files changed, 10 insertions, 8 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 6bd06bb..2cdbd4f 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -22,7 +22,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
p_conf = {}
end
local reader
- local r_conf
+ local r_conf = {}
local chunk_size, batch_size
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
@@ -48,6 +48,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
global_conf.timer:flush()
+ tnn:init(batch_size, chunk_size)
tnn:flush_all() --caution: will also flush the inputs from the reader!
local next_log_wcn = global_conf.log_w_num
@@ -107,7 +108,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
if p_conf.one_sen_report == true then
for i = 1, batch_size do
- nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report, %f\n", sen_logp[i])
+ nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report_output, %f\n", sen_logp[i])
end
end
@@ -177,6 +178,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
end
global_conf.timer:flush()
+ tnn:init(batch_size, chunk_size)
tnn:flush_all() --caution: will also flush the inputs from the reader!
local next_log_wcn = global_conf.log_w_num
@@ -235,7 +237,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
end
if p_conf.one_sen_report == true then
for i = 1, batch_size do
- nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report, %f\n", sen_logp[i])
+ nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i])
end
end
diff --git a/nerv/examples/lmptb/lstmlm_ptb_main.lua b/nerv/examples/lmptb/lstmlm_ptb_main.lua
index 2751ec8..804438d 100644
--- a/nerv/examples/lmptb/lstmlm_ptb_main.lua
+++ b/nerv/examples/lmptb/lstmlm_ptb_main.lua
@@ -337,9 +337,8 @@ ppl_last = 100000
commands_str = "train:test"
commands = {}
test_iter = -1
-
--for testout(question)
-local q_file = "ptb.test.txt.q10rs1_Msss.adds"
+q_file = "ptb.test.txt.q10rs1_Msss.adds"
if arg[2] ~= nil then
nerv.printf("%s applying arg[2](%s)...\n", global_conf.sche_log_pre, arg[2])
@@ -485,8 +484,8 @@ end --if commands["test"]
if commands["testout"] == 1 then
nerv.printf("===TEST OUT===\n")
nerv.printf("q_file:\t%s\n", q_file)
- local q_fn = qdata_dir .. q_file
- global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:"
+ local q_fn = qdata_dir .. '/' .. q_file
+ global_conf.sche_log_pre = "[SCHEDULER TESTOUT]:"
if final_iter ~= -1 and test_iter == -1 then
test_iter = final_iter
end
@@ -495,7 +494,8 @@ if commands["testout"] == 1 then
end
tnn = load_net(global_conf, test_iter)
global_conf.dropout_rate = 0
- LMTrainer.lm_process_file_rnn(global_conf, q_fn, tnn, false) --false update!
+ LMTrainer.lm_process_file_rnn(global_conf, q_fn, tnn, false,
+ {["one_sen_report"] = true}) --false update!
end --if commands["testout"]