require 'lmptb.lmvocab' require 'lmptb.lmfeeder' require 'lmptb.lmutil' require 'lmptb.layer.init' --require 'tnn.init' require 'lmptb.lmseqreader' local LMTrainer = nerv.class('nerv.LMTrainer') --local printf = nerv.printf --The bias param update in nerv don't have wcost added function nerv.BiasParam:update_by_gradient(gradient) local gconf = self.gconf local l2 = 1 - gconf.lrate * gconf.wcost self:_update_by_gradient(gradient, l2, l2) end --Returns: LMResult function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) if p_conf == nil then p_conf = {} end local reader local r_conf = {} if p_conf.compressed_label ~= nil then r_conf.compressed_label = p_conf.compressed_label end local chunk_size, batch_size if p_conf.one_sen_report == true then --report log prob one by one sentence if do_train == true then nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange") end nerv.printf("lm_process_file_rnn: one_sen report mode, set chunk_size to max_sen_len(%d)\n", global_conf.max_sen_len) batch_size = global_conf.batch_size chunk_size = global_conf.max_sen_len r_conf["se_mode"] = true else batch_size = global_conf.batch_size chunk_size = global_conf.chunk_size end reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab, r_conf) reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) result:init("rnn") if global_conf.dropout_rate ~= nil then nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate) end global_conf.timer:flush() tnn:init(batch_size, chunk_size) tnn:flush_all() --caution: will also flush the inputs from the reader! local next_log_wcn = global_conf.log_w_num local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output nerv.info("LMTrainer.lm_process_file_rnn: begin processing...") while (1) do global_conf.timer:tic('most_out_loop_lmprocessfile') local r, feeds global_conf.timer:tic('tnn_beforeprocess') r, feeds = tnn:getfeed_from_reader(reader) if r == false then break end for t = 1, chunk_size do tnn.err_inputs_m[t][1]:fill(1) for i = 1, batch_size do if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then tnn.err_inputs_m[t][1][i - 1][0] = 0 end end end global_conf.timer:toc('tnn_beforeprocess') --[[ for j = 1, global_conf.chunk_size, 1 do for i = 1, global_conf.batch_size, 1 do printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id end printf("\n") end printf("\n") ]]-- tnn:net_propagate() if do_train == true then tnn:net_backpropagate(false) tnn:net_backpropagate(true) end global_conf.timer:tic('tnn_afterprocess') local sen_logp = {} for t = 1, chunk_size, 1 do tnn.outputs_m[t][1]:copy_toh(neto_bakm) for i = 1, batch_size, 1 do if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0])) if sen_logp[i] == nil then sen_logp[i] = 0 end sen_logp[i] = sen_logp[i] + neto_bakm[i - 1][0] end end end if p_conf.one_sen_report == true then for i = 1, batch_size do if feeds.labels_s[1][i] ~= global_conf.vocab.null_token then nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report_output, %f\n", sen_logp[i]) end end end tnn:move_right_to_nextmb({0}) --only copy for time 0 global_conf.timer:toc('tnn_afterprocess') global_conf.timer:toc('most_out_loop_lmprocessfile') --print log if result["rnn"].cn_w > next_log_wcn then next_log_wcn = next_log_wcn + global_conf.log_w_num nerv.printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) nerv.printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")) for key, value in pairs(global_conf.timer.rec) do nerv.printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value) end global_conf.timer:flush() nerv.LMUtil.wait(0.1) end --[[ for t = 1, global_conf.chunk_size do print(tnn.outputs_m[t][1]) end ]]-- collectgarbage("collect") --break --debug end nerv.printf("%s Displaying result:\n", global_conf.sche_log_pre) nerv.printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn")) nerv.printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn) return result end --Returns: LMResult function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) if p_conf == nil then p_conf = {} end local reader local chunk_size, batch_size local r_conf = {["se_mode"] = true} if p_conf.compressed_label ~= nil then r_conf.compressed_label = p_conf.compressed_label end if p_conf.one_sen_report == true then --report log prob one by one sentence if do_train == true then nerv.warning("LMTrainer.lm_process_file_birnn: warning, one_sen_report is true while do_train is also true, strange") end nerv.printf("lm_process_file_birnn: one_sen report mode, set chunk_size to max_sen_len(%d)\n", global_conf.max_sen_len) batch_size = global_conf.batch_size chunk_size = global_conf.max_sen_len else batch_size = global_conf.batch_size chunk_size = global_conf.chunk_size end reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab, r_conf) reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) result:init("birnn") if global_conf.dropout_rate ~= nil then nerv.info("LMTrainer.lm_process_file_birnn: dropout_rate is %f", global_conf.dropout_rate) end global_conf.timer:flush() tnn:init(batch_size, chunk_size) tnn:flush_all() --caution: will also flush the inputs from the reader! local next_log_wcn = global_conf.log_w_num local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output nerv.info("LMTrainer.lm_process_file_birnn: begin processing...") while (1) do global_conf.timer:tic('most_out_loop_lmprocessfile') local r, feeds global_conf.timer:tic('tnn_beforeprocess') r, feeds = tnn:getfeed_from_reader(reader) if r == false then break end for t = 1, chunk_size do tnn.err_inputs_m[t][1]:fill(1) for i = 1, batch_size do if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then tnn.err_inputs_m[t][1][i - 1][0] = 0 end end end global_conf.timer:toc('tnn_beforeprocess') --[[ for j = 1, global_conf.chunk_size, 1 do for i = 1, global_conf.batch_size, 1 do printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id end printf("\n") end printf("\n") ]]-- tnn:net_propagate() if do_train == true then tnn:net_backpropagate(false) tnn:net_backpropagate(true) end global_conf.timer:tic('tnn_afterprocess') local sen_logp = {} for t = 1, chunk_size, 1 do tnn.outputs_m[t][1]:copy_toh(neto_bakm) for i = 1, batch_size, 1 do if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then result:add("birnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0])) if sen_logp[i] == nil then sen_logp[i] = 0 end sen_logp[i] = sen_logp[i] + neto_bakm[i - 1][0] end end end if p_conf.one_sen_report == true then for i = 1, batch_size do if sen_logp[i] ~= nil then nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i]) end end end --tnn:move_right_to_nextmb({0}) --do not need history for bi directional model global_conf.timer:toc('tnn_afterprocess') --tnn:flush_all() --you need this for bilstmlm_ptb_v2, because it has connection across 2 time steps global_conf.timer:toc('most_out_loop_lmprocessfile') --print log if result["birnn"].cn_w > next_log_wcn then next_log_wcn = next_log_wcn + global_conf.log_w_num nerv.printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["birnn"].cn_w, os.date()) nerv.printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("birnn")) for key, value in pairs(global_conf.timer.rec) do nerv.printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value) end global_conf.timer:flush() nerv.LMUtil.wait(0.1) end --[[ for t = 1, global_conf.chunk_size do print(tnn.outputs_m[t][1]) end ]]-- collectgarbage("collect") --break --debug end nerv.printf("%s Displaying result:\n", global_conf.sche_log_pre) nerv.printf("%s %s\n", global_conf.sche_log_pre, result:status("birnn")) nerv.printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn) return result end