aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2016-01-06 19:11:44 +0800
committertxh18 <cloudygooseg@gmail.com>2016-01-06 19:11:44 +0800
commit38d7b02e9f5f5568ef6c7d1481eee190cb74eebc (patch)
tree55431a4a0062d2d611616e246bb12e14993630be
parent464cd37d218ac7dee90d44e721655a7c79ebe961 (diff)
added word_prob feature to lm_trainer
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua21
-rw-r--r--nerv/examples/lmptb/rnnlm_ptb_main.lua15
2 files changed, 32 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 9dfefe5..8a9744b 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -27,12 +27,11 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
r_conf.compressed_label = p_conf.compressed_label
end
local chunk_size, batch_size
- if p_conf.one_sen_report == true then --report log prob one by one sentence
+ if p_conf.one_sen_report == true or p_conf.word_prob_report == true then --report log prob one by one sentence
if do_train == true then
- nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
+ nerv.error("LMTrainer.lm_process_file_rnn: warning, one_sen_report(or word_prob_report) is true while do_train is also true, strange")
end
- nerv.printf("lm_process_file_rnn: one_sen report mode, set chunk_size to max_sen_len(%d)\n",
- global_conf.max_sen_len)
+ nerv.printf("lm_process_file_rnn: one_sen(or word_prob) report mode, set chunk_size to max_sen_len(%d)\n", global_conf.max_sen_len)
batch_size = global_conf.batch_size
chunk_size = global_conf.max_sen_len
r_conf["se_mode"] = true
@@ -97,12 +96,15 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
global_conf.timer:tic('tnn_afterprocess')
local sen_logp = {}
+ local word_prob = {}
for t = 1, chunk_size, 1 do
+ word_prob[t] = {}
tnn.outputs_m[t][1]:copy_toh(neto_bakm)
for i = 1, batch_size, 1 do
if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
--result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
+ word_prob[t][i] = math.exp(neto_bakm[i - 1][0])
if sen_logp[i] == nil then
sen_logp[i] = 0
end
@@ -110,6 +112,17 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
end
end
+
+ if p_conf.word_prob_report == true then
+ for i = 1, batch_size do
+ for t = 1, chunk_size do
+ if feeds.labels_s[t][i] ~= global_conf.vocab.null_token then
+ nerv.printf("LMTrainer.lm_process_file_rnn: word_prob_report_output, %.8f %s\n", word_prob[t][i], feeds.labels_s[t][i])
+ end
+ end
+ end
+ end
+
if p_conf.one_sen_report == true then
for i = 1, batch_size do
if feeds.labels_s[1][i] ~= global_conf.vocab.null_token then
diff --git a/nerv/examples/lmptb/rnnlm_ptb_main.lua b/nerv/examples/lmptb/rnnlm_ptb_main.lua
index e2ca860..dc011fb 100644
--- a/nerv/examples/lmptb/rnnlm_ptb_main.lua
+++ b/nerv/examples/lmptb/rnnlm_ptb_main.lua
@@ -190,6 +190,7 @@ global_conf = {
valid_fn = valid_fn,
test_fn = test_fn,
vocab_fn = vocab_fn,
+ max_sen_len = 90,
sche_log_pre = "[SCHEDULER]:",
log_w_num = 40000, --give a message when log_w_num words have been processed
timer = nerv.Timer(),
@@ -398,3 +399,17 @@ if commands["test"] == 1 then
LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update!
end --if commands["test"]
+if commands["wordprob"] == 1 then
+ if final_iter ~= -1 and test_iter == -1 then
+ test_iter = final_iter
+ end
+ if test_iter == -1 then
+ test_iter = "final"
+ end
+
+ printf("===FINAL TEST===\n")
+ global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:"
+ tnn = load_net(global_conf, test_iter)
+ LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false, {["word_prob_report"] = true}) --false update!
+end --if commands["test"]
+