aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua89
1 files changed, 89 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
new file mode 100644
index 0000000..d34634c
--- /dev/null
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -0,0 +1,89 @@
+require 'lmptb.lmvocab'
+require 'lmptb.lmfeeder'
+require 'lmptb.lmutil'
+require 'lmptb.layer.init'
+require 'rnn.init'
+require 'lmptb.lmseqreader'
+
+local LMTrainer = nerv.class('nerv.LMTrainer')
+
+local printf = nerv.printf
+
+--Returns: LMResult
+function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
+ local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
+ reader:open_file(fn)
+ local result = nerv.LMResult(global_conf, global_conf.vocab)
+ result:init("rnn")
+
+ tnn:flush_all() --caution: will also flush the inputs from the reader!
+
+ local next_log_wcn = global_conf.log_w_num
+
+ while (1) do
+ local r, feeds
+
+ r, feeds = tnn:getFeedFromReader(reader)
+ if (r == false) then break end
+
+ for t = 1, global_conf.chunk_size do
+ tnn.err_inputs_m[t][1]:fill(1)
+ for i = 1, global_conf.batch_size do
+ if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then
+ tnn.err_inputs_m[t][1][i - 1][0] = 0
+ end
+ end
+ end
+
+ --[[
+ for j = 1, global_conf.chunk_size, 1 do
+ for i = 1, global_conf.batch_size, 1 do
+ printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id
+ end
+ printf("\n")
+ end
+ printf("\n")
+ ]]--
+
+ tnn:net_propagate()
+
+ if (do_train == true) then
+ tnn:net_backpropagate(false)
+ tnn:net_backpropagate(true)
+ end
+
+ for t = 1, global_conf.chunk_size, 1 do
+ for i = 1, global_conf.batch_size, 1 do
+ if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
+ result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
+ end
+ end
+ end
+ if (result["rnn"].cn_w > next_log_wcn) then
+ next_log_wcn = next_log_wcn + global_conf.log_w_num
+ printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date())
+ printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
+ nerv.LMUtil.wait(0.1)
+ end
+
+ --[[
+ for t = 1, global_conf.chunk_size do
+ print(tnn.outputs_m[t][1])
+ end
+ ]]--
+
+ tnn:moveRightToNextMB()
+
+ collectgarbage("collect")
+
+ --break --debug
+ end
+
+ printf("%s Displaying result:\n", global_conf.sche_log_pre)
+ printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn"))
+ printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn)
+
+ return result
+end
+
+