aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-03-02 18:24:09 +0800
committerDeterminant <ted.sybil@gmail.com>2016-03-02 18:24:09 +0800
commitad704f2623cc9e0a5d702434bfdebc345465ca12 (patch)
tree898d0688e913efc3ff098ba51e5c1a5488f5771d /nerv/examples/asr_trainer.lua
parentd3abc6459a776ff7fa3777f4f561bc4f5d5e2075 (diff)
major changes in asr_trainer.lua; unified settings in `gconf`
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua104
1 files changed, 87 insertions, 17 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 3fa2653..684ea30 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,4 +1,4 @@
-function build_trainer(ifname)
+local function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
@@ -75,24 +75,91 @@ function build_trainer(ifname)
return iterative_trainer
end
+local function check_and_add_defaults(spec)
+ for k, v in pairs(spec) do
+ gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v
+ end
+end
+
+local function make_options(spec)
+ local options = {}
+ for k, v in pairs(spec) do
+ table.insert(options,
+ {string.gsub(k, '_', '-'), nil, type(v), default = v})
+ end
+ return options
+end
+
+local function print_help(options)
+ nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
+ nerv.print_usage(options)
+end
+
+local function print_gconf()
+ local key_maxlen = 0
+ for k, v in pairs(gconf) do
+ key_maxlen = math.max(key_maxlen, #k or 0)
+ end
+ local function pattern_gen()
+ return string.format("%%-%ds = %%s\n", key_maxlen)
+ end
+ nerv.info("ready to train with the following gconf settings:")
+ nerv.printf(pattern_gen(), "Key", "Value")
+ for k, v in pairs(gconf) do
+ nerv.printf(pattern_gen(), k or "", v or "")
+ end
+end
+
+local trainer_defaults = {
+ lrate = 0.8,
+ batch_size = 256,
+ buffer_size = 81920,
+ wcost = 1e-6,
+ momentum = 0.9,
+ start_halving_inc = 0.5,
+ halving_factor = 0.6,
+ end_halving_inc = 0.1,
+ min_iter = 1,
+ max_iter = 20,
+ min_halving = 5,
+ do_halving = false,
+ tr_scp = nil,
+ cv_scp = nil,
+ cumat_type = nerv.CuMatrixFloat,
+ mmat_type = nerv.MMatrixFloat,
+ debug = false
+}
+
+local options = make_options(trainer_defaults)
+table.insert(options, {"help", "h", "boolean",
+ default = false, desc = "show this help information"})
+
+arg, opts = nerv.parse_args(arg, options)
+
+if #arg < 1 or opts["help"].val then
+ print_help(options)
+ return
+end
+
dofile(arg[1])
-start_halving_inc = 0.5
-halving_factor = 0.6
-end_halving_inc = 0.1
-min_iter = 1
-max_iter = 20
-min_halving = 5
-gconf.batch_size = 256
-gconf.buffer_size = 81920
+
+--[[
+
+Rule: command-line option overrides network config overrides trainer default.
+Note: config key like aaa_bbbb_cc could be overriden by specifying
+--aaa-bbbb-cc to command-line arguments.
+
+]]--
+
+check_and_add_defaults(trainer_defaults)
local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)
---local trainer = build_trainer("c3.nerv")
local accu_best = trainer(nil, gconf.cv_scp, false)
-local do_halving = false
+print_gconf()
nerv.info("initial cross validation: %.3f", accu_best)
-for i = 1, max_iter do
+for i = 1, gconf.max_iter do
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local accu_tr = trainer(nil, gconf.tr_scp, true)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
@@ -108,14 +175,17 @@ for i = 1, max_iter do
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
-- TODO: revert the weights
local accu_diff = accu_new - accu_best
- if do_halving and accu_diff < end_halving_inc and i > min_iter then
+ if gconf.do_halving and
+ accu_diff < gconf.end_halving_inc and
+ i > gconf.min_iter then
break
end
- if accu_diff < start_halving_inc and i >= min_halving then
- do_halving = true
+ if accu_diff < gconf.start_halving_inc and
+ i >= gconf.min_halving then
+ gconf.do_halving = true
end
- if do_halving then
- gconf.lrate = gconf.lrate * halving_factor
+ if gconf.do_halving then
+ gconf.lrate = gconf.lrate * gconf.halving_factor
end
if accu_new > accu_best then
accu_best = accu_new