diff options
author | Determinant <[email protected]> | 2016-03-02 18:24:09 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-03-02 18:24:09 +0800 |
commit | ad704f2623cc9e0a5d702434bfdebc345465ca12 (patch) | |
tree | 898d0688e913efc3ff098ba51e5c1a5488f5771d /nerv | |
parent | d3abc6459a776ff7fa3777f4f561bc4f5d5e2075 (diff) |
major changes in asr_trainer.lua; unified settings in `gconf`
Diffstat (limited to 'nerv')
-rw-r--r-- | nerv/examples/asr_trainer.lua | 104 | ||||
-rw-r--r-- | nerv/examples/swb_baseline.lua | 7 | ||||
-rw-r--r-- | nerv/examples/swb_baseline2.lua | 7 | ||||
-rw-r--r-- | nerv/examples/timit_baseline2.lua | 9 | ||||
-rw-r--r-- | nerv/init.lua | 22 | ||||
-rw-r--r-- | nerv/io/sgd_buffer.lua | 7 | ||||
-rw-r--r-- | nerv/nerv | 11 | ||||
-rw-r--r-- | nerv/test/parse_args.lua | 12 |
8 files changed, 126 insertions, 53 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index 3fa2653..684ea30 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -1,4 +1,4 @@ -function build_trainer(ifname) +local function build_trainer(ifname) local param_repo = nerv.ParamRepo() param_repo:import(ifname, nil, gconf) local layer_repo = make_layer_repo(param_repo) @@ -75,24 +75,91 @@ function build_trainer(ifname) return iterative_trainer end +local function check_and_add_defaults(spec) + for k, v in pairs(spec) do + gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v + end +end + +local function make_options(spec) + local options = {} + for k, v in pairs(spec) do + table.insert(options, + {string.gsub(k, '_', '-'), nil, type(v), default = v}) + end + return options +end + +local function print_help(options) + nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n") + nerv.print_usage(options) +end + +local function print_gconf() + local key_maxlen = 0 + for k, v in pairs(gconf) do + key_maxlen = math.max(key_maxlen, #k or 0) + end + local function pattern_gen() + return string.format("%%-%ds = %%s\n", key_maxlen) + end + nerv.info("ready to train with the following gconf settings:") + nerv.printf(pattern_gen(), "Key", "Value") + for k, v in pairs(gconf) do + nerv.printf(pattern_gen(), k or "", v or "") + end +end + +local trainer_defaults = { + lrate = 0.8, + batch_size = 256, + buffer_size = 81920, + wcost = 1e-6, + momentum = 0.9, + start_halving_inc = 0.5, + halving_factor = 0.6, + end_halving_inc = 0.1, + min_iter = 1, + max_iter = 20, + min_halving = 5, + do_halving = false, + tr_scp = nil, + cv_scp = nil, + cumat_type = nerv.CuMatrixFloat, + mmat_type = nerv.MMatrixFloat, + debug = false +} + +local options = make_options(trainer_defaults) +table.insert(options, {"help", "h", "boolean", + default = false, desc = "show this help information"}) + +arg, opts = nerv.parse_args(arg, options) + +if #arg < 1 or opts["help"].val then + print_help(options) + return +end + dofile(arg[1]) -start_halving_inc = 0.5 -halving_factor = 0.6 -end_halving_inc = 0.1 -min_iter = 1 -max_iter = 20 -min_halving = 5 -gconf.batch_size = 256 -gconf.buffer_size = 81920 + +--[[ + +Rule: command-line option overrides network config overrides trainer default. +Note: config key like aaa_bbbb_cc could be overriden by specifying +--aaa-bbbb-cc to command-line arguments. + +]]-- + +check_and_add_defaults(trainer_defaults) local pf0 = gconf.initialized_param local trainer = build_trainer(pf0) ---local trainer = build_trainer("c3.nerv") local accu_best = trainer(nil, gconf.cv_scp, false) -local do_halving = false +print_gconf() nerv.info("initial cross validation: %.3f", accu_best) -for i = 1, max_iter do +for i = 1, gconf.max_iter do nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) local accu_tr = trainer(nil, gconf.tr_scp, true) nerv.info("[TR] training set %d: %.3f", i, accu_tr) @@ -108,14 +175,17 @@ for i = 1, max_iter do nerv.info("[CV] cross validation %d: %.3f", i, accu_new) -- TODO: revert the weights local accu_diff = accu_new - accu_best - if do_halving and accu_diff < end_halving_inc and i > min_iter then + if gconf.do_halving and + accu_diff < gconf.end_halving_inc and + i > gconf.min_iter then break end - if accu_diff < start_halving_inc and i >= min_halving then - do_halving = true + if accu_diff < gconf.start_halving_inc and + i >= gconf.min_halving then + gconf.do_halving = true end - if do_halving then - gconf.lrate = gconf.lrate * halving_factor + if gconf.do_halving then + gconf.lrate = gconf.lrate * gconf.halving_factor end if accu_new > accu_best then accu_best = accu_new diff --git a/nerv/examples/swb_baseline.lua b/nerv/examples/swb_baseline.lua index cacc401..4cb2389 100644 --- a/nerv/examples/swb_baseline.lua +++ b/nerv/examples/swb_baseline.lua @@ -1,7 +1,5 @@ require 'htk_io' gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, - cumat_type = nerv.CuMatrixFloat, - mmat_type = nerv.MMatrixFloat, rearrange = true, -- just to make the context order consistent with old results, deprecated frm_ext = 5, frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated @@ -173,6 +171,7 @@ function make_buffer(readers) return nerv.SGDBuffer(gconf, { buffer_size = gconf.buffer_size, + batch_size = gconf.batch_size, randomize = gconf.randomize, readers = readers, use_gpu = true @@ -184,6 +183,10 @@ function get_input_order() {id = "phone_state"}} end +function get_decode_input_order() + return {{id = "main_scp", global_transf = true}} +end + function get_accuracy(layer_repo) local ce_crit = layer_repo:get_layer("ce_crit") return ce_crit.total_correct / ce_crit.total_frames * 100 diff --git a/nerv/examples/swb_baseline2.lua b/nerv/examples/swb_baseline2.lua index 0e2a6e0..b0b9689 100644 --- a/nerv/examples/swb_baseline2.lua +++ b/nerv/examples/swb_baseline2.lua @@ -1,7 +1,5 @@ require 'htk_io' gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, - cumat_type = nerv.CuMatrixFloat, - mmat_type = nerv.MMatrixFloat, rearrange = true, -- just to make the context order consistent with old results, deprecated frm_ext = 5, frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated @@ -173,6 +171,7 @@ function make_buffer(readers) return nerv.SGDBuffer(gconf, { buffer_size = gconf.buffer_size, + batch_size = gconf.batch_size, randomize = gconf.randomize, readers = readers, use_gpu = true @@ -184,6 +183,10 @@ function get_input_order() {id = "phone_state"}} end +function get_decode_input_order() + return {{id = "main_scp", global_transf = true}} +end + function get_accuracy(layer_repo) local ce_crit = layer_repo:get_layer("ce_crit") return ce_crit.total_correct / ce_crit.total_frames * 100 diff --git a/nerv/examples/timit_baseline2.lua b/nerv/examples/timit_baseline2.lua index 174b9e7..103d89d 100644 --- a/nerv/examples/timit_baseline2.lua +++ b/nerv/examples/timit_baseline2.lua @@ -1,8 +1,5 @@ require 'kaldi_io' -gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, - cumat_type = nerv.CuMatrixFloat, - mmat_type = nerv.MMatrixFloat, - frm_ext = 5, +gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, frm_ext = 5, tr_scp = "ark:/speechlab/tools/KALDI/kaldi-master/src/featbin/copy-feats " .. "scp:/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/train.scp ark:- |", cv_scp = "ark:/speechlab/tools/KALDI/kaldi-master/src/featbin/copy-feats " .. @@ -11,8 +8,7 @@ gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_output.nerv", "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"}, decode_param = {"/speechlab/users/mfy43/timit/nnet_init_20160229015745_iter_13_lr0.013437_tr72.434_cv58.729.nerv", - "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"}, - debug = false} + "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"}} function make_layer_repo(param_repo) local layer_repo = nerv.LayerRepo( @@ -183,6 +179,7 @@ function make_buffer(readers) return nerv.SGDBuffer(gconf, { buffer_size = gconf.buffer_size, + batch_size = gconf.batch_size, randomize = gconf.randomize, readers = readers, use_gpu = true diff --git a/nerv/init.lua b/nerv/init.lua index d72d8b8..a5b032c 100644 --- a/nerv/init.lua +++ b/nerv/init.lua @@ -182,15 +182,15 @@ end -- value and description of the option. -- -- An example of specification: --- {{"aaa", "a", "bool", default = false, desc = "an option called aaa"}, --- {"bbb", "b", "bool", default = true, desc = "bbb is set to be true if --bbb=no does not present"}, +-- {{"aaa", "a", "boolean", default = false, desc = "an option called aaa"}, +-- {"bbb", "b", "boolean", default = true, desc = "bbb is set to be true if --bbb=no does not present"}, -- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}` -- -- @return args, opts The non-option arguments and parsed options. `opts` is -- again a list of tables, each of which corresponds to one table in parameter -- `options`. The parsed value could be accessed by `opts["aaa"].val` (which is -- `true` if "--aaa" or "-a" is specified). -function nerv.parse_args(argv, options) +function nerv.parse_args(argv, options, unordered) local is_opt_exp = "^[-](.*)$" local sim_opt_exp = "^[-]([a-z]+)$" local opt_exp = "^[-][-]([^=]+)$" @@ -198,6 +198,7 @@ function nerv.parse_args(argv, options) local opts = {} local sopts = {} local args = {} + local arg_start = false local function err() nerv.error("invalid format of option specification") end @@ -215,7 +216,7 @@ function nerv.parse_args(argv, options) val = v.default} if opt_short ~= nil then if type(opt_short) ~= "string" or #opt_short ~= 1 then err() end - if opt_type ~= "bool" then + if opt_type ~= "boolean" then nerv.error("only boolean option could have short form") end sopts[opt_short] = opt_meta @@ -226,7 +227,7 @@ function nerv.parse_args(argv, options) end end for _, token in ipairs(argv) do - if token:match(is_opt_exp) then + if ((not arg_start) or unordered) and token:match(is_opt_exp) then local k = token:match(sim_opt_exp) if k then for c in k:gmatch"." do @@ -242,7 +243,7 @@ function nerv.parse_args(argv, options) if opts[k] == nil then nerv.error("invalid option %s", token) end - if opts[k].type ~= "bool" then + if opts[k].type ~= "boolean" then nerv.error("invalid option --%s: " .. "a %s value needs to be specified", k, opts[k].type) @@ -255,13 +256,13 @@ function nerv.parse_args(argv, options) if opts[k] == nil then nerv.error("invalid option %s", token) end - if opts[k].type == "bool" then + if opts[k].type == "boolean" then if v == "yes" then opts[k].val = true elseif v == "no" then opts[k].val = false else - nerv.error("bool value should be \"yes\" or \"no\"") + nerv.error("boolean value should be \"yes\" or \"no\"") end elseif opts[k].type == "int" then local t = tonumber(v) @@ -269,11 +270,11 @@ function nerv.parse_args(argv, options) if t == nil or math.floor(t) ~= t then nerv.error("int value is expected") end - elseif opts[k].type == "float" then + elseif opts[k].type == "number" then local t = tonumber(v) opts[k].val = t if t == nil then - nerv.error("float value is expected") + nerv.error("numeric value is expected") end elseif opts[k].type == "string" then opts[k].val = v @@ -287,6 +288,7 @@ function nerv.parse_args(argv, options) end else table.insert(args, token) + arg_start = true end end return args, opts diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua index 3cf4f5a..d78f6d1 100644 --- a/nerv/io/sgd_buffer.lua +++ b/nerv/io/sgd_buffer.lua @@ -2,8 +2,9 @@ local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer") function SGDBuffer:__init(global_conf, buffer_conf) self.gconf = global_conf + self.batch_size = buffer_conf.batch_size self.buffer_size = math.floor(buffer_conf.buffer_size / - global_conf.batch_size) * global_conf.batch_size + self.batch_size) * self.batch_size self.randomize = buffer_conf.randomize self.consume = buffer_conf.consume local cumat_type = global_conf.cumat_type @@ -112,11 +113,11 @@ function SGDBuffer:saturate() end self.rand_map = self.perm_gen(self.tail) -- generate shuffled index collectgarbage("collect") - return self.tail >= self.gconf.batch_size + return self.tail >= self.batch_size end function SGDBuffer:get_data() - local batch_size = self.gconf.batch_size + local batch_size = self.batch_size if self.head >= self.tail then -- buffer is empty local t = os.clock() if (not self:saturate()) and (not self.consume) then @@ -1,8 +1,8 @@ #! /usr/bin/env luajit require 'nerv' -local options = {{"help", "h", "bool", default = false, desc = "print this help message"}, - {"use-cpu", "c", "bool", default = false, desc = "use CPU by default (instead of gpu by default)"}, - {"select-gpu", nil, "int", default = nil, desc = "select the GPU for computation, fallback to auto mode if not specified"}} +local options = {{"help", "h", "boolean", default = false, desc = "print this help message"}, + {"use-cpu", "c", "boolean", default = false, desc = "use CPU by default (instead of gpu by default)"}, + {"select-gpu", nil, "int", default = -1, desc = "select the GPU for computation, fallback to auto mode if not specified"}} local function print_help() nerv.printf("Usage: <nerv_prog> [options] script.lua\n") @@ -24,10 +24,7 @@ local function _add_profile_method(cls) end if not opts["use-cpu"].val then - local dev = -1 - if opts["select-gpu"].val then - dev = opts["select-gpu"].val - end + local dev = opts["select-gpu"].val nerv.info("automatically initialize a default CuContext...") nerv.CuMatrix._default_context = nerv.CuContext(dev) nerv.info("the default CuContext is ok") diff --git a/nerv/test/parse_args.lua b/nerv/test/parse_args.lua index 0d280a1..34ad55e 100644 --- a/nerv/test/parse_args.lua +++ b/nerv/test/parse_args.lua @@ -1,9 +1,9 @@ -local options = {{"abandon", "a", "bool", default = false, desc = "abandon your belief"}, - {"bullshit", "b", "bool", default = false, desc = "start to bullshit"}, - {"cheat", "c", "bool", default = false, desc = "try to cheat"}, - {"delete", "d", "bool", default = false, desc = "remove everything"}, - {"hehe", "h", "bool", default = false, desc = "233333"}, - {"oh", "o", "bool", default = true, desc = "oh yes!"}, +local options = {{"abandon", "a", "boolean", default = false, desc = "abandon your belief"}, + {"bullshit", "b", "boolean", default = false, desc = "start to bullshit"}, + {"cheat", "c", "boolean", default = false, desc = "try to cheat"}, + {"delete", "d", "boolean", default = false, desc = "remove everything"}, + {"hehe", "h", "boolean", default = false, desc = "233333"}, + {"oh", "o", "boolean", default = true, desc = "oh yes!"}, {"uid", nil, "int", desc = "user uid"}, {"str", nil, "string", desc = "test string"}} |