aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/trainer.lua')
-rw-r--r--nerv/nn/trainer.lua183
1 files changed, 183 insertions, 0 deletions
diff --git a/nerv/nn/trainer.lua b/nerv/nn/trainer.lua
new file mode 100644
index 0000000..4ae08d9
--- /dev/null
+++ b/nerv/nn/trainer.lua
@@ -0,0 +1,183 @@
+local trainer = nerv.class('nerv.Trainer')
+
+function trainer:__init(gconf)
+ self.gconf = gconf
+ local mat_type
+ self.src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
+ local src_loc_type = self.src_loc_type
+ if gconf.use_cpu then
+ mat_type = gconf.mmat_type
+ self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
+ else
+ mat_type = gconf.cumat_type
+ self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
+ end
+ local train_loc_type = self.train_loc_type
+
+ local host_param_repo = nerv.ParamRepo()
+ host_param_repo:import(gconf.initialized_param, gconf)
+ local param_repo = host_param_repo:copy(train_loc_type, gconf)
+ self.layer_repo = self:make_layer_repo(param_repo)
+ local layer_repo = self.layer_repo
+ local graph = self:get_network(layer_repo)
+ self.input_order = self:get_input_order()
+
+ self.network = nerv.Network('network', gconf, {network = graph, clip = gconf.clip})
+ local network = self.network
+ network:init(gconf.batch_size, gconf.chunk_size)
+
+ local dim_in, dim_out = network.dim_in, network.dim_out
+ self.err_output = {}
+ local err_output = self.err_output
+ for i = 1, #dim_in do
+ err_output[i] = {}
+ local tmp = mat_type(gconf.batch_size, dim_in[i])
+ for t = 1, gconf.chunk_size do
+ err_output[i][t] = tmp
+ end
+ end
+ self.output = {}
+ self.err_input = {}
+ local output = self.output
+ local err_input = self.err_input
+ for i = 1, #dim_out do
+ output[i] = {}
+ for t = 1, gconf.chunk_size do
+ output[i][t] = mat_type(gconf.batch_size, dim_out[i])
+ end
+ err_input[i] = {}
+ local tmp = mat_type(gconf.batch_size, dim_out[i])
+ tmp:fill(0)
+ for t = 1, gconf.chunk_size do
+ if dim_out[i] == 1 then
+ err_input[i][t] = gconf.mask[t]
+ else
+ err_input[i][t] = tmp
+ end
+ end
+ end
+end
+
+function trainer:make_buffer(readers)
+ local gconf = self.gconf
+ if gconf.chunk_size == 1 then
+ return nerv.FrmBuffer(gconf, {
+ buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
+ chunk_size = gconf.chunk_size,
+ randomize = gconf.randomize,
+ readers = readers,
+ use_gpu = true,
+ })
+ else
+ return nerv.SeqBuffer(gconf, {
+ batch_size = gconf.batch_size,
+ chunk_size = gconf.chunk_size,
+ readers = readers,
+ })
+ end
+end
+
+function trainer:process(dataset, do_train)
+ self:epoch_preprocess(dataset, do_train)
+ local buffer = self:make_buffer(self:get_readers(dataset))
+ local cnt = 0
+ local network = self.network
+ local input_order = self.input_order
+ local output = self.output
+ local err_input = self.err_input
+ local err_output = self.err_output
+ network:epoch_init()
+
+ while true do
+ local data = buffer:get_data()
+ if data == nil then
+ break
+ end
+
+ cnt = cnt + 1
+ local info = {input = {}, output = output, err_input = err_input, err_output = err_output,
+ do_train = do_train, seq_length = data.seq_length, new_seq = data.new_seq}
+ for i = 1, #network.dim_in do
+ info.input[i] = data.data[input_order[i]]
+ end
+
+ self:mini_batch_preprocess(cnt, info)
+ network:mini_batch_init(info)
+ network:propagate()
+ self:mini_batch_middleprocess(cnt, info)
+ if do_train then
+ network:back_propagate()
+ network:update()
+ end
+ self:mini_batch_afterprocess(cnt, info)
+
+ collectgarbage('collect')
+ end
+
+ self:epoch_afterprocess(dataset, do_train)
+ return self:get_error()
+end
+
+function trainer:halving(train_err, cv_err)
+ local gconf = self.gconf
+ local src_loc_type = self.src_loc_type
+ local train_loc_type = self.train_loc_type
+ local layer_repo = self.layer_repo
+ local param_fname = string.format('%s_iter_%d_lr%f_tr%.3f_cv%.3f.nerv', os.date(gconf.date_pattern), gconf.cur_iter, gconf.lrate, train_err, cv_err)
+ param_fname = path.join(gconf.working_dir, param_fname)
+ local network = self.network
+ local host_param_repo = network:get_params():copy(src_loc_type, gconf)
+ host_param_repo:export(param_fname)
+
+ if cv_err < gconf.best_cv then
+ nerv.info("accepting the trained params")
+ gconf.best_cv = cv_err
+ gconf.initialized_param = {param_fname}
+ else
+ nerv.info("rejecting the trained params, rollback to the previous one")
+ file.move(param_fname, param_fname .. '.rejected')
+ host_param_repo = nerv.ParamRepo()
+ host_param_repo:import(gconf.initialized_param, gconf)
+ local param_repo = host_param_repo:copy(train_loc_type, gconf)
+ layer_repo:rebind(param_repo)
+ gconf.lrate = gconf.lrate * 0.5
+ end
+end
+
+function trainer:training_preprocess()
+end
+
+function trainer:training_afterprocess()
+end
+
+function trainer:epoch_preprocess(dataset, do_train)
+end
+
+function trainer:epoch_afterprocess(dataset, do_train)
+end
+
+function trainer:mini_batch_preprocess(cnt, info)
+end
+
+function trainer:mini_batch_middleprocess(cnt, info)
+end
+
+function trainer:mini_batch_afterprocess(cnt, info)
+end
+
+function trainer:make_layer_repo(param_repo)
+ nerv.error_method_not_implemented()
+end
+
+function trainer:get_network(layer_repo)
+ nerv.error_method_not_implemented()
+end
+
+function trainer:get_readers(dataset)
+ nerv.error_method_not_implemented()
+end
+
+function trainer:get_input_order()
+ nerv.error_method_not_implemented()
+end