aboutsummaryrefslogblamecommitdiff
path: root/nerv/nn/trainer.lua
blob: 4ae08d9b3956b8e20e47665f0cff36c91981f6ab (plain) (tree)






















































































































































































                                                                                                                                                        
local trainer = nerv.class('nerv.Trainer')

function trainer:__init(gconf)
    self.gconf = gconf
    local mat_type
    self.src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    local src_loc_type = self.src_loc_type
    if gconf.use_cpu then
        mat_type = gconf.mmat_type
        self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    else
        mat_type = gconf.cumat_type
        self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
    end
    local train_loc_type = self.train_loc_type

    local host_param_repo = nerv.ParamRepo()
    host_param_repo:import(gconf.initialized_param, gconf)
    local param_repo = host_param_repo:copy(train_loc_type, gconf)
    self.layer_repo = self:make_layer_repo(param_repo)
    local layer_repo = self.layer_repo
    local graph = self:get_network(layer_repo)
    self.input_order = self:get_input_order()

    self.network = nerv.Network('network', gconf, {network = graph, clip = gconf.clip})
    local network = self.network
    network:init(gconf.batch_size, gconf.chunk_size)

    local dim_in, dim_out = network.dim_in, network.dim_out
    self.err_output = {}
    local err_output = self.err_output
    for i = 1, #dim_in do
        err_output[i] = {}
        local tmp = mat_type(gconf.batch_size, dim_in[i])
        for t = 1, gconf.chunk_size do
            err_output[i][t] = tmp
        end
    end
    self.output = {}
    self.err_input = {}
    local output = self.output
    local err_input = self.err_input
    for i = 1, #dim_out do
        output[i] = {}
        for t = 1, gconf.chunk_size do
            output[i][t] = mat_type(gconf.batch_size, dim_out[i])
        end
        err_input[i] = {}
        local tmp = mat_type(gconf.batch_size, dim_out[i])
        tmp:fill(0)
        for t = 1, gconf.chunk_size do
            if dim_out[i] == 1 then
                err_input[i][t] = gconf.mask[t]
            else
                err_input[i][t] = tmp
            end
        end
    end
end

function trainer:make_buffer(readers)
    local gconf = self.gconf
    if gconf.chunk_size == 1 then
        return nerv.FrmBuffer(gconf, {
            buffer_size = gconf.buffer_size,
            batch_size = gconf.batch_size,
            chunk_size = gconf.chunk_size,
            randomize = gconf.randomize,
            readers = readers,
            use_gpu = true,
        })
    else
        return nerv.SeqBuffer(gconf, {
            batch_size = gconf.batch_size,
            chunk_size = gconf.chunk_size,
            readers = readers,
        })
    end
end

function trainer:process(dataset, do_train)
    self:epoch_preprocess(dataset, do_train)
    local buffer = self:make_buffer(self:get_readers(dataset))
    local cnt = 0
    local network = self.network
    local input_order = self.input_order
    local output = self.output
    local err_input = self.err_input
    local err_output = self.err_output
    network:epoch_init()

    while true do
        local data = buffer:get_data()
        if data == nil then
            break
        end

        cnt = cnt + 1
        local info = {input = {}, output = output, err_input = err_input, err_output = err_output,
            do_train = do_train, seq_length = data.seq_length, new_seq = data.new_seq}
        for i = 1, #network.dim_in do
            info.input[i] = data.data[input_order[i]]
        end

        self:mini_batch_preprocess(cnt, info)
        network:mini_batch_init(info)
        network:propagate()
        self:mini_batch_middleprocess(cnt, info)
        if do_train then
            network:back_propagate()
            network:update()
        end
        self:mini_batch_afterprocess(cnt, info)

        collectgarbage('collect')
    end

    self:epoch_afterprocess(dataset, do_train)
    return self:get_error()
end

function trainer:halving(train_err, cv_err)
    local gconf = self.gconf
    local src_loc_type = self.src_loc_type
    local train_loc_type = self.train_loc_type
    local layer_repo = self.layer_repo
    local param_fname = string.format('%s_iter_%d_lr%f_tr%.3f_cv%.3f.nerv', os.date(gconf.date_pattern), gconf.cur_iter, gconf.lrate, train_err, cv_err)
    param_fname = path.join(gconf.working_dir, param_fname)
    local network = self.network
    local host_param_repo = network:get_params():copy(src_loc_type, gconf)
    host_param_repo:export(param_fname)

    if cv_err < gconf.best_cv then
        nerv.info("accepting the trained params")
        gconf.best_cv = cv_err
        gconf.initialized_param = {param_fname}
    else
        nerv.info("rejecting the trained params, rollback to the previous one")
        file.move(param_fname, param_fname .. '.rejected')
        host_param_repo = nerv.ParamRepo()
        host_param_repo:import(gconf.initialized_param, gconf)
        local param_repo = host_param_repo:copy(train_loc_type, gconf)
        layer_repo:rebind(param_repo)
        gconf.lrate = gconf.lrate * 0.5
    end
end

function trainer:training_preprocess()
end

function trainer:training_afterprocess()
end

function trainer:epoch_preprocess(dataset, do_train)
end

function trainer:epoch_afterprocess(dataset, do_train)
end

function trainer:mini_batch_preprocess(cnt, info)
end

function trainer:mini_batch_middleprocess(cnt, info)
end

function trainer:mini_batch_afterprocess(cnt, info)
end

function trainer:make_layer_repo(param_repo)
    nerv.error_method_not_implemented()
end

function trainer:get_network(layer_repo)
    nerv.error_method_not_implemented()
end

function trainer:get_readers(dataset)
    nerv.error_method_not_implemented()
end

function trainer:get_input_order()
    nerv.error_method_not_implemented()
end