diff options
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua new file mode 100644 index 0000000..d34634c --- /dev/null +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -0,0 +1,89 @@ +require 'lmptb.lmvocab' +require 'lmptb.lmfeeder' +require 'lmptb.lmutil' +require 'lmptb.layer.init' +require 'rnn.init' +require 'lmptb.lmseqreader' + +local LMTrainer = nerv.class('nerv.LMTrainer') + +local printf = nerv.printf + +--Returns: LMResult +function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) + local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab) + reader:open_file(fn) + local result = nerv.LMResult(global_conf, global_conf.vocab) + result:init("rnn") + + tnn:flush_all() --caution: will also flush the inputs from the reader! + + local next_log_wcn = global_conf.log_w_num + + while (1) do + local r, feeds + + r, feeds = tnn:getFeedFromReader(reader) + if (r == false) then break end + + for t = 1, global_conf.chunk_size do + tnn.err_inputs_m[t][1]:fill(1) + for i = 1, global_conf.batch_size do + if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then + tnn.err_inputs_m[t][1][i - 1][0] = 0 + end + end + end + + --[[ + for j = 1, global_conf.chunk_size, 1 do + for i = 1, global_conf.batch_size, 1 do + printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id + end + printf("\n") + end + printf("\n") + ]]-- + + tnn:net_propagate() + + if (do_train == true) then + tnn:net_backpropagate(false) + tnn:net_backpropagate(true) + end + + for t = 1, global_conf.chunk_size, 1 do + for i = 1, global_conf.batch_size, 1 do + if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then + result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) + end + end + end + if (result["rnn"].cn_w > next_log_wcn) then + next_log_wcn = next_log_wcn + global_conf.log_w_num + printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) + printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")) + nerv.LMUtil.wait(0.1) + end + + --[[ + for t = 1, global_conf.chunk_size do + print(tnn.outputs_m[t][1]) + end + ]]-- + + tnn:moveRightToNextMB() + + collectgarbage("collect") + + --break --debug + end + + printf("%s Displaying result:\n", global_conf.sche_log_pre) + printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn")) + printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn) + + return result +end + + |