aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua44
1 files changed, 34 insertions, 10 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 3b8b5c3..8a9744b 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -23,14 +23,16 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
local reader
local r_conf = {}
+ if p_conf.compressed_label ~= nil then
+ r_conf.compressed_label = p_conf.compressed_label
+ end
local chunk_size, batch_size
- if p_conf.one_sen_report == true then --report log prob one by one sentence
+ if p_conf.one_sen_report == true or p_conf.word_prob_report == true then --report log prob one by one sentence
if do_train == true then
- nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
+ nerv.error("LMTrainer.lm_process_file_rnn: warning, one_sen_report(or word_prob_report) is true while do_train is also true, strange")
end
- nerv.printf("lm_process_file_rnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n",
- global_conf.max_sen_len)
- batch_size = 1
+ nerv.printf("lm_process_file_rnn: one_sen(or word_prob) report mode, set chunk_size to max_sen_len(%d)\n", global_conf.max_sen_len)
+ batch_size = global_conf.batch_size
chunk_size = global_conf.max_sen_len
r_conf["se_mode"] = true
else
@@ -94,12 +96,15 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
global_conf.timer:tic('tnn_afterprocess')
local sen_logp = {}
+ local word_prob = {}
for t = 1, chunk_size, 1 do
+ word_prob[t] = {}
tnn.outputs_m[t][1]:copy_toh(neto_bakm)
for i = 1, batch_size, 1 do
if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
--result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
+ word_prob[t][i] = math.exp(neto_bakm[i - 1][0])
if sen_logp[i] == nil then
sen_logp[i] = 0
end
@@ -107,9 +112,22 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
end
end
+
+ if p_conf.word_prob_report == true then
+ for i = 1, batch_size do
+ for t = 1, chunk_size do
+ if feeds.labels_s[t][i] ~= global_conf.vocab.null_token then
+ nerv.printf("LMTrainer.lm_process_file_rnn: word_prob_report_output, %.8f %s\n", word_prob[t][i], feeds.labels_s[t][i])
+ end
+ end
+ end
+ end
+
if p_conf.one_sen_report == true then
for i = 1, batch_size do
- nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report_output, %f\n", sen_logp[i])
+ if feeds.labels_s[1][i] ~= global_conf.vocab.null_token then
+ nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report_output, %f\n", sen_logp[i])
+ end
end
end
@@ -156,13 +174,16 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
local reader
local chunk_size, batch_size
local r_conf = {["se_mode"] = true}
+ if p_conf.compressed_label ~= nil then
+ r_conf.compressed_label = p_conf.compressed_label
+ end
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
nerv.warning("LMTrainer.lm_process_file_birnn: warning, one_sen_report is true while do_train is also true, strange")
end
- nerv.printf("lm_process_file_birnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n",
+ nerv.printf("lm_process_file_birnn: one_sen report mode, set chunk_size to max_sen_len(%d)\n",
global_conf.max_sen_len)
- batch_size = 1
+ batch_size = global_conf.batch_size
chunk_size = global_conf.max_sen_len
else
batch_size = global_conf.batch_size
@@ -196,7 +217,6 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
if r == false then
break
end
-
for t = 1, chunk_size do
tnn.err_inputs_m[t][1]:fill(1)
for i = 1, batch_size do
@@ -240,13 +260,17 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
end
if p_conf.one_sen_report == true then
for i = 1, batch_size do
- nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i])
+ if sen_logp[i] ~= nil then
+ nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i])
+ end
end
end
--tnn:move_right_to_nextmb({0}) --do not need history for bi directional model
global_conf.timer:toc('tnn_afterprocess')
+ --tnn:flush_all() --you need this for bilstmlm_ptb_v2, because it has connection across 2 time steps
+
global_conf.timer:toc('most_out_loop_lmprocessfile')
--print log