require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
--require 'tnn.init'
require 'lmptb.lmseqreader'
local LMTrainer = nerv.class('nerv.LMTrainer')
--local printf = nerv.printf
--The bias param update in nerv don't have wcost added
function nerv.BiasParam:update_by_gradient(gradient)
local gconf = self.gconf
local l2 = 1 - gconf.lrate * gconf.wcost
self:_update_by_gradient(gradient, l2, l2)
end
--Returns: LMResult
function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
if p_conf == nil then
p_conf = {}
end
local reader
local r_conf = {}
local chunk_size, batch_size
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
end
nerv.printf("lm_process_file_rnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n",
global_conf.max_sen_len)
batch_size = 1
chunk_size = global_conf.max_sen_len
r_conf["se_mode"] = true
else
batch_size = global_conf.batch_size
chunk_size = global_conf.chunk_size
end
reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab, r_conf)
reader:open_file(fn)
local result = nerv.LMResult(global_conf, global_conf.vocab)
result:init("rnn")
if global_conf.dropout_rate ~= nil then
nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate)
end
global_conf.timer:flush()
tnn:init(batch_size, chunk_size)
tnn:flush_all() --caution: will also flush the inputs from the reader!
local next_log_wcn = global_conf.log_w_num
local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output
nerv.info("LMTrainer.lm_process_file_rnn: begin processing...")
while (1) do
global_conf.timer:tic('most_out_loop_lmprocessfile')
local r, feeds
global_conf.timer:tic('tnn_beforeprocess')
r, feeds = tnn:getfeed_from_reader(reader)
if r == false then
break
end
for t = 1, chunk_size do
tnn.err_inputs_m[t][1]:fill(1)
for i = 1, batch_size do
if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then
tnn.err_inputs_m[t][1][i - 1][0] = 0
end
end
end
global_conf.timer:toc('tnn_beforeprocess')
--[[
for j = 1, global_conf.chunk_size, 1 do
for i = 1, global