local trainer = nerv.class('nerv.Trainer') function trainer:__init(gconf) self.gconf = gconf local mat_type self.src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST local src_loc_type = self.src_loc_type if gconf.use_cpu then mat_type = gconf.mmat_type self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST else mat_type = gconf.cumat_type self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE end local train_loc_type = self.train_loc_type local host_param_repo = nerv.ParamRepo() host_param_repo:import(gconf.initialized_param, gconf) local param_repo = host_param_repo:copy(train_loc_type, gconf) self.layer_repo = self:make_layer_repo(param_repo) local layer_repo = self.layer_repo local graph = self:get_network(layer_repo) self.input_order = self:get_input_order() self.network = nerv.Network('network', gconf, {network = graph, clip = gconf.clip}) local network = self.network network:init(gconf.batch_size, gconf.chunk_size) local dim_in, dim_out = network.dim_in, network.dim_out self.err_output = {} local err_output = self.err_output for i = 1, #dim_in do err_output[i] = {} local tmp = mat_type(gconf.batch_size, dim_in[i]) for t = 1, gconf.chunk_size do err_output[i][t] = tmp end end self.output = {} self.err_input = {} local output = self.output local err_input = self.err_input for i = 1, #dim_out do output[i] = {} for t = 1, gconf.chunk_size do output[i][t] = mat_type(gconf.batch_size, dim_out[i]) end err_input[i] = {} local tmp = mat_type(gconf.batch_size, dim_out[i]) tmp:fill(0) for t = 1, gconf.chunk_size do if dim_out[i] == 1 then err_input[i][t] = gconf.mask[t] else err_input[i][t] = tmp end end end end function trainer:make_buffer(readers) local gconf = self.gconf if gconf.chunk_size == 1 then return nerv.FrmBuffer(gconf, { buffer_size = gconf.buffer_size, batch_size = gconf.batch_size, chunk_size = gconf.chunk_size, randomize = gconf.randomize, readers = readers, use_gpu = true, }) else return nerv.SeqBuffer(gconf, { batch_size = gconf.batch_size, chunk_size = gconf.chunk_size, readers = readers, }) end end function trainer:process(dataset, do_train) self:epoch_preprocess(dataset, do_train) local buffer = self:make_buffer(self:get_readers(dataset)) local cnt = 0 local network = self.network local input_order = self.input_order local output = self.output local err_input = self.err_input local err_output = self.err_output network:epoch_init() while true do local data = buffer:get_data() if data == nil then break end cnt = cnt + 1 local info = {input = {}, output = output, err_input = err_input, err_output = err_output, do_train = do_train, seq_length = data.seq_length, new_seq = data.new_seq} for i = 1, #network.dim_in do info.input[i] = data.data[input_order[i]] end self:mini_batch_preprocess(cnt, info) network:mini_batch_init(info) network:propagate() self:mini_batch_middleprocess(cnt, info) if do_train then network:back_propagate() network:update() end self:mini_batch_afterprocess(cnt, info) collectgarbage('collect') end self:epoch_afterprocess(dataset, do_train) return self:get_error() end function trainer:halving(train_err, cv_err) local gconf = self.gconf local src_loc_type = self.src_loc_type local train_loc_type = self.train_loc_type local layer_repo = self.layer_repo local param_fname = string.format('%s_iter_%d_lr%f_tr%.3f_cv%.3f.nerv', os.date(gconf.date_pattern), gconf.cur_iter, gconf.lrate, train_err, cv_err) param_fname = path.join(gconf.working_dir, param_fname) local network = self.network local host_param_repo = network:get_params():copy(src_loc_type, gconf) host_param_repo:export(param_fname) if cv_err < gconf.best_cv then nerv.info("accepting the trained params") gconf.best_cv = cv_err gconf.initialized_param = {param_fname} else nerv.info("rejecting the trained params, rollback to the previous one") file.move(param_fname, param_fname .. '.rejected') host_param_repo = nerv.ParamRepo() host_param_repo:import(gconf.initialized_param, gconf) local param_repo = host_param_repo:copy(train_loc_type, gconf) layer_repo:rebind(param_repo) gconf.lrate = gconf.lrate * 0.5 end end function trainer:training_preprocess() end function trainer:training_afterprocess() end function trainer:epoch_preprocess(dataset, do_train) end function trainer:epoch_afterprocess(dataset, do_train) end function trainer:mini_batch_preprocess(cnt, info) end function trainer:mini_batch_middleprocess(cnt, info) end function trainer:mini_batch_afterprocess(cnt, info) end function trainer:make_layer_repo(param_repo) nerv.error_method_not_implemented() end function trainer:get_network(layer_repo) nerv.error_method_not_implemented() end function trainer:get_readers(dataset) nerv.error_method_not_implemented() end function trainer:get_input_order() nerv.error_method_not_implemented() end