local trainer = nerv.class('nerv.Trainer')
function trainer:__init(gconf)
self.gconf = gconf
local mat_type
self.src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
local src_loc_type = self.src_loc_type
if gconf.use_cpu then
mat_type = gconf.mmat_type
self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
else
mat_type = gconf.cumat_type
self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
local train_loc_type = self.train_loc_type
local host_param_repo = nerv.ParamRepo()
host_param_repo:import(gconf.initialized_param, gconf)
local param_repo = host_param_repo:copy(train_loc_type, gconf)
self.layer_repo = self:make_layer_repo(param_repo)
local layer_repo = self.layer_repo
local graph = self:get_network(layer_repo)
self.input_order = self:get_input_order()
self.network = nerv.Network('network', gconf, {network = graph, clip = gconf.clip})
local network = self.network
network:init(gconf.batch_size, gconf.chunk_size)
local dim_in, dim_out = network.dim_in, network.dim_out
self.err_output = {}
local err_output = self.err_output
for i = 1, #dim_in do
err_output[i] = {}
local tmp = mat_type(gconf.batch_size, dim_in[i])
for t = 1, gconf.chunk_size do
err_output[i][t] = tmp
end
end
self.output = {}
self.err_input = {}
local output = self.output
local err_input = self.err_input
for i = 1, #dim_out do
output[i] = {}
for t = 1, gconf.chunk_size do
output[i][t] = mat_type(gconf.batch_size, dim_out[i])
end
err_input[i] = {}
local tmp = mat_type(gconf.batch_size, dim_out[i])
tmp:fill(0)
for t = 1, gconf.chunk_size do
if dim_out[i] == 1 then
err_input[i][t] = gconf.mask[t]
else
err_input[i][t] = tmp
end
end
end
end
function trainer:make_buffer(readers)
local gconf = self.gconf
if gconf.chunk_size == 1 then
return nerv.FrmBuffer(gconf, {
buffer_size = gconf.buffer_size,
batch_size = gconf.batch_size,
chunk_size = gconf.chunk_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true,
})
else
return nerv.SeqBuffer(gconf, {
batch_size = gconf.batch_size,
chunk_size = gconf.chunk_size,
readers = readers,
})
end
end
function trainer:process(dataset, do_train)
self:epoch_preprocess(dataset, do_train)
local buffer = self:make_buffer(self:get_readers(dataset))
local cnt = 0
local network = self.network
local input_order = self.input_order
local output = self.output
local err_input = self.err_input
local err_output = self.err_output
network:epoch_init()
while true do
local data = buffer:get_data()
if data == nil then
break
end
cnt = cnt + 1
local info = {input = {}, output = output, err_input = err_input, err_output = err_output,
do_train = do_train, seq_length = data.seq_length, new_seq = data.new_seq}
for i = 1, #network.dim_in do
info.input[i] = data.data[input_order[i]]
end
self:mini_batch_preprocess(cnt, info)
network:mini_batch_init(info)
network:propagate()
self:mini_batch_middleprocess(cnt, info)
if do_train then
network:back_propagate()
network:update()
end
self:mini_batch_afterprocess(cnt, info)
collectgarbage('collect')
end
self:epoch_afterprocess(dataset, do_train)
return self:get_error()
end
function trainer:halving(train_err, cv_err)
local gconf = self.gconf
local src_loc_type = self.src_loc_type
local train_loc_type = self.train_loc_type
local layer_repo = self.layer_repo
local param_fname = string.format('%s_iter_%d_lr%f_tr%.3f_cv%.3f.nerv', os.date(gconf.date_pattern), gconf.cur_iter, gconf.lrate, train_err, cv_err)
param_fname = path.join(gconf.working_dir, param_fname)
local network = self.network
local host_param_repo = network:get_params():copy(src_loc_type, gconf)
host_param_repo:export(param_fname)
if cv_err < gconf.best_cv then
nerv.info("accepting the trained params")
gconf.best_cv = cv_err
gconf.initialized_param = {param_fname}
else
nerv.info("rejecting the trained params, rollback to the previous one")
file.move(param_fname, param_fname .. '.rejected')
host_param_repo = nerv.ParamRepo()
host_param_repo:import(gconf.initialized_param, gconf)
local param_repo = host_param_repo:copy(train_loc_type, gconf)
layer_repo:rebind(param_repo)
gconf.lrate = gconf.lrate * 0.5
end
end
function trainer:training_preprocess()
end
function trainer:training_afterprocess()
end
function trainer:epoch_preprocess(dataset, do_train)
end
function trainer:epoch_afterprocess(dataset, do_train)
end
function trainer:mini_batch_preprocess(cnt, info)
end
function trainer:mini_batch_middleprocess(cnt, info)
end
function trainer:mini_batch_afterprocess(cnt, info)
end
function trainer:make_layer_repo(param_repo)
nerv.error_method_not_implemented()
end
function trainer:get_network(layer_repo)
nerv.error_method_not_implemented()
end
function trainer:get_readers(dataset)
nerv.error_method_not_implemented()
end
function trainer:get_input_order()
nerv.error_method_not_implemented()
end