aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua256
1 files changed, 202 insertions, 54 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 3fa2653..5bf28bd 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,17 +1,33 @@
-function build_trainer(ifname)
- local param_repo = nerv.ParamRepo()
- param_repo:import(ifname, nil, gconf)
- local layer_repo = make_layer_repo(param_repo)
- local network = get_network(layer_repo)
- local global_transf = get_global_transf(layer_repo)
- local input_order = get_input_order()
+require 'lfs'
+require 'pl'
+local function build_trainer(ifname)
+ local host_param_repo = nerv.ParamRepo()
local mat_type
+ local src_loc_type
+ local train_loc_type
+ host_param_repo:import(ifname, nil, gconf)
if gconf.use_cpu then
mat_type = gconf.mmat_type
+ src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
+ train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
else
mat_type = gconf.cumat_type
+ src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
+ train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
- local iterative_trainer = function (prefix, scp_file, bp)
+ local param_repo = host_param_repo:copy(train_loc_type)
+ local layer_repo = make_layer_repo(param_repo)
+ local network = get_network(layer_repo)
+ local global_transf = get_global_transf(layer_repo)
+ local input_order = get_input_order()
+ local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo)
+ -- rebind the params if necessary
+ if rebind_param_repo then
+ host_param_repo = rebind_param_repo
+ param_repo = host_param_repo:copy(train_loc_type)
+ layer_repo:rebind(param_repo)
+ rebind_param_repo = nil
+ end
gconf.randomize = bp
-- build buffer
local buffer = make_buffer(make_readers(scp_file, layer_repo))
@@ -64,61 +80,193 @@ function build_trainer(ifname)
print_stat(layer_repo)
mat_type.print_profile()
mat_type.clear_profile()
- if (not bp) and prefix ~= nil then
- nerv.info("writing back...")
- local fname = string.format("%s_cv%.3f.nerv",
- prefix, get_accuracy(layer_repo))
- network:get_params():export(fname, nil)
+ local fname
+ if (not bp) then
+ host_param_repo = param_repo:copy(src_loc_type)
+ if prefix ~= nil then
+ nerv.info("writing back...")
+ fname = string.format("%s_cv%.3f.nerv",
+ prefix, get_accuracy(layer_repo))
+ host_param_repo:export(fname, nil)
+ end
end
- return get_accuracy(layer_repo)
+ return get_accuracy(layer_repo), host_param_repo, fname
end
return iterative_trainer
end
-dofile(arg[1])
-start_halving_inc = 0.5
-halving_factor = 0.6
-end_halving_inc = 0.1
-min_iter = 1
-max_iter = 20
-min_halving = 5
-gconf.batch_size = 256
-gconf.buffer_size = 81920
+local function check_and_add_defaults(spec, opts)
+ local function get_opt_val(k)
+ return opts[string.gsub(k, '_', '-')].val
+ end
+ local opt_v = get_opt_val("resume_from")
+ if opt_v then
+ gconf = dofile(opt_v)
+ else
+ for k, v in pairs(spec) do
+ local opt_v = get_opt_val(k)
+ if opt_v ~= nil then
+ gconf[k] = opt_v
+ elseif gconf[k] ~= nil then
+ elseif v ~= nil then
+ gconf[k] = v
+ end
+ end
+ end
+end
-local pf0 = gconf.initialized_param
-local trainer = build_trainer(pf0)
---local trainer = build_trainer("c3.nerv")
-local accu_best = trainer(nil, gconf.cv_scp, false)
-local do_halving = false
-
-nerv.info("initial cross validation: %.3f", accu_best)
-for i = 1, max_iter do
- nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
- local accu_tr = trainer(nil, gconf.tr_scp, true)
- nerv.info("[TR] training set %d: %.3f", i, accu_tr)
- local accu_new = trainer(
- string.format("%s_%s_iter_%d_lr%f_tr%.3f",
- string.gsub(
- (string.gsub(pf0[1], "(.*/)(.*)", "%2")),
- "(.*)%..*", "%1"),
- os.date("%Y%m%d%H%M%S"),
- i, gconf.lrate,
- accu_tr),
- gconf.cv_scp, false)
- nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
- -- TODO: revert the weights
- local accu_diff = accu_new - accu_best
- if do_halving and accu_diff < end_halving_inc and i > min_iter then
- break
+local function make_options(spec)
+ local options = {}
+ for k, v in pairs(spec) do
+ table.insert(options,
+ {string.gsub(k, '_', '-'), nil, type(v), default = v})
end
- if accu_diff < start_halving_inc and i >= min_halving then
- do_halving = true
+ return options
+end
+
+local function print_help(options)
+ nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
+ nerv.print_usage(options)
+end
+
+local function print_gconf()
+ local key_maxlen = 0
+ for k, v in pairs(gconf) do
+ key_maxlen = math.max(key_maxlen, #k or 0)
end
- if do_halving then
- gconf.lrate = gconf.lrate * halving_factor
+ local function pattern_gen()
+ return string.format("%%-%ds = %%s\n", key_maxlen)
end
- if accu_new > accu_best then
- accu_best = accu_new
+ nerv.info("ready to train with the following gconf settings:")
+ nerv.printf(pattern_gen(), "Key", "Value")
+ for k, v in pairs(gconf) do
+ nerv.printf(pattern_gen(), k or "", v or "")
end
+end
+
+local function dump_gconf(fname)
+ local f = io.open(fname, "w")
+ f:write("return ")
+ f:write(table.tostring(gconf))
+ f:close()
+end
+
+local trainer_defaults = {
+ lrate = 0.8,
+ batch_size = 256,
+ buffer_size = 81920,
+ wcost = 1e-6,
+ momentum = 0.9,
+ start_halving_inc = 0.5,
+ halving_factor = 0.6,
+ end_halving_inc = 0.1,
+ cur_iter = 1,
+ min_iter = 1,
+ max_iter = 20,
+ min_halving = 5,
+ do_halving = false,
+ cumat_tname = "nerv.CuMatrixFloat",
+ mmat_tname = "nerv.MMatrixFloat",
+ debug = false,
+}
+
+local options = make_options(trainer_defaults)
+local extra_opt_spec = {
+ {"tr-scp", nil, "string"},
+ {"cv-scp", nil, "string"},
+ {"resume-from", nil, "string"},
+ {"help", "h", "boolean", default = false, desc = "show this help information"},
+ {"dir", nil, "string", desc = "specify the working directory"},
+}
+
+table.extend(options, extra_opt_spec)
+
+arg, opts = nerv.parse_args(arg, options)
+
+if #arg < 1 or opts["help"].val then
+ print_help(options)
+ return
+end
+
+dofile(arg[1])
+
+--[[
+
+Rule: command-line option overrides network config overrides trainer default.
+Note: config key like aaa_bbbb_cc could be overriden by specifying
+--aaa-bbbb-cc to command-line arguments.
+
+]]--
+
+check_and_add_defaults(trainer_defaults, opts)
+gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
+gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
+gconf.use_cpu = econf.use_cpu or false
+
+local pf0 = gconf.initialized_param
+local date_pattern = "%Y%m%d%H%M%S"
+local logfile_name = "log"
+local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
+local rebind_param_repo = nil
+
+print_gconf()
+if not lfs.mkdir(working_dir) then
+ nerv.error("[asr_trainer] working directory already exists")
+end
+-- copy the network config
+dir.copyfile(arg[1], working_dir)
+-- set logfile path
+nerv.set_logfile(path.join(working_dir, logfile_name))
+path.chdir(working_dir)
+
+-- start the training
+local trainer = build_trainer(pf0)
+local pr_prev
+gconf.accu_best, pr_prev = trainer(nil, gconf.cv_scp, false)
+nerv.info("initial cross validation: %.3f", gconf.accu_best)
+for i = gconf.cur_iter, gconf.max_iter do
+ local stop = false
+ gconf.cur_iter = i
+ dump_gconf(string.format("iter_%d.meta", i))
+ repeat -- trick to implement `continue` statement
+ nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
+ local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo)
+ nerv.info("[TR] training set %d: %.3f", i, accu_tr)
+ local param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f",
+ string.gsub(
+ (string.gsub(pf0[1], "(.*/)(.*)", "%2")),
+ "(.*)%..*", "%1"),
+ os.date(date_pattern),
+ i, gconf.lrate,
+ accu_tr)
+ local accu_new, pr_new, param_fname = trainer(param_prefix, gconf.cv_scp, false)
+ nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
+ local accu_prev = gconf.accu_best
+ if accu_new < gconf.accu_best then
+ nerv.info("rejecting the trained params, rollback to the previous one")
+ file.move(param_fname, param_fname .. ".rejected")
+ rebind_param_repo = pr_prev
+ break -- `continue` equivalent
+ else
+ nerv.info("accepting the trained params")
+ gconf.accu_best = accu_new
+ pr_prev = pr_new
+ gconf.initialized_param = {path.join(path.currentdir(), param_fname)}
+ end
+ if gconf.do_halving and
+ gconf.accu_best - accu_prev < gconf.end_halving_inc and
+ i > gconf.min_iter then
+ stop = true
+ break
+ end
+ if gconf.accu_best - accu_prev < gconf.start_halving_inc and
+ i >= gconf.min_halving then
+ gconf.do_halving = true
+ end
+ if gconf.do_halving then
+ gconf.lrate = gconf.lrate * gconf.halving_factor
+ end
+ until true
+ if stop then break end
-- nerv.Matrix.print_profile()
end