diff options
Diffstat (limited to 'nerv/examples/trainer.lua')
-rw-r--r-- | nerv/examples/trainer.lua | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/nerv/examples/trainer.lua b/nerv/examples/trainer.lua new file mode 100644 index 0000000..b691f5b --- /dev/null +++ b/nerv/examples/trainer.lua @@ -0,0 +1,166 @@ +nerv.include('trainer_class.lua') + +require 'lfs' +require 'pl' + +-- ======================================================= +-- Deal with command line input & init training envrioment +-- ======================================================= + +local function check_and_add_defaults(spec, opts) + local function get_opt_val(k) + local k = string.gsub(k, '_', '-') + return opts[k].val, opts[k].specified + end + local opt_v = get_opt_val("resume_from") + if opt_v then + nerv.info("resuming from previous training state") + gconf = dofile(opt_v) + else + for k, v in pairs(spec) do + local opt_v, specified = get_opt_val(k) + if (not specified) and gconf[k] ~= nil then + nerv.info("using setting in network config file: %s = %s", k, gconf[k]) + elseif opt_v ~= nil then + nerv.info("using setting in options: %s = %s", k, opt_v) + gconf[k] = opt_v + end + end + end +end + +local function make_options(spec) + local options = {} + for k, v in pairs(spec) do + table.insert(options, + {string.gsub(k, '_', '-'), nil, type(v), default = v}) + end + return options +end + +local function print_help(options) + nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n") + nerv.print_usage(options) +end + +local function print_gconf() + local key_maxlen = 0 + for k, v in pairs(gconf) do + key_maxlen = math.max(key_maxlen, #k or 0) + end + local function pattern_gen() + return string.format("%%-%ds = %%s\n", key_maxlen) + end + nerv.info("ready to train with the following gconf settings:") + nerv.printf(pattern_gen(), "Key", "Value") + for k, v in pairs(gconf) do + nerv.printf(pattern_gen(), k or "", v or "") + end +end + +local function dump_gconf(fname) + local f = io.open(fname, "w") + f:write("return ") + f:write(table.tostring(gconf)) + f:close() +end + +local trainer_defaults = { + lrate = 0.8, + batch_size = 256, + chunk_size = 1, + buffer_size = 81920, + wcost = 1e-6, + momentum = 0.9, + cur_iter = 1, + max_iter = 20, + cumat_tname = "nerv.CuMatrixFloat", + mmat_tname = "nerv.MMatrixFloat", + trainer_tname = "nerv.Trainer", +} + +local options = make_options(trainer_defaults) +local extra_opt_spec = { + {"resume-from", nil, "string"}, + {"help", "h", "boolean", default = false, desc = "show this help information"}, + {"dir", nil, "string", desc = "specify the working directory"}, +} + +table.extend(options, extra_opt_spec) + +local opts +arg, opts = nerv.parse_args(arg, options) + +if #arg < 1 or opts["help"].val then + print_help(options) + return +end + +local script = arg[1] +local script_arg = {} +for i = 2, #arg do + table.insert(script_arg, arg[i]) +end +arg = script_arg +dofile(script) + +--[[ + +Rule: command-line option overrides network config overrides trainer default. +Note: config key like aaa_bbbb_cc could be overriden by specifying +--aaa-bbbb-cc to command-line arguments. + +]]-- + +check_and_add_defaults(trainer_defaults, opts) +gconf.mmat_type = nerv.get_type(gconf.mmat_tname) +gconf.cumat_type = nerv.get_type(gconf.cumat_tname) +gconf.trainer = nerv.get_type(gconf.trainer_tname) +gconf.use_cpu = econf.use_cpu or false +if gconf.initialized_param == nil then + gconf.initialized_param = {} +end +if gconf.param_random == nil then + gconf.param_random = function() return math.random() / 5 - 0.1 end +end + +local date_pattern = "%Y-%m-%d_%H:%M:%S" +local logfile_name = "log" +local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern)) +gconf.working_dir = working_dir +gconf.date_pattern = date_pattern + +print_gconf() +if not lfs.mkdir(working_dir) then + nerv.error("[trainer] working directory already exists") +end + +-- copy the network config +dir.copyfile(script, working_dir) +-- set logfile path +nerv.set_logfile(path.join(working_dir, logfile_name)) + +-- ============= +-- main function +-- ============= + +local trainer = gconf.trainer(gconf) +trainer:training_preprocess() +gconf.best_cv = trainer:process('validate', false) +nerv.info("initial cross validation: %.3f", gconf.best_cv) + +for i = gconf.cur_iter, gconf.max_iter do + gconf.cur_iter = i + dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i))) + nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) + local train_err = trainer:process('train', true) + nerv.info("[TR] training set %d: %.3f", i, train_err) + local cv_err = trainer:process('validate', false) + nerv.info("[CV] cross validation %d: %.3f", i, cv_err) + if gconf.test then + local test_err = trainer:process('test', false) + nerv.info('[TE] testset error %d: %.3f', i, test_err) + end + trainer:halving(train_err, cv_err) +end +trainer:training_afterprocess() |