aboutsummaryrefslogtreecommitdiff
path: root/nerv
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2016-03-02 18:24:09 +0800
committerDeterminant <[email protected]>2016-03-02 18:24:09 +0800
commitad704f2623cc9e0a5d702434bfdebc345465ca12 (patch)
tree898d0688e913efc3ff098ba51e5c1a5488f5771d /nerv
parentd3abc6459a776ff7fa3777f4f561bc4f5d5e2075 (diff)
major changes in asr_trainer.lua; unified settings in `gconf`
Diffstat (limited to 'nerv')
-rw-r--r--nerv/examples/asr_trainer.lua104
-rw-r--r--nerv/examples/swb_baseline.lua7
-rw-r--r--nerv/examples/swb_baseline2.lua7
-rw-r--r--nerv/examples/timit_baseline2.lua9
-rw-r--r--nerv/init.lua22
-rw-r--r--nerv/io/sgd_buffer.lua7
-rw-r--r--nerv/nerv11
-rw-r--r--nerv/test/parse_args.lua12
8 files changed, 126 insertions, 53 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 3fa2653..684ea30 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,4 +1,4 @@
-function build_trainer(ifname)
+local function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
@@ -75,24 +75,91 @@ function build_trainer(ifname)
return iterative_trainer
end
+local function check_and_add_defaults(spec)
+ for k, v in pairs(spec) do
+ gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v
+ end
+end
+
+local function make_options(spec)
+ local options = {}
+ for k, v in pairs(spec) do
+ table.insert(options,
+ {string.gsub(k, '_', '-'), nil, type(v), default = v})
+ end
+ return options
+end
+
+local function print_help(options)
+ nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
+ nerv.print_usage(options)
+end
+
+local function print_gconf()
+ local key_maxlen = 0
+ for k, v in pairs(gconf) do
+ key_maxlen = math.max(key_maxlen, #k or 0)
+ end
+ local function pattern_gen()
+ return string.format("%%-%ds = %%s\n", key_maxlen)
+ end
+ nerv.info("ready to train with the following gconf settings:")
+ nerv.printf(pattern_gen(), "Key", "Value")
+ for k, v in pairs(gconf) do
+ nerv.printf(pattern_gen(), k or "", v or "")
+ end
+end
+
+local trainer_defaults = {
+ lrate = 0.8,
+ batch_size = 256,
+ buffer_size = 81920,
+ wcost = 1e-6,
+ momentum = 0.9,
+ start_halving_inc = 0.5,
+ halving_factor = 0.6,
+ end_halving_inc = 0.1,
+ min_iter = 1,
+ max_iter = 20,
+ min_halving = 5,
+ do_halving = false,
+ tr_scp = nil,
+ cv_scp = nil,
+ cumat_type = nerv.CuMatrixFloat,
+ mmat_type = nerv.MMatrixFloat,
+ debug = false
+}
+
+local options = make_options(trainer_defaults)
+table.insert(options, {"help", "h", "boolean",
+ default = false, desc = "show this help information"})
+
+arg, opts = nerv.parse_args(arg, options)
+
+if #arg < 1 or opts["help"].val then
+ print_help(options)
+ return
+end
+
dofile(arg[1])
-start_halving_inc = 0.5
-halving_factor = 0.6
-end_halving_inc = 0.1
-min_iter = 1
-max_iter = 20
-min_halving = 5
-gconf.batch_size = 256
-gconf.buffer_size = 81920
+
+--[[
+
+Rule: command-line option overrides network config overrides trainer default.
+Note: config key like aaa_bbbb_cc could be overriden by specifying
+--aaa-bbbb-cc to command-line arguments.
+
+]]--
+
+check_and_add_defaults(trainer_defaults)
local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)
---local trainer = build_trainer("c3.nerv")
local accu_best = trainer(nil, gconf.cv_scp, false)
-local do_halving = false
+print_gconf()
nerv.info("initial cross validation: %.3f", accu_best)
-for i = 1, max_iter do
+for i = 1, gconf.max_iter do
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local accu_tr = trainer(nil, gconf.tr_scp, true)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
@@ -108,14 +175,17 @@ for i = 1, max_iter do
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
-- TODO: revert the weights
local accu_diff = accu_new - accu_best
- if do_halving and accu_diff < end_halving_inc and i > min_iter then
+ if gconf.do_halving and
+ accu_diff < gconf.end_halving_inc and
+ i > gconf.min_iter then
break
end
- if accu_diff < start_halving_inc and i >= min_halving then
- do_halving = true
+ if accu_diff < gconf.start_halving_inc and
+ i >= gconf.min_halving then
+ gconf.do_halving = true
end
- if do_halving then
- gconf.lrate = gconf.lrate * halving_factor
+ if gconf.do_halving then
+ gconf.lrate = gconf.lrate * gconf.halving_factor
end
if accu_new > accu_best then
accu_best = accu_new
diff --git a/nerv/examples/swb_baseline.lua b/nerv/examples/swb_baseline.lua
index cacc401..4cb2389 100644
--- a/nerv/examples/swb_baseline.lua
+++ b/nerv/examples/swb_baseline.lua
@@ -1,7 +1,5 @@
require 'htk_io'
gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
rearrange = true, -- just to make the context order consistent with old results, deprecated
frm_ext = 5,
frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
@@ -173,6 +171,7 @@ function make_buffer(readers)
return nerv.SGDBuffer(gconf,
{
buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true
@@ -184,6 +183,10 @@ function get_input_order()
{id = "phone_state"}}
end
+function get_decode_input_order()
+ return {{id = "main_scp", global_transf = true}}
+end
+
function get_accuracy(layer_repo)
local ce_crit = layer_repo:get_layer("ce_crit")
return ce_crit.total_correct / ce_crit.total_frames * 100
diff --git a/nerv/examples/swb_baseline2.lua b/nerv/examples/swb_baseline2.lua
index 0e2a6e0..b0b9689 100644
--- a/nerv/examples/swb_baseline2.lua
+++ b/nerv/examples/swb_baseline2.lua
@@ -1,7 +1,5 @@
require 'htk_io'
gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
rearrange = true, -- just to make the context order consistent with old results, deprecated
frm_ext = 5,
frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
@@ -173,6 +171,7 @@ function make_buffer(readers)
return nerv.SGDBuffer(gconf,
{
buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true
@@ -184,6 +183,10 @@ function get_input_order()
{id = "phone_state"}}
end
+function get_decode_input_order()
+ return {{id = "main_scp", global_transf = true}}
+end
+
function get_accuracy(layer_repo)
local ce_crit = layer_repo:get_layer("ce_crit")
return ce_crit.total_correct / ce_crit.total_frames * 100
diff --git a/nerv/examples/timit_baseline2.lua b/nerv/examples/timit_baseline2.lua
index 174b9e7..103d89d 100644
--- a/nerv/examples/timit_baseline2.lua
+++ b/nerv/examples/timit_baseline2.lua
@@ -1,8 +1,5 @@
require 'kaldi_io'
-gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
- frm_ext = 5,
+gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, frm_ext = 5,
tr_scp = "ark:/speechlab/tools/KALDI/kaldi-master/src/featbin/copy-feats " ..
"scp:/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/train.scp ark:- |",
cv_scp = "ark:/speechlab/tools/KALDI/kaldi-master/src/featbin/copy-feats " ..
@@ -11,8 +8,7 @@ gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
"/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_output.nerv",
"/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"},
decode_param = {"/speechlab/users/mfy43/timit/nnet_init_20160229015745_iter_13_lr0.013437_tr72.434_cv58.729.nerv",
- "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"},
- debug = false}
+ "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"}}
function make_layer_repo(param_repo)
local layer_repo = nerv.LayerRepo(
@@ -183,6 +179,7 @@ function make_buffer(readers)
return nerv.SGDBuffer(gconf,
{
buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true
diff --git a/nerv/init.lua b/nerv/init.lua
index d72d8b8..a5b032c 100644
--- a/nerv/init.lua
+++ b/nerv/init.lua
@@ -182,15 +182,15 @@ end
-- value and description of the option.
--
-- An example of specification:
--- {{"aaa", "a", "bool", default = false, desc = "an option called aaa"},
--- {"bbb", "b", "bool", default = true, desc = "bbb is set to be true if --bbb=no does not present"},
+-- {{"aaa", "a", "boolean", default = false, desc = "an option called aaa"},
+-- {"bbb", "b", "boolean", default = true, desc = "bbb is set to be true if --bbb=no does not present"},
-- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}`
--
-- @return args, opts The non-option arguments and parsed options. `opts` is
-- again a list of tables, each of which corresponds to one table in parameter
-- `options`. The parsed value could be accessed by `opts["aaa"].val` (which is
-- `true` if "--aaa" or "-a" is specified).
-function nerv.parse_args(argv, options)
+function nerv.parse_args(argv, options, unordered)
local is_opt_exp = "^[-](.*)$"
local sim_opt_exp = "^[-]([a-z]+)$"
local opt_exp = "^[-][-]([^=]+)$"
@@ -198,6 +198,7 @@ function nerv.parse_args(argv, options)
local opts = {}
local sopts = {}
local args = {}
+ local arg_start = false
local function err()
nerv.error("invalid format of option specification")
end
@@ -215,7 +216,7 @@ function nerv.parse_args(argv, options)
val = v.default}
if opt_short ~= nil then
if type(opt_short) ~= "string" or #opt_short ~= 1 then err() end
- if opt_type ~= "bool" then
+ if opt_type ~= "boolean" then
nerv.error("only boolean option could have short form")
end
sopts[opt_short] = opt_meta
@@ -226,7 +227,7 @@ function nerv.parse_args(argv, options)
end
end
for _, token in ipairs(argv) do
- if token:match(is_opt_exp) then
+ if ((not arg_start) or unordered) and token:match(is_opt_exp) then
local k = token:match(sim_opt_exp)
if k then
for c in k:gmatch"." do
@@ -242,7 +243,7 @@ function nerv.parse_args(argv, options)
if opts[k] == nil then
nerv.error("invalid option %s", token)
end
- if opts[k].type ~= "bool" then
+ if opts[k].type ~= "boolean" then
nerv.error("invalid option --%s: " ..
"a %s value needs to be specified",
k, opts[k].type)
@@ -255,13 +256,13 @@ function nerv.parse_args(argv, options)
if opts[k] == nil then
nerv.error("invalid option %s", token)
end
- if opts[k].type == "bool" then
+ if opts[k].type == "boolean" then
if v == "yes" then
opts[k].val = true
elseif v == "no" then
opts[k].val = false
else
- nerv.error("bool value should be \"yes\" or \"no\"")
+ nerv.error("boolean value should be \"yes\" or \"no\"")
end
elseif opts[k].type == "int" then
local t = tonumber(v)
@@ -269,11 +270,11 @@ function nerv.parse_args(argv, options)
if t == nil or math.floor(t) ~= t then
nerv.error("int value is expected")
end
- elseif opts[k].type == "float" then
+ elseif opts[k].type == "number" then
local t = tonumber(v)
opts[k].val = t
if t == nil then
- nerv.error("float value is expected")
+ nerv.error("numeric value is expected")
end
elseif opts[k].type == "string" then
opts[k].val = v
@@ -287,6 +288,7 @@ function nerv.parse_args(argv, options)
end
else
table.insert(args, token)
+ arg_start = true
end
end
return args, opts
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
index 3cf4f5a..d78f6d1 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/sgd_buffer.lua
@@ -2,8 +2,9 @@ local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
function SGDBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
+ self.batch_size = buffer_conf.batch_size
self.buffer_size = math.floor(buffer_conf.buffer_size /
- global_conf.batch_size) * global_conf.batch_size
+ self.batch_size) * self.batch_size
self.randomize = buffer_conf.randomize
self.consume = buffer_conf.consume
local cumat_type = global_conf.cumat_type
@@ -112,11 +113,11 @@ function SGDBuffer:saturate()
end
self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
collectgarbage("collect")
- return self.tail >= self.gconf.batch_size
+ return self.tail >= self.batch_size
end
function SGDBuffer:get_data()
- local batch_size = self.gconf.batch_size
+ local batch_size = self.batch_size
if self.head >= self.tail then -- buffer is empty
local t = os.clock()
if (not self:saturate()) and (not self.consume) then
diff --git a/nerv/nerv b/nerv/nerv
index 0b75a9b..f73d517 100644
--- a/nerv/nerv
+++ b/nerv/nerv
@@ -1,8 +1,8 @@
#! /usr/bin/env luajit
require 'nerv'
-local options = {{"help", "h", "bool", default = false, desc = "print this help message"},
- {"use-cpu", "c", "bool", default = false, desc = "use CPU by default (instead of gpu by default)"},
- {"select-gpu", nil, "int", default = nil, desc = "select the GPU for computation, fallback to auto mode if not specified"}}
+local options = {{"help", "h", "boolean", default = false, desc = "print this help message"},
+ {"use-cpu", "c", "boolean", default = false, desc = "use CPU by default (instead of gpu by default)"},
+ {"select-gpu", nil, "int", default = -1, desc = "select the GPU for computation, fallback to auto mode if not specified"}}
local function print_help()
nerv.printf("Usage: <nerv_prog> [options] script.lua\n")
@@ -24,10 +24,7 @@ local function _add_profile_method(cls)
end
if not opts["use-cpu"].val then
- local dev = -1
- if opts["select-gpu"].val then
- dev = opts["select-gpu"].val
- end
+ local dev = opts["select-gpu"].val
nerv.info("automatically initialize a default CuContext...")
nerv.CuMatrix._default_context = nerv.CuContext(dev)
nerv.info("the default CuContext is ok")
diff --git a/nerv/test/parse_args.lua b/nerv/test/parse_args.lua
index 0d280a1..34ad55e 100644
--- a/nerv/test/parse_args.lua
+++ b/nerv/test/parse_args.lua
@@ -1,9 +1,9 @@
-local options = {{"abandon", "a", "bool", default = false, desc = "abandon your belief"},
- {"bullshit", "b", "bool", default = false, desc = "start to bullshit"},
- {"cheat", "c", "bool", default = false, desc = "try to cheat"},
- {"delete", "d", "bool", default = false, desc = "remove everything"},
- {"hehe", "h", "bool", default = false, desc = "233333"},
- {"oh", "o", "bool", default = true, desc = "oh yes!"},
+local options = {{"abandon", "a", "boolean", default = false, desc = "abandon your belief"},
+ {"bullshit", "b", "boolean", default = false, desc = "start to bullshit"},
+ {"cheat", "c", "boolean", default = false, desc = "try to cheat"},
+ {"delete", "d", "boolean", default = false, desc = "remove everything"},
+ {"hehe", "h", "boolean", default = false, desc = "233333"},
+ {"oh", "o", "boolean", default = true, desc = "oh yes!"},
{"uid", nil, "int", desc = "user uid"},
{"str", nil, "string", desc = "test string"}}