aboutsummaryrefslogblamecommitdiff
path: root/nerv/nn/trainer.lua
blob: 8357c100ca8025eb1270d94fe4f592d5d0fdcc52 (plain) (tree)
1
2
3
4
5


                                          
                  
                      









                                                                
                                            
                                             

                                                                  
                                                   

                                                      
                                        

                                              

                                                 
                                                                        







                                                           
                                                           
                                      
                                              








                                      
                                                                           

                         



                                                                         

                                      
                                                         
                

                                                                                  



































                                                              
                                          
                     







                                                   






                                                     
                                            












                                               








                                               



                                              





                                                                           




                                                                          
                                  








                                                                               
                                
                                     
                         

















                                                      
                                                



















                                                   



                                       
local trainer = nerv.class('nerv.Trainer')

function trainer:__init(gconf)
    local mat_type
    self.gconf = gconf
    self.src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    local src_loc_type = self.src_loc_type
    if gconf.use_cpu then
        mat_type = gconf.mmat_type
        self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    else
        mat_type = gconf.cumat_type
        self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
    end
    local train_loc_type = self.train_loc_type
    local host_param_repo = nerv.ParamRepo()
    -- import the parameters from chunk files
    host_param_repo:import(gconf.initialized_param, gconf)
    local param_repo = host_param_repo:copy(train_loc_type, gconf)
    -- create layers and establish initial bindings
    self.layer_repo = self:make_layer_repo(param_repo)
    local layer_repo = self.layer_repo
    -- compile the network to be trained
    local graph = self:get_network(layer_repo)
    self.input_order = self:get_input_order()
    self.network = nerv.Network('network', gconf,
                                {network = graph,
                                 nn_act_default = gconf.nn_act_default})
    local network = self.network
    network:init(gconf.batch_size, gconf.chunk_size)

    local dim_in, dim_out = network.dim_in, network.dim_out
    self.err_output = {}
    local err_output = self.err_output
    for i = 1, #dim_in do
        err_output[i] = {}
        local dummy = mat_type(gconf.batch_size, dim_in[i])
        for t = 1, gconf.chunk_size do
            table.insert(err_output[i], dummy)
        end
    end
    self.output = {}
    self.err_input = {}
    local output = self.output
    local err_input = self.err_input
    for i = 1, #dim_out do
        output[i] = {}
        for t = 1, gconf.chunk_size do
            table.insert(output[i], mat_type(gconf.batch_size, dim_out[i]))
        end
        err_input[i] = {}
        if dim_out[i] ~= 1 then
            nerv.warning("the output has multiple heads, the default " ..
                        "`err_input` will be zero")
        end
        for t = 1, gconf.chunk_size do
            if dim_out[i] == 1 then
                table.insert(err_input[i], gconf.mask[t])
            else
                table.insert(err_input[i], mat_type(gconf.batch_size, dim_out[i]))
                err_input[i][t]:fill(0)
            end
        end
    end
end

function trainer:make_buffer(readers)
    local gconf = self.gconf
    if gconf.chunk_size == 1 then
        return nerv.FrmBuffer(gconf, {
            buffer_size = gconf.buffer_size,
            batch_size = gconf.batch_size,
            chunk_size = gconf.chunk_size,
            randomize = gconf.randomize,
            readers = readers,
            use_gpu = true,
        })
    else
        return nerv.SeqBuffer(gconf, {
            batch_size = gconf.batch_size,
            chunk_size = gconf.chunk_size,
            readers = readers,
        })
    end
end

function trainer:process(dataset, do_train)
    self:epoch_preprocess(dataset, do_train)
    local buffer = self:make_buffer(self:get_readers(dataset))
    local cnt = 0
    local network = self.network
    local input_order = self.input_order
    local output = self.output
    local err_input = self.err_input
    local err_output = self.err_output
    network:epoch_init()

    for data in buffer.get_data, buffer do
        cnt = cnt + 1
        local info = {input = {},
                      output = output,
                      err_input = err_input,
                      err_output = err_output,
                      do_train = do_train,
                      seq_length = data.seq_length,
                      new_seq = data.new_seq}

        for i = 1, #network.dim_in do
            info.input[i] = data.data[input_order[i]]
        end

        self:mini_batch_preprocess(cnt, info)
        network:mini_batch_init(info)
        network:propagate()
        self:mini_batch_inprocess(cnt, info)
        if do_train then
            network:back_propagate()
            network:update()
        end
        self:mini_batch_afterprocess(cnt, info)

        collectgarbage('collect')
    end

    self:epoch_afterprocess(dataset, do_train)
    return self:get_error()
end

function trainer:if_accept(cv_err)
    return cv_err < gconf.best_cv
end

function trainer:do_halving()
    gconf.lrate = gconf.lrate * gconf.hfactor
end

function trainer:save_params(train_err, cv_err)
    local gconf = self.gconf
    local src_loc_type = self.src_loc_type
    local train_loc_type = self.train_loc_type
    local layer_repo = self.layer_repo
    local param_fname = string.format('%s_iter_%d_lr%f_tr%.3f_cv%.3f.nerv',
                                      os.date(gconf.date_pattern),
                                      gconf.cur_iter,
                                      gconf.lrate,
                                      train_err,
                                      cv_err)
    param_fname = path.join(gconf.working_dir, param_fname)
    local network = self.network
    local host_param_repo = network:get_params():copy(src_loc_type, gconf)
    host_param_repo:export(param_fname)

    if self:if_accept(cv_err) then
        nerv.info("accepting the trained params")
        gconf.best_cv = cv_err
        gconf.initialized_param = {param_fname}
    else
        nerv.info("rejecting the trained params, rollback to the previous one")
        file.move(param_fname, param_fname .. '.rejected')
        host_param_repo = nerv.ParamRepo()
        host_param_repo:import(gconf.initialized_param, gconf)
        local param_repo = host_param_repo:copy(train_loc_type, gconf)
        -- rebind the parameters
        layer_repo:rebind(param_repo)
        self:do_halving()
    end
end

function trainer:training_preprocess()
end

function trainer:training_afterprocess()
end

function trainer:epoch_preprocess(dataset, do_train)
end

function trainer:epoch_afterprocess(dataset, do_train)
end

function trainer:mini_batch_preprocess(cnt, info)
end

function trainer:mini_batch_inprocess(cnt, info)
end

function trainer:mini_batch_afterprocess(cnt, info)
end

function trainer:make_layer_repo(param_repo)
    nerv.error_method_not_implemented()
end

function trainer:get_network(layer_repo)
    nerv.error_method_not_implemented()
end

function trainer:get_readers(dataset)
    nerv.error_method_not_implemented()
end

function trainer:get_input_order()
    nerv.error_method_not_implemented()
end

function trainer:get_error()
    nerv.error_method_not_implemented()
end