From 38d7b02e9f5f5568ef6c7d1481eee190cb74eebc Mon Sep 17 00:00:00 2001 From: txh18 Date: Wed, 6 Jan 2016 19:11:44 +0800 Subject: added word_prob feature to lm_trainer --- nerv/examples/lmptb/lm_trainer.lua | 21 +++++++++++++++++---- nerv/examples/lmptb/rnnlm_ptb_main.lua | 15 +++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 9dfefe5..8a9744b 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -27,12 +27,11 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) r_conf.compressed_label = p_conf.compressed_label end local chunk_size, batch_size - if p_conf.one_sen_report == true then --report log prob one by one sentence + if p_conf.one_sen_report == true or p_conf.word_prob_report == true then --report log prob one by one sentence if do_train == true then - nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange") + nerv.error("LMTrainer.lm_process_file_rnn: warning, one_sen_report(or word_prob_report) is true while do_train is also true, strange") end - nerv.printf("lm_process_file_rnn: one_sen report mode, set chunk_size to max_sen_len(%d)\n", - global_conf.max_sen_len) + nerv.printf("lm_process_file_rnn: one_sen(or word_prob) report mode, set chunk_size to max_sen_len(%d)\n", global_conf.max_sen_len) batch_size = global_conf.batch_size chunk_size = global_conf.max_sen_len r_conf["se_mode"] = true @@ -97,12 +96,15 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) global_conf.timer:tic('tnn_afterprocess') local sen_logp = {} + local word_prob = {} for t = 1, chunk_size, 1 do + word_prob[t] = {} tnn.outputs_m[t][1]:copy_toh(neto_bakm) for i = 1, batch_size, 1 do if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0])) + word_prob[t][i] = math.exp(neto_bakm[i - 1][0]) if sen_logp[i] == nil then sen_logp[i] = 0 end @@ -110,6 +112,17 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) end end end + + if p_conf.word_prob_report == true then + for i = 1, batch_size do + for t = 1, chunk_size do + if feeds.labels_s[t][i] ~= global_conf.vocab.null_token then + nerv.printf("LMTrainer.lm_process_file_rnn: word_prob_report_output, %.8f %s\n", word_prob[t][i], feeds.labels_s[t][i]) + end + end + end + end + if p_conf.one_sen_report == true then for i = 1, batch_size do if feeds.labels_s[1][i] ~= global_conf.vocab.null_token then diff --git a/nerv/examples/lmptb/rnnlm_ptb_main.lua b/nerv/examples/lmptb/rnnlm_ptb_main.lua index e2ca860..dc011fb 100644 --- a/nerv/examples/lmptb/rnnlm_ptb_main.lua +++ b/nerv/examples/lmptb/rnnlm_ptb_main.lua @@ -190,6 +190,7 @@ global_conf = { valid_fn = valid_fn, test_fn = test_fn, vocab_fn = vocab_fn, + max_sen_len = 90, sche_log_pre = "[SCHEDULER]:", log_w_num = 40000, --give a message when log_w_num words have been processed timer = nerv.Timer(), @@ -398,3 +399,17 @@ if commands["test"] == 1 then LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! end --if commands["test"] +if commands["wordprob"] == 1 then + if final_iter ~= -1 and test_iter == -1 then + test_iter = final_iter + end + if test_iter == -1 then + test_iter = "final" + end + + printf("===FINAL TEST===\n") + global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:" + tnn = load_net(global_conf, test_iter) + LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false, {["word_prob_report"] = true}) --false update! +end --if commands["test"] + -- cgit v1.2.3-70-g09d2