diff options
author | txh18 <cloudygooseg@gmail.com> | 2015-12-05 17:47:30 +0800 |
---|---|---|
committer | txh18 <cloudygooseg@gmail.com> | 2015-12-05 17:47:30 +0800 |
commit | 902e547326311bab9d6494cebbc1e5e2f14a018b (patch) | |
tree | e54546cce6efa4df89658a2bd3f65992f3b261a7 /nerv/examples/lmptb/lm_trainer.lua | |
parent | 2daed79a165015f164a46117dd7d8aa9cbfe5587 (diff) |
added twitter, added bilstmlm script, todo: test bilstmlm
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 132 |
1 files changed, 131 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 3c7078e..6bd06bb 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -22,6 +22,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) p_conf = {} end local reader + local r_conf local chunk_size, batch_size if p_conf.one_sen_report == true then --report log prob one by one sentence if do_train == true then @@ -31,12 +32,13 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) global_conf.max_sen_len) batch_size = 1 chunk_size = global_conf.max_sen_len + r_conf["se_mode"] = true else batch_size = global_conf.batch_size chunk_size = global_conf.chunk_size end - reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab) + reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab, r_conf) reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) @@ -144,4 +146,132 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) return result end +--Returns: LMResult +function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) + if p_conf == nil then + p_conf = {} + end + local reader + local chunk_size, batch_size + local r_conf = {["se_mode"] = true} + if p_conf.one_sen_report == true then --report log prob one by one sentence + if do_train == true then + nerv.warning("LMTrainer.lm_process_file_birnn: warning, one_sen_report is true while do_train is also true, strange") + end + nerv.printf("lm_process_file_birnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n", + global_conf.max_sen_len) + batch_size = 1 + chunk_size = global_conf.max_sen_len + else + batch_size = global_conf.batch_size + chunk_size = global_conf.chunk_size + end + + reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab) + reader:open_file(fn) + + local result = nerv.LMResult(global_conf, global_conf.vocab) + result:init("birnn") + if global_conf.dropout_rate ~= nil then + nerv.info("LMTrainer.lm_process_file_birnn: dropout_rate is %f", global_conf.dropout_rate) + end + + global_conf.timer:flush() + tnn:flush_all() --caution: will also flush the inputs from the reader! + + local next_log_wcn = global_conf.log_w_num + local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output + + while (1) do + global_conf.timer:tic('most_out_loop_lmprocessfile') + + local r, feeds + global_conf.timer:tic('tnn_beforeprocess') + r, feeds = tnn:getfeed_from_reader(reader) + if r == false then + break + end + + for t = 1, chunk_size do + tnn.err_inputs_m[t][1]:fill(1) + for i = 1, batch_size do + if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then + tnn.err_inputs_m[t][1][i - 1][0] = 0 + end + end + end + global_conf.timer:toc('tnn_beforeprocess') + + --[[ + for j = 1, global_conf.chunk_size, 1 do + for i = 1, global_conf.batch_size, 1 do + printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id + end + printf("\n") + end + printf("\n") + ]]-- + + tnn:net_propagate() + + if do_train == true then + tnn:net_backpropagate(false) + tnn:net_backpropagate(true) + end + + global_conf.timer:tic('tnn_afterprocess') + local sen_logp = {} + for t = 1, chunk_size, 1 do + tnn.outputs_m[t][1]:copy_toh(neto_bakm) + for i = 1, batch_size, 1 do + if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then + result:add("birnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0])) + if sen_logp[i] == nil then + sen_logp[i] = 0 + end + sen_logp[i] = sen_logp[i] + neto_bakm[i - 1][0] + end + end + end + if p_conf.one_sen_report == true then + for i = 1, batch_size do + nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report, %f\n", sen_logp[i]) + end + end + + --tnn:move_right_to_nextmb({0}) --do not need history for bi directional model + global_conf.timer:toc('tnn_afterprocess') + + global_conf.timer:toc('most_out_loop_lmprocessfile') + + --print log + if result["birnn"].cn_w > next_log_wcn then + next_log_wcn = next_log_wcn + global_conf.log_w_num + nerv.printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["birnn"].cn_w, os.date()) + nerv.printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("birnn")) + for key, value in pairs(global_conf.timer.rec) do + nerv.printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value) + end + global_conf.timer:flush() + nerv.LMUtil.wait(0.1) + end + + --[[ + for t = 1, global_conf.chunk_size do + print(tnn.outputs_m[t][1]) + end + ]]-- + + collectgarbage("collect") + + --break --debug + end + + nerv.printf("%s Displaying result:\n", global_conf.sche_log_pre) + nerv.printf("%s %s\n", global_conf.sche_log_pre, result:status("birnn")) + nerv.printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn) + + return result +end + |