aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/trainer.lua')
-rw-r--r--nerv/examples/trainer.lua166
1 files changed, 166 insertions, 0 deletions
diff --git a/nerv/examples/trainer.lua b/nerv/examples/trainer.lua
new file mode 100644
index 0000000..b691f5b
--- /dev/null
+++ b/nerv/examples/trainer.lua
@@ -0,0 +1,166 @@
+nerv.include('trainer_class.lua')
+
+require 'lfs'
+require 'pl'
+
+-- =======================================================
+-- Deal with command line input & init training envrioment
+-- =======================================================
+
+local function check_and_add_defaults(spec, opts)
+ local function get_opt_val(k)
+ local k = string.gsub(k, '_', '-')
+ return opts[k].val, opts[k].specified
+ end
+ local opt_v = get_opt_val("resume_from")
+ if opt_v then
+ nerv.info("resuming from previous training state")
+ gconf = dofile(opt_v)
+ else
+ for k, v in pairs(spec) do
+ local opt_v, specified = get_opt_val(k)
+ if (not specified) and gconf[k] ~= nil then
+ nerv.info("using setting in network config file: %s = %s", k, gconf[k])
+ elseif opt_v ~= nil then
+ nerv.info("using setting in options: %s = %s", k, opt_v)
+ gconf[k] = opt_v
+ end
+ end
+ end
+end
+
+local function make_options(spec)
+ local options = {}
+ for k, v in pairs(spec) do
+ table.insert(options,
+ {string.gsub(k, '_', '-'), nil, type(v), default = v})
+ end
+ return options
+end
+
+local function print_help(options)
+ nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
+ nerv.print_usage(options)
+end
+
+local function print_gconf()
+ local key_maxlen = 0
+ for k, v in pairs(gconf) do
+ key_maxlen = math.max(key_maxlen, #k or 0)
+ end
+ local function pattern_gen()
+ return string.format("%%-%ds = %%s\n", key_maxlen)
+ end
+ nerv.info("ready to train with the following gconf settings:")
+ nerv.printf(pattern_gen(), "Key", "Value")
+ for k, v in pairs(gconf) do
+ nerv.printf(pattern_gen(), k or "", v or "")
+ end
+end
+
+local function dump_gconf(fname)
+ local f = io.open(fname, "w")
+ f:write("return ")
+ f:write(table.tostring(gconf))
+ f:close()
+end
+
+local trainer_defaults = {
+ lrate = 0.8,
+ batch_size = 256,
+ chunk_size = 1,
+ buffer_size = 81920,
+ wcost = 1e-6,
+ momentum = 0.9,
+ cur_iter = 1,
+ max_iter = 20,
+ cumat_tname = "nerv.CuMatrixFloat",
+ mmat_tname = "nerv.MMatrixFloat",
+ trainer_tname = "nerv.Trainer",
+}
+
+local options = make_options(trainer_defaults)
+local extra_opt_spec = {
+ {"resume-from", nil, "string"},
+ {"help", "h", "boolean", default = false, desc = "show this help information"},
+ {"dir", nil, "string", desc = "specify the working directory"},
+}
+
+table.extend(options, extra_opt_spec)
+
+local opts
+arg, opts = nerv.parse_args(arg, options)
+
+if #arg < 1 or opts["help"].val then
+ print_help(options)
+ return
+end
+
+local script = arg[1]
+local script_arg = {}
+for i = 2, #arg do
+ table.insert(script_arg, arg[i])
+end
+arg = script_arg
+dofile(script)
+
+--[[
+
+Rule: command-line option overrides network config overrides trainer default.
+Note: config key like aaa_bbbb_cc could be overriden by specifying
+--aaa-bbbb-cc to command-line arguments.
+
+]]--
+
+check_and_add_defaults(trainer_defaults, opts)
+gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
+gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
+gconf.trainer = nerv.get_type(gconf.trainer_tname)
+gconf.use_cpu = econf.use_cpu or false
+if gconf.initialized_param == nil then
+ gconf.initialized_param = {}
+end
+if gconf.param_random == nil then
+ gconf.param_random = function() return math.random() / 5 - 0.1 end
+end
+
+local date_pattern = "%Y-%m-%d_%H:%M:%S"
+local logfile_name = "log"
+local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
+gconf.working_dir = working_dir
+gconf.date_pattern = date_pattern
+
+print_gconf()
+if not lfs.mkdir(working_dir) then
+ nerv.error("[trainer] working directory already exists")
+end
+
+-- copy the network config
+dir.copyfile(script, working_dir)
+-- set logfile path
+nerv.set_logfile(path.join(working_dir, logfile_name))
+
+-- =============
+-- main function
+-- =============
+
+local trainer = gconf.trainer(gconf)
+trainer:training_preprocess()
+gconf.best_cv = trainer:process('validate', false)
+nerv.info("initial cross validation: %.3f", gconf.best_cv)
+
+for i = gconf.cur_iter, gconf.max_iter do
+ gconf.cur_iter = i
+ dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i)))
+ nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
+ local train_err = trainer:process('train', true)
+ nerv.info("[TR] training set %d: %.3f", i, train_err)
+ local cv_err = trainer:process('validate', false)
+ nerv.info("[CV] cross validation %d: %.3f", i, cv_err)
+ if gconf.test then
+ local test_err = trainer:process('test', false)
+ nerv.info('[TE] testset error %d: %.3f', i, test_err)
+ end
+ trainer:halving(train_err, cv_err)
+end
+trainer:training_afterprocess()