aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua183
1 files changed, 123 insertions, 60 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 5001e12..5bf28bd 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,19 +1,33 @@
require 'lfs'
require 'pl'
local function build_trainer(ifname)
- local param_repo = nerv.ParamRepo()
- param_repo:import(ifname, nil, gconf)
- local layer_repo = make_layer_repo(param_repo)
- local network = get_network(layer_repo)
- local global_transf = get_global_transf(layer_repo)
- local input_order = get_input_order()
+ local host_param_repo = nerv.ParamRepo()
local mat_type
+ local src_loc_type
+ local train_loc_type
+ host_param_repo:import(ifname, nil, gconf)
if gconf.use_cpu then
mat_type = gconf.mmat_type
+ src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
+ train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
else
mat_type = gconf.cumat_type
+ src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
+ train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
- local iterative_trainer = function (prefix, scp_file, bp)
+ local param_repo = host_param_repo:copy(train_loc_type)
+ local layer_repo = make_layer_repo(param_repo)
+ local network = get_network(layer_repo)
+ local global_transf = get_global_transf(layer_repo)
+ local input_order = get_input_order()
+ local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo)
+ -- rebind the params if necessary
+ if rebind_param_repo then
+ host_param_repo = rebind_param_repo
+ param_repo = host_param_repo:copy(train_loc_type)
+ layer_repo:rebind(param_repo)
+ rebind_param_repo = nil
+ end
gconf.randomize = bp
-- build buffer
local buffer = make_buffer(make_readers(scp_file, layer_repo))
@@ -66,20 +80,38 @@ local function build_trainer(ifname)
print_stat(layer_repo)
mat_type.print_profile()
mat_type.clear_profile()
- if (not bp) and prefix ~= nil then
- nerv.info("writing back...")
- local fname = string.format("%s_cv%.3f.nerv",
- prefix, get_accuracy(layer_repo))
- network:get_params():export(fname, nil)
+ local fname
+ if (not bp) then
+ host_param_repo = param_repo:copy(src_loc_type)
+ if prefix ~= nil then
+ nerv.info("writing back...")
+ fname = string.format("%s_cv%.3f.nerv",
+ prefix, get_accuracy(layer_repo))
+ host_param_repo:export(fname, nil)
+ end
end
- return get_accuracy(layer_repo)
+ return get_accuracy(layer_repo), host_param_repo, fname
end
return iterative_trainer
end
-local function check_and_add_defaults(spec)
- for k, v in pairs(spec) do
- gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v
+local function check_and_add_defaults(spec, opts)
+ local function get_opt_val(k)
+ return opts[string.gsub(k, '_', '-')].val
+ end
+ local opt_v = get_opt_val("resume_from")
+ if opt_v then
+ gconf = dofile(opt_v)
+ else
+ for k, v in pairs(spec) do
+ local opt_v = get_opt_val(k)
+ if opt_v ~= nil then
+ gconf[k] = opt_v
+ elseif gconf[k] ~= nil then
+ elseif v ~= nil then
+ gconf[k] = v
+ end
+ end
end
end
@@ -112,6 +144,13 @@ local function print_gconf()
end
end
+local function dump_gconf(fname)
+ local f = io.open(fname, "w")
+ f:write("return ")
+ f:write(table.tostring(gconf))
+ f:close()
+end
+
local trainer_defaults = {
lrate = 0.8,
batch_size = 256,
@@ -121,22 +160,26 @@ local trainer_defaults = {
start_halving_inc = 0.5,
halving_factor = 0.6,
end_halving_inc = 0.1,
+ cur_iter = 1,
min_iter = 1,
max_iter = 20,
min_halving = 5,
do_halving = false,
- tr_scp = nil,
- cv_scp = nil,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
- debug = false
+ cumat_tname = "nerv.CuMatrixFloat",
+ mmat_tname = "nerv.MMatrixFloat",
+ debug = false,
}
local options = make_options(trainer_defaults)
-table.insert(options, {"help", "h", "boolean",
- default = false, desc = "show this help information"})
-table.insert(options, {"dir", nil, "string",
- default = nil, desc = "specify the working directory"})
+local extra_opt_spec = {
+ {"tr-scp", nil, "string"},
+ {"cv-scp", nil, "string"},
+ {"resume-from", nil, "string"},
+ {"help", "h", "boolean", default = false, desc = "show this help information"},
+ {"dir", nil, "string", desc = "specify the working directory"},
+}
+
+table.extend(options, extra_opt_spec)
arg, opts = nerv.parse_args(arg, options)
@@ -155,14 +198,16 @@ Note: config key like aaa_bbbb_cc could be overriden by specifying
]]--
-check_and_add_defaults(trainer_defaults)
+check_and_add_defaults(trainer_defaults, opts)
+gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
+gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
+gconf.use_cpu = econf.use_cpu or false
local pf0 = gconf.initialized_param
-local trainer = build_trainer(pf0)
-local accu_best = trainer(nil, gconf.cv_scp, false)
local date_pattern = "%Y%m%d%H%M%S"
local logfile_name = "log"
local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
+local rebind_param_repo = nil
print_gconf()
if not lfs.mkdir(working_dir) then
@@ -173,37 +218,55 @@ dir.copyfile(arg[1], working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))
path.chdir(working_dir)
-nerv.info("initial cross validation: %.3f", accu_best)
-for i = 1, gconf.max_iter do
- nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
- local accu_tr = trainer(nil, gconf.tr_scp, true)
- nerv.info("[TR] training set %d: %.3f", i, accu_tr)
- local accu_new = trainer(
- string.format("%s_%s_iter_%d_lr%f_tr%.3f",
- string.gsub(
- (string.gsub(pf0[1], "(.*/)(.*)", "%2")),
- "(.*)%..*", "%1"),
- os.date(date_pattern),
- i, gconf.lrate,
- accu_tr),
- gconf.cv_scp, false)
- nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
- -- TODO: revert the weights
- local accu_diff = accu_new - accu_best
- if gconf.do_halving and
- accu_diff < gconf.end_halving_inc and
- i > gconf.min_iter then
- break
- end
- if accu_diff < gconf.start_halving_inc and
- i >= gconf.min_halving then
- gconf.do_halving = true
- end
- if gconf.do_halving then
- gconf.lrate = gconf.lrate * gconf.halving_factor
- end
- if accu_new > accu_best then
- accu_best = accu_new
- end
+
+-- start the training
+local trainer = build_trainer(pf0)
+local pr_prev
+gconf.accu_best, pr_prev = trainer(nil, gconf.cv_scp, false)
+nerv.info("initial cross validation: %.3f", gconf.accu_best)
+for i = gconf.cur_iter, gconf.max_iter do
+ local stop = false
+ gconf.cur_iter = i
+ dump_gconf(string.format("iter_%d.meta", i))
+ repeat -- trick to implement `continue` statement
+ nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
+ local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo)
+ nerv.info("[TR] training set %d: %.3f", i, accu_tr)
+ local param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f",
+ string.gsub(
+ (string.gsub(pf0[1], "(.*/)(.*)", "%2")),
+ "(.*)%..*", "%1"),
+ os.date(date_pattern),
+ i, gconf.lrate,
+ accu_tr)
+ local accu_new, pr_new, param_fname = trainer(param_prefix, gconf.cv_scp, false)
+ nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
+ local accu_prev = gconf.accu_best
+ if accu_new < gconf.accu_best then
+ nerv.info("rejecting the trained params, rollback to the previous one")
+ file.move(param_fname, param_fname .. ".rejected")
+ rebind_param_repo = pr_prev
+ break -- `continue` equivalent
+ else
+ nerv.info("accepting the trained params")
+ gconf.accu_best = accu_new
+ pr_prev = pr_new
+ gconf.initialized_param = {path.join(path.currentdir(), param_fname)}
+ end
+ if gconf.do_halving and
+ gconf.accu_best - accu_prev < gconf.end_halving_inc and
+ i > gconf.min_iter then
+ stop = true
+ break
+ end
+ if gconf.accu_best - accu_prev < gconf.start_halving_inc and
+ i >= gconf.min_halving then
+ gconf.do_halving = true
+ end
+ if gconf.do_halving then
+ gconf.lrate = gconf.lrate * gconf.halving_factor
+ end
+ until true
+ if stop then break end
-- nerv.Matrix.print_profile()
end