require 'lfs'
require 'pl'
-- =======================================================
-- Deal with command line input & init training envrioment
-- =======================================================
local function check_and_add_defaults(spec, opts)
local function get_opt_val(k)
local k = string.gsub(k, '_', '-')
return opts[k].val, opts[k].specified
end
local opt_v = get_opt_val("resume_from")
if opt_v then
nerv.info("resuming from previous training state")
gconf = dofile(opt_v)
else
for k, v in pairs(spec) do
local opt_v, specified = get_opt_val(k)
if (not specified) and gconf[k] ~= nil then
nerv.info("using setting in network config file: %s = %s", k, gconf[k])
elseif opt_v ~= nil then
nerv.info("using setting in options: %s = %s", k, opt_v)
gconf[k] = opt_v
end
end
end
end
local function make_options(spec)
local options = {}
for k, v in pairs(spec) do
table.insert(options,
{string.gsub(k, '_', '-'), nil, type(v), default = v})
end
return options
end
local function print_help(options)
nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
nerv.print_usage(options)
end
local function print_gconf()
local key_maxlen = 0
for k, v in pairs(gconf) do
key_maxlen = math.max(key_maxlen, #k or 0)
end
local function pattern_gen()
return string.format("%%-%ds = %%s\n", key_maxlen)
end
nerv.info("ready to train with the following gconf settings:")
nerv.printf(pattern_gen(), "Key", "Value")
for k, v in pairs(gconf) do
nerv.printf(pattern_gen(), k or "", v or "")
end
end
local function dump_gconf(fname)
local f = io.open(fname, "w")
f:write("return ")
f:write(table.tostring(gconf))
f:close()
end
local trainer_defaults = {
lrate = 0.8,
batch_size = 256,
chunk_size = 1,
buffer_size = 81920,
wcost = 1e-6,
momentum = 0.9,
cur_iter = 1,
max_iter = 20,
randomize = true,
cumat_tname = "nerv.CuMatrixFloat",
mmat_tname = "nerv.MMatrixFloat",
trainer_tname = "nerv.Trainer",
}
local options = make_options(trainer_defaults)
local extra_opt_spec = {
{"resume-from", nil, "string"},
{"help", "h", "boolean", default = false, desc = "show this help information"},
{"dir", nil, "string", desc = "specify the working directory"},
}
table.extend(options, extra_opt_spec)
local opts
arg, opts = nerv.parse_args(arg, options)
if #arg < 1 or opts["help"].val then
print_help(options)
return
end
local script = arg[1]
local script_arg = {}
for i = 2, #arg do
table.insert(script_arg, arg[i])
end
arg = script_arg
dofile(script)
--[[
Rule: command-line option overrides network config overrides trainer default.
Note: config key like aaa_bbbb_cc could be overriden by specifying
--aaa-bbbb-cc to command-line arguments.
]]--
check_and_add_defaults(trainer_defaults, opts)
gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
gconf.trainer = nerv.get_type(gconf.trainer_tname)
gconf.use_cpu = econf.use_cpu or false
if gconf.initialized_param == nil then
gconf.initialized_param = {}
end
if gconf.param_random == nil then
gconf.param_random = function() return math.random() / 5 - 0.1 end
end
local date_pattern = "%Y-%m-%d_%H:%M:%S"
local logfile_name = "log"
local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
gconf.working_dir = working_dir
gconf.date_pattern = date_pattern
print_gconf()
if not lfs.mkdir(working_dir) then
nerv.error("[trainer] working directory already exists")
end
-- copy the network config
dir.copyfile(script, working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))
-- =============
-- main function
-- =============
local trainer = gconf.trainer(gconf)
trainer:training_preprocess()
gconf.best_cv = trainer:process('validate', false)
nerv.info("initial cross validation: %.3f", gconf.best_cv)
for i = gconf.cur_iter, gconf.max_iter do
gconf.cur_iter = i
dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i)))
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local train_err = trainer:process('train', true)
nerv.info("[TR] training set %d: %.3f", i, train_err)
local cv_err = trainer:process('validate', false)
nerv.info("[CV] cross validation %d: %.3f", i, cv_err)
if gconf.test then
local test_err = trainer:process('test', false)
nerv.info('[TE] testset error %d: %.3f', i, test_err)
end
trainer:halving(train_err, cv_err)
end
trainer:training_afterprocess()