aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-12-04 21:03:06 +0800
committertxh18 <[email protected]>2015-12-04 21:03:06 +0800
commit7ee8988f21075246106a4d990190d0ef25fa82a8 (patch)
treeaa4898a406ebbb4249247e04d83b92b2d67ea4bf
parent6b674ff54cf61dbf583032455ed2c4c33dd0443c (diff)
added one_sen_report to lm_process_file
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua34
1 files changed, 26 insertions, 8 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 58d5bfc..3c7078e 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -22,15 +22,23 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
p_conf = {}
end
local reader
+ local chunk_size, batch_size
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
end
- reader = nerv.LMSeqReader(global_conf, 1, global_conf.max_sen_len, global_conf.vocab)
+ nerv.printf("lm_process_file_rnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n",
+ global_conf.max_sen_len)
+ batch_size = 1
+ chunk_size = global_conf.max_sen_len
else
- reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
+ batch_size = global_conf.batch_size
+ chunk_size = global_conf.chunk_size
end
+
+ reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab)
reader:open_file(fn)
+
local result = nerv.LMResult(global_conf, global_conf.vocab)
result:init("rnn")
if global_conf.dropout_rate ~= nil then
@@ -41,7 +49,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
tnn:flush_all() --caution: will also flush the inputs from the reader!
local next_log_wcn = global_conf.log_w_num
- local neto_bakm = global_conf.mmat_type(global_conf.batch_size, 1) --space backup matrix for network output
+ local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output
while (1) do
global_conf.timer:tic('most_out_loop_lmprocessfile')
@@ -53,9 +61,9 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
break
end
- for t = 1, global_conf.chunk_size do
+ for t = 1, chunk_size do
tnn.err_inputs_m[t][1]:fill(1)
- for i = 1, global_conf.batch_size do
+ for i = 1, batch_size do
if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then
tnn.err_inputs_m[t][1][i - 1][0] = 0
end
@@ -81,15 +89,26 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
global_conf.timer:tic('tnn_afterprocess')
- for t = 1, global_conf.chunk_size, 1 do
+ local sen_logp = {}
+ for t = 1, chunk_size, 1 do
tnn.outputs_m[t][1]:copy_toh(neto_bakm)
- for i = 1, global_conf.batch_size, 1 do
+ for i = 1, batch_size, 1 do
if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
--result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
+ if sen_logp[i] == nil then
+ sen_logp[i] = 0
+ end
+ sen_logp[i] = sen_logp[i] + neto_bakm[i - 1][0]
end
end
end
+ if p_conf.one_sen_report == true then
+ for i = 1, batch_size do
+ nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report, %f\n", sen_logp[i])
+ end
+ end
+
tnn:move_right_to_nextmb({0}) --only copy for time 0
global_conf.timer:toc('tnn_afterprocess')
@@ -113,7 +132,6 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
]]--
-
collectgarbage("collect")
--break --debug