require 'lfs' require 'pl' -- ========================================================= -- Deal with command line input & init training envrioment -- ========================================================= local function check_and_add_defaults(spec, opts) local function get_opt_val(k) local k = string.gsub(k, '_', '-') return opts[k].val, opts[k].specified end local opt_v = get_opt_val("resume_from") if opt_v then nerv.info("resuming from previous training state") gconf = dofile(opt_v) end for k, v in pairs(spec) do local opt_v, specified = get_opt_val(k) if (not specified) and gconf[k] ~= nil then nerv.info("using setting in network config file: %s = %s", k, gconf[k]) elseif opt_v ~= nil then nerv.info("using setting in options: %s = %s", k, opt_v) gconf[k] = opt_v end end end local function make_options(spec) local options = {} for k, v in pairs(spec) do table.insert(options, {string.gsub(k, '_', '-'), nil, type(v), default = v}) end return options end local function print_help(options) nerv.printf("Usage: [options] \n") nerv.print_usage(options) end local function print_gconf() local key_maxlen = 0 for k, v in pairs(gconf) do key_maxlen = math.max(key_maxlen, #k or 0) end local function pattern_gen() return string.format("%%-%ds = %%s\n", key_maxlen) end nerv.info("ready to train with the following gconf settings:") nerv.printf(pattern_gen(), "Key", "Value") for k, v in pairs(gconf) do nerv.printf(pattern_gen(), k or "", v or "") end end local function dump_gconf(fname) local f = io.open(fname, "w") f:write("return ") f:write(table.tostring(gconf)) f:close() end local trainer_defaults = { lrate = 0.8, hfactor = 0.5, batch_size = 256, chunk_size = 1, buffer_size = 81920, wcost = 1e-6, momentum = 0.9, cur_iter = 1, max_iter = 20, randomize = true, cumat_tname = "nerv.CuMatrixFloat", mmat_tname = "nerv.MMatrixFloat", trainer_tname = "nerv.Trainer", } local options = make_options(trainer_defaults) local extra_opt_spec = { {"clip", nil, "number"}, {"resume-from", nil, "string"}, {"help", "h", "boolean", default = false, desc = "show this help information"}, {"dir", nil, "string", desc = "specify the working directory"}, } table.extend(options, extra_opt_spec) local opts arg, opts = nerv.parse_args(arg, options) if #arg < 1 or opts["help"].val then print_help(options) return end local script = arg[1] local script_arg = {} for i = 2, #arg do table.insert(script_arg, arg[i]) end arg = script_arg dofile(script) --[[ Rule: command-line option overrides network config overrides trainer default. Note: config key like aaa_bbbb_cc could be overriden by specifying --aaa-bbbb-cc to command-line arguments. ]]-- check_and_add_defaults(trainer_defaults, opts) gconf.mmat_type = nerv.get_type(gconf.mmat_tname) gconf.cumat_type = nerv.get_type(gconf.cumat_tname) gconf.trainer = nerv.get_type(gconf.trainer_tname) gconf.use_cpu = econf.use_cpu or false if gconf.initialized_param == nil then gconf.initialized_param = {} end if gconf.param_random == nil then gconf.param_random = function() return math.random() / 5 - 0.1 end end local date_pattern = "%Y-%m-%d_%H:%M:%S" local logfile_name = "log" local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern)) gconf.working_dir = working_dir gconf.date_pattern = date_pattern print_gconf() if not lfs.mkdir(working_dir) then nerv.error("[trainer] working directory already exists") end -- copy the network config dir.copyfile(script, working_dir) -- set logfile path nerv.set_logfile(path.join(working_dir, logfile_name)) -- ============ -- Main loop -- ============ local trainer = gconf.trainer(gconf) trainer:training_preprocess() gconf.best_cv = trainer:process('validate', false) nerv.info("initial cross validation: %.3f", gconf.best_cv) for i = gconf.cur_iter, gconf.max_iter do gconf.cur_iter = i dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i))) nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) local train_err = trainer:process('train', true) nerv.info("[TR] training set %d: %.3f", i, train_err) local cv_err = trainer:process('validate', false) nerv.info("[CV] cross validation %d: %.3f", i, cv_err) if gconf.test then local test_err = trainer:process('test', false) nerv.info('[TE] testset error %d: %.3f', i, test_err) end trainer:save_params(train_err, cv_err) end dump_gconf(path.join(working_dir, string.format("iter_%d.meta", gconf.max_iter + 1))) trainer:training_afterprocess()