diff options
author | Determinant <[email protected]> | 2016-05-08 11:38:28 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-05-08 11:38:28 +0800 |
commit | 88b3f2a13fa3c01a684259e85ee8298e35f2ac09 (patch) | |
tree | 1c5ff4e2759ea88f6a9daa5fcafbc07d91951c00 /nerv/examples/trainer.lua | |
parent | e3ed809bb7d5d11b5b2cec559955b15db18db915 (diff) |
prepare for the replacement of `asr_trainer.lua` with `trainer.lua`
Diffstat (limited to 'nerv/examples/trainer.lua')
-rw-r--r-- | nerv/examples/trainer.lua | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/nerv/examples/trainer.lua b/nerv/examples/trainer.lua index 783ff1d..8e3efcb 100644 --- a/nerv/examples/trainer.lua +++ b/nerv/examples/trainer.lua @@ -1,9 +1,9 @@ require 'lfs' require 'pl' --- ======================================================= --- Deal with command line input & init training envrioment --- ======================================================= +-- ========================================================= +-- Deal with command line input & init training envrioment +-- ========================================================= local function check_and_add_defaults(spec, opts) local function get_opt_val(k) @@ -14,15 +14,14 @@ local function check_and_add_defaults(spec, opts) if opt_v then nerv.info("resuming from previous training state") gconf = dofile(opt_v) - else - for k, v in pairs(spec) do - local opt_v, specified = get_opt_val(k) - if (not specified) and gconf[k] ~= nil then - nerv.info("using setting in network config file: %s = %s", k, gconf[k]) - elseif opt_v ~= nil then - nerv.info("using setting in options: %s = %s", k, opt_v) - gconf[k] = opt_v - end + end + for k, v in pairs(spec) do + local opt_v, specified = get_opt_val(k) + if (not specified) and gconf[k] ~= nil then + nerv.info("using setting in network config file: %s = %s", k, gconf[k]) + elseif opt_v ~= nil then + nerv.info("using setting in options: %s = %s", k, opt_v) + gconf[k] = opt_v end end end @@ -65,6 +64,7 @@ end local trainer_defaults = { lrate = 0.8, + hfactor = 0.5, batch_size = 256, chunk_size = 1, buffer_size = 81920, @@ -125,7 +125,8 @@ end local date_pattern = "%Y-%m-%d_%H:%M:%S" local logfile_name = "log" -local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern)) +local working_dir = opts["dir"].val or + string.format("nerv_%s", os.date(date_pattern)) gconf.working_dir = working_dir gconf.date_pattern = date_pattern @@ -139,9 +140,9 @@ dir.copyfile(script, working_dir) -- set logfile path nerv.set_logfile(path.join(working_dir, logfile_name)) --- ============= --- main function --- ============= +-- ============ +-- Main loop +-- ============ local trainer = gconf.trainer(gconf) trainer:training_preprocess() @@ -160,6 +161,7 @@ for i = gconf.cur_iter, gconf.max_iter do local test_err = trainer:process('test', false) nerv.info('[TE] testset error %d: %.3f', i, test_err) end - trainer:halving(train_err, cv_err) + trainer:save_params(train_err, cv_err) end +dump_gconf(path.join(working_dir, string.format("iter_%d.meta", gconf.max_iter + 1))) trainer:training_afterprocess() |