require 'lmptb.lmvocab' require 'lmptb.lmfeeder' require 'lmptb.lmutil' require 'lmptb.layer.init' require 'rnn.init' require 'lmptb.lmseqreader' local LMTrainer = nerv.class('nerv.LMTrainer') local printf = nerv.printf --Returns: LMResult function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab) reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) result:init("rnn") global_conf.timer:flush() tnn:flush_all() --caution: will also flush the inputs from the reader! local next_log_wcn = global_conf.log_w_num while (1) do global_conf.timer:tic('most_out_loop_lmprocessfile') local r, feeds r, feeds = tnn:getFeedFromReader(reader) if (r == false) then break end for t = 1, global_conf.chunk_size do tnn.err_inputs_m[t][1]:fill(1) for i = 1, global_conf.batch_size do if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then tnn.err_inputs_m[t][1][i - 1][0] = 0 end end end --[[ for j = 1, global_conf.chunk_size, 1 do for i = 1, global_conf.batch_size, 1 do printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id end printf("\n") end printf("\n") ]]-- tnn:net_propagate() if (do_train == true) then tnn:net_backpropagate(false) tnn:net_backpropagate(true) end for t = 1, global_conf.chunk_size, 1 do for i = 1, global_conf.batch_size, 1 do if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) end end end tnn:moveRightToNextMB() global_conf.timer:toc('most_out_loop_lmprocessfile') --print log if (result["rnn"].cn_w > next_log_wcn) then next_log_wcn = next_log_wcn + global_conf.log_w_num printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")) for key, value in pairs(global_conf.timer.rec) do printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value) end global_conf.timer:flush() nerv.LMUtil.wait(0.1) end --[[ for t = 1, global_conf.chunk_size do print(tnn.outputs_m[t][1]) end ]]-- collectgarbage("collect") --break --debug end printf("%s Displaying result:\n", global_conf.sche_log_pre) printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn")) printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn) return result end