aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-12-05 17:47:30 +0800
committertxh18 <cloudygooseg@gmail.com>2015-12-05 17:47:30 +0800
commit902e547326311bab9d6494cebbc1e5e2f14a018b (patch)
treee54546cce6efa4df89658a2bd3f65992f3b261a7 /nerv/examples/lmptb/lm_trainer.lua
parent2daed79a165015f164a46117dd7d8aa9cbfe5587 (diff)
added twitter, added bilstmlm script, todo: test bilstmlm
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua132
1 files changed, 131 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 3c7078e..6bd06bb 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -22,6 +22,7 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
p_conf = {}
end
local reader
+ local r_conf
local chunk_size, batch_size
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
@@ -31,12 +32,13 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
global_conf.max_sen_len)
batch_size = 1
chunk_size = global_conf.max_sen_len
+ r_conf["se_mode"] = true
else
batch_size = global_conf.batch_size
chunk_size = global_conf.chunk_size
end
- reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab)
+ reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab, r_conf)
reader:open_file(fn)
local result = nerv.LMResult(global_conf, global_conf.vocab)
@@ -144,4 +146,132 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
return result
end
+--Returns: LMResult
+function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
+ if p_conf == nil then
+ p_conf = {}
+ end
+ local reader
+ local chunk_size, batch_size
+ local r_conf = {["se_mode"] = true}
+ if p_conf.one_sen_report == true then --report log prob one by one sentence
+ if do_train == true then
+ nerv.warning("LMTrainer.lm_process_file_birnn: warning, one_sen_report is true while do_train is also true, strange")
+ end
+ nerv.printf("lm_process_file_birnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n",
+ global_conf.max_sen_len)
+ batch_size = 1
+ chunk_size = global_conf.max_sen_len
+ else
+ batch_size = global_conf.batch_size
+ chunk_size = global_conf.chunk_size
+ end
+
+ reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab)
+ reader:open_file(fn)
+
+ local result = nerv.LMResult(global_conf, global_conf.vocab)
+ result:init("birnn")
+ if global_conf.dropout_rate ~= nil then
+ nerv.info("LMTrainer.lm_process_file_birnn: dropout_rate is %f", global_conf.dropout_rate)
+ end
+
+ global_conf.timer:flush()
+ tnn:flush_all() --caution: will also flush the inputs from the reader!
+
+ local next_log_wcn = global_conf.log_w_num
+ local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output
+
+ while (1) do
+ global_conf.timer:tic('most_out_loop_lmprocessfile')
+
+ local r, feeds
+ global_conf.timer:tic('tnn_beforeprocess')
+ r, feeds = tnn:getfeed_from_reader(reader)
+ if r == false then
+ break
+ end
+
+ for t = 1, chunk_size do
+ tnn.err_inputs_m[t][1]:fill(1)
+ for i = 1, batch_size do
+ if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then
+ tnn.err_inputs_m[t][1][i - 1][0] = 0
+ end
+ end
+ end
+ global_conf.timer:toc('tnn_beforeprocess')
+
+ --[[
+ for j = 1, global_conf.chunk_size, 1 do
+ for i = 1, global_conf.batch_size, 1 do
+ printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id
+ end
+ printf("\n")
+ end
+ printf("\n")
+ ]]--
+
+ tnn:net_propagate()
+
+ if do_train == true then
+ tnn:net_backpropagate(false)
+ tnn:net_backpropagate(true)
+ end
+
+ global_conf.timer:tic('tnn_afterprocess')
+ local sen_logp = {}
+ for t = 1, chunk_size, 1 do
+ tnn.outputs_m[t][1]:copy_toh(neto_bakm)
+ for i = 1, batch_size, 1 do
+ if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
+ result:add("birnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
+ if sen_logp[i] == nil then
+ sen_logp[i] = 0
+ end
+ sen_logp[i] = sen_logp[i] + neto_bakm[i - 1][0]
+ end
+ end
+ end
+ if p_conf.one_sen_report == true then
+ for i = 1, batch_size do
+ nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report, %f\n", sen_logp[i])
+ end
+ end
+
+ --tnn:move_right_to_nextmb({0}) --do not need history for bi directional model
+ global_conf.timer:toc('tnn_afterprocess')
+
+ global_conf.timer:toc('most_out_loop_lmprocessfile')
+
+ --print log
+ if result["birnn"].cn_w > next_log_wcn then
+ next_log_wcn = next_log_wcn + global_conf.log_w_num
+ nerv.printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["birnn"].cn_w, os.date())
+ nerv.printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("birnn"))
+ for key, value in pairs(global_conf.timer.rec) do
+ nerv.printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value)
+ end
+ global_conf.timer:flush()
+ nerv.LMUtil.wait(0.1)
+ end
+
+ --[[
+ for t = 1, global_conf.chunk_size do
+ print(tnn.outputs_m[t][1])
+ end
+ ]]--
+
+ collectgarbage("collect")
+
+ --break --debug
+ end
+
+ nerv.printf("%s Displaying result:\n", global_conf.sche_log_pre)
+ nerv.printf("%s %s\n", global_conf.sche_log_pre, result:status("birnn"))
+ nerv.printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn)
+
+ return result
+end
+