diff options
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 56 |
1 files changed, 37 insertions, 19 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 44862dc..9ef4794 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -2,41 +2,55 @@ require 'lmptb.lmvocab' require 'lmptb.lmfeeder' require 'lmptb.lmutil' require 'lmptb.layer.init' -require 'rnn.init' +--require 'tnn.init' require 'lmptb.lmseqreader' local LMTrainer = nerv.class('nerv.LMTrainer') -local printf = nerv.printf +--local printf = nerv.printf + +--The bias param update in nerv don't have wcost added +function nerv.BiasParam:update_by_gradient(gradient) + local gconf = self.gconf + local l2 = 1 - gconf.lrate * gconf.wcost + self:_update_by_gradient(gradient, l2, l2) +end --Returns: LMResult -function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) +function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train) local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab) reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) result:init("rnn") - + if global_conf.dropout_rate ~= nil then + nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate) + end + global_conf.timer:flush() tnn:flush_all() --caution: will also flush the inputs from the reader! local next_log_wcn = global_conf.log_w_num + local neto_bakm = global_conf.mmat_type(global_conf.batch_size, 1) --space backup matrix for network output while (1) do global_conf.timer:tic('most_out_loop_lmprocessfile') local r, feeds - - r, feeds = tnn:getFeedFromReader(reader) - if (r == false) then break end + global_conf.timer:tic('tnn_beforeprocess') + r, feeds = tnn:getfeed_from_reader(reader) + if r == false then + break + end for t = 1, global_conf.chunk_size do tnn.err_inputs_m[t][1]:fill(1) for i = 1, global_conf.batch_size do - if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then + if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then tnn.err_inputs_m[t][1][i - 1][0] = 0 end end end + global_conf.timer:toc('tnn_beforeprocess') --[[ for j = 1, global_conf.chunk_size, 1 do @@ -50,29 +64,33 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) tnn:net_propagate() - if (do_train == true) then + if do_train == true then tnn:net_backpropagate(false) tnn:net_backpropagate(true) end - + + global_conf.timer:tic('tnn_afterprocess') for t = 1, global_conf.chunk_size, 1 do + tnn.outputs_m[t][1]:copy_toh(neto_bakm) for i = 1, global_conf.batch_size, 1 do if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then - result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) + --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) + result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0])) end end end + tnn:move_right_to_nextmb({0}) --only copy for time 0 + global_conf.timer:toc('tnn_afterprocess') - tnn:moveRightToNextMB() global_conf.timer:toc('most_out_loop_lmprocessfile') --print log - if (result["rnn"].cn_w > next_log_wcn) then + if result["rnn"].cn_w > next_log_wcn then next_log_wcn = next_log_wcn + global_conf.log_w_num - printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) - printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")) + nerv.printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) + nerv.printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")) for key, value in pairs(global_conf.timer.rec) do - printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value) + nerv.printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value) end global_conf.timer:flush() nerv.LMUtil.wait(0.1) @@ -90,9 +108,9 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) --break --debug end - printf("%s Displaying result:\n", global_conf.sche_log_pre) - printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn")) - printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn) + nerv.printf("%s Displaying result:\n", global_conf.sche_log_pre) + nerv.printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn")) + nerv.printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn) return result end |