diff options
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r-- | nerv/examples/asr_trainer.lua | 104 |
1 files changed, 87 insertions, 17 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index 3fa2653..684ea30 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -1,4 +1,4 @@ -function build_trainer(ifname) +local function build_trainer(ifname) local param_repo = nerv.ParamRepo() param_repo:import(ifname, nil, gconf) local layer_repo = make_layer_repo(param_repo) @@ -75,24 +75,91 @@ function build_trainer(ifname) return iterative_trainer end +local function check_and_add_defaults(spec) + for k, v in pairs(spec) do + gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v + end +end + +local function make_options(spec) + local options = {} + for k, v in pairs(spec) do + table.insert(options, + {string.gsub(k, '_', '-'), nil, type(v), default = v}) + end + return options +end + +local function print_help(options) + nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n") + nerv.print_usage(options) +end + +local function print_gconf() + local key_maxlen = 0 + for k, v in pairs(gconf) do + key_maxlen = math.max(key_maxlen, #k or 0) + end + local function pattern_gen() + return string.format("%%-%ds = %%s\n", key_maxlen) + end + nerv.info("ready to train with the following gconf settings:") + nerv.printf(pattern_gen(), "Key", "Value") + for k, v in pairs(gconf) do + nerv.printf(pattern_gen(), k or "", v or "") + end +end + +local trainer_defaults = { + lrate = 0.8, + batch_size = 256, + buffer_size = 81920, + wcost = 1e-6, + momentum = 0.9, + start_halving_inc = 0.5, + halving_factor = 0.6, + end_halving_inc = 0.1, + min_iter = 1, + max_iter = 20, + min_halving = 5, + do_halving = false, + tr_scp = nil, + cv_scp = nil, + cumat_type = nerv.CuMatrixFloat, + mmat_type = nerv.MMatrixFloat, + debug = false +} + +local options = make_options(trainer_defaults) +table.insert(options, {"help", "h", "boolean", + default = false, desc = "show this help information"}) + +arg, opts = nerv.parse_args(arg, options) + +if #arg < 1 or opts["help"].val then + print_help(options) + return +end + dofile(arg[1]) -start_halving_inc = 0.5 -halving_factor = 0.6 -end_halving_inc = 0.1 -min_iter = 1 -max_iter = 20 -min_halving = 5 -gconf.batch_size = 256 -gconf.buffer_size = 81920 + +--[[ + +Rule: command-line option overrides network config overrides trainer default. +Note: config key like aaa_bbbb_cc could be overriden by specifying +--aaa-bbbb-cc to command-line arguments. + +]]-- + +check_and_add_defaults(trainer_defaults) local pf0 = gconf.initialized_param local trainer = build_trainer(pf0) ---local trainer = build_trainer("c3.nerv") local accu_best = trainer(nil, gconf.cv_scp, false) -local do_halving = false +print_gconf() nerv.info("initial cross validation: %.3f", accu_best) -for i = 1, max_iter do +for i = 1, gconf.max_iter do nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) local accu_tr = trainer(nil, gconf.tr_scp, true) nerv.info("[TR] training set %d: %.3f", i, accu_tr) @@ -108,14 +175,17 @@ for i = 1, max_iter do nerv.info("[CV] cross validation %d: %.3f", i, accu_new) -- TODO: revert the weights local accu_diff = accu_new - accu_best - if do_halving and accu_diff < end_halving_inc and i > min_iter then + if gconf.do_halving and + accu_diff < gconf.end_halving_inc and + i > gconf.min_iter then break end - if accu_diff < start_halving_inc and i >= min_halving then - do_halving = true + if accu_diff < gconf.start_halving_inc and + i >= gconf.min_halving then + gconf.do_halving = true end - if do_halving then - gconf.lrate = gconf.lrate * halving_factor + if gconf.do_halving then + gconf.lrate = gconf.lrate * gconf.halving_factor end if accu_new > accu_best then accu_best = accu_new |