aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/trainer.lua
blob: caed2e26bd38bceb88cc64d296f574b0c2f53c6a (plain) (tree)
1
2
3
4
5
6


             


                                                            









                                                          







                                                                                   













                                                                          
                                                                        


























                                                                  
                  






                        
                     






                                              
                            












































                                                                                   

                                                                   












                                                            


               

















                                                                          
                                          
   
                                                                                     
                               
require 'lfs'
require 'pl'

-- =========================================================
--  Deal with command line input & init training envrioment
-- =========================================================

local function check_and_add_defaults(spec, opts)
    local function get_opt_val(k)
        local k = string.gsub(k, '_', '-')
        return opts[k].val, opts[k].specified
    end
    local opt_v = get_opt_val("resume_from")
    if opt_v then
        nerv.info("resuming from previous training state")
        gconf = dofile(opt_v)
    end
    for k, v in pairs(spec) do
        local opt_v, specified = get_opt_val(k)
        if (not specified) and gconf[k] ~= nil then
            nerv.info("using setting in network config file: %s = %s", k, gconf[k])
        elseif opt_v ~= nil then
            nerv.info("using setting in options: %s = %s", k, opt_v)
            gconf[k] = opt_v
        end
    end
end

local function make_options(spec)
    local options = {}
    for k, v in pairs(spec) do
        table.insert(options,
                    {string.gsub(k, '_', '-'), nil, type(v), default = v})
    end
    return options
end

local function print_help(options)
    nerv.printf("Usage: <trainer.lua> [options] <network_config.lua>\n")
    nerv.print_usage(options)
end

local function print_gconf()
    local key_maxlen = 0
    for k, v in pairs(gconf) do
        key_maxlen = math.max(key_maxlen, #k or 0)
    end
    local function pattern_gen()
        return string.format("%%-%ds = %%s\n", key_maxlen)
    end
    nerv.info("ready to train with the following gconf settings:")
    nerv.printf(pattern_gen(), "Key", "Value")
    for k, v in pairs(gconf) do
        nerv.printf(pattern_gen(), k or "", v or "")
    end
end

local function dump_gconf(fname)
    local f = io.open(fname, "w")
    f:write("return ")
    f:write(table.tostring(gconf))
    f:close()
end

local trainer_defaults = {
    lrate = 0.8,
    hfactor = 0.5,
    batch_size = 256,
    chunk_size = 1,
    buffer_size = 81920,
    wcost = 1e-6,
    momentum = 0.9,
    cur_iter = 1,
    max_iter = 20,
    randomize = true,
    cumat_tname = "nerv.CuMatrixFloat",
    mmat_tname = "nerv.MMatrixFloat",
    trainer_tname = "nerv.Trainer",
}

local options = make_options(trainer_defaults)
local extra_opt_spec = {
    {"clip", nil, "number"},
    {"resume-from", nil, "string"},
    {"help", "h", "boolean", default = false, desc = "show this help information"},
    {"dir", nil, "string", desc = "specify the working directory"},
}

table.extend(options, extra_opt_spec)

local opts
arg, opts = nerv.parse_args(arg, options)

if #arg < 1 or opts["help"].val then
    print_help(options)
    return
end

local script = arg[1]
local script_arg = {}
for i = 2, #arg do
    table.insert(script_arg, arg[i])
end
arg = script_arg
dofile(script)

--[[

Rule: command-line option overrides network config overrides trainer default.
Note: config key like aaa_bbbb_cc could be overriden by specifying
--aaa-bbbb-cc to command-line arguments.

]]--

check_and_add_defaults(trainer_defaults, opts)
gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
gconf.trainer = nerv.get_type(gconf.trainer_tname)
gconf.use_cpu = econf.use_cpu or false
if gconf.initialized_param == nil then
    gconf.initialized_param = {}
end
if gconf.param_random == nil then
    gconf.param_random = function() return math.random() / 5 - 0.1 end
end

local date_pattern = "%Y-%m-%d_%H:%M:%S"
local logfile_name = "log"
local working_dir = opts["dir"].val or
                    string.format("nerv_%s", os.date(date_pattern))
gconf.working_dir = working_dir
gconf.date_pattern = date_pattern

print_gconf()
if not lfs.mkdir(working_dir) then
    nerv.error("[trainer] working directory already exists")
end

-- copy the network config
dir.copyfile(script, working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))

-- ============
--  Main loop
-- ============

local trainer = gconf.trainer(gconf)
trainer:training_preprocess()
gconf.best_cv = trainer:process('validate', false)
nerv.info("initial cross validation: %.3f", gconf.best_cv)

for i = gconf.cur_iter, gconf.max_iter do
    gconf.cur_iter = i
    dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i)))
    nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
    local train_err = trainer:process('train', true)
    nerv.info("[TR] training set %d: %.3f", i, train_err)
    local cv_err = trainer:process('validate', false)
    nerv.info("[CV] cross validation %d: %.3f", i, cv_err)
    if gconf.test then
        local test_err = trainer:process('test', false)
        nerv.info('[TE] testset error %d: %.3f', i, test_err)
    end
    trainer:save_params(train_err, cv_err)
end
dump_gconf(path.join(working_dir, string.format("iter_%d.meta", gconf.max_iter + 1)))
trainer:training_afterprocess()