aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-03-02 18:24:09 +0800
committerDeterminant <ted.sybil@gmail.com>2016-03-02 18:24:09 +0800
commitad704f2623cc9e0a5d702434bfdebc345465ca12 (patch)
tree898d0688e913efc3ff098ba51e5c1a5488f5771d /nerv/examples
parentd3abc6459a776ff7fa3777f4f561bc4f5d5e2075 (diff)
major changes in asr_trainer.lua; unified settings in `gconf`
Diffstat (limited to 'nerv/examples')
-rw-r--r--nerv/examples/asr_trainer.lua104
-rw-r--r--nerv/examples/swb_baseline.lua7
-rw-r--r--nerv/examples/swb_baseline2.lua7
-rw-r--r--nerv/examples/timit_baseline2.lua9
4 files changed, 100 insertions, 27 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 3fa2653..684ea30 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,4 +1,4 @@
-function build_trainer(ifname)
+local function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
@@ -75,24 +75,91 @@ function build_trainer(ifname)
return iterative_trainer
end
+local function check_and_add_defaults(spec)
+ for k, v in pairs(spec) do
+ gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v
+ end
+end
+
+local function make_options(spec)
+ local options = {}
+ for k, v in pairs(spec) do
+ table.insert(options,
+ {string.gsub(k, '_', '-'), nil, type(v), default = v})
+ end
+ return options
+end
+
+local function print_help(options)
+ nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
+ nerv.print_usage(options)
+end
+
+local function print_gconf()
+ local key_maxlen = 0
+ for k, v in pairs(gconf) do
+ key_maxlen = math.max(key_maxlen, #k or 0)
+ end
+ local function pattern_gen()
+ return string.format("%%-%ds = %%s\n", key_maxlen)
+ end
+ nerv.info("ready to train with the following gconf settings:")
+ nerv.printf(pattern_gen(), "Key", "Value")
+ for k, v in pairs(gconf) do
+ nerv.printf(pattern_gen(), k or "", v or "")
+ end
+end
+
+local trainer_defaults = {
+ lrate = 0.8,
+ batch_size = 256,
+ buffer_size = 81920,
+ wcost = 1e-6,
+ momentum = 0.9,
+ start_halving_inc = 0.5,
+ halving_factor = 0.6,
+ end_halving_inc = 0.1,
+ min_iter = 1,
+ max_iter = 20,
+ min_halving = 5,
+ do_halving = false,
+ tr_scp = nil,
+ cv_scp = nil,
+ cumat_type = nerv.CuMatrixFloat,
+ mmat_type = nerv.MMatrixFloat,
+ debug = false
+}
+
+local options = make_options(trainer_defaults)
+table.insert(options, {"help", "h", "boolean",
+ default = false, desc = "show this help information"})
+
+arg, opts = nerv.parse_args(arg, options)
+
+if #arg < 1 or opts["help"].val then
+ print_help(options)
+ return
+end
+
dofile(arg[1])
-start_halving_inc = 0.5
-halving_factor = 0.6
-end_halving_inc = 0.1
-min_iter = 1
-max_iter = 20
-min_halving = 5
-gconf.batch_size = 256
-gconf.buffer_size = 81920
+
+--[[
+
+Rule: command-line option overrides network config overrides trainer default.
+Note: config key like aaa_bbbb_cc could be overriden by specifying
+--aaa-bbbb-cc to command-line arguments.
+
+]]--
+
+check_and_add_defaults(trainer_defaults)
local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)
---local trainer = build_trainer("c3.nerv")
local accu_best = trainer(nil, gconf.cv_scp, false)
-local do_halving = false
+print_gconf()
nerv.info("initial cross validation: %.3f", accu_best)
-for i = 1, max_iter do
+for i = 1, gconf.max_iter do
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local accu_tr = trainer(nil, gconf.tr_scp, true)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
@@ -108,14 +175,17 @@ for i = 1, max_iter do
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
-- TODO: revert the weights
local accu_diff = accu_new - accu_best
- if do_halving and accu_diff < end_halving_inc and i > min_iter then
+ if gconf.do_halving and
+ accu_diff < gconf.end_halving_inc and
+ i > gconf.min_iter then
break
end
- if accu_diff < start_halving_inc and i >= min_halving then
- do_halving = true
+ if accu_diff < gconf.start_halving_inc and
+ i >= gconf.min_halving then
+ gconf.do_halving = true
end
- if do_halving then
- gconf.lrate = gconf.lrate * halving_factor
+ if gconf.do_halving then
+ gconf.lrate = gconf.lrate * gconf.halving_factor
end
if accu_new > accu_best then
accu_best = accu_new
diff --git a/nerv/examples/swb_baseline.lua b/nerv/examples/swb_baseline.lua
index cacc401..4cb2389 100644
--- a/nerv/examples/swb_baseline.lua
+++ b/nerv/examples/swb_baseline.lua
@@ -1,7 +1,5 @@
require 'htk_io'
gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
rearrange = true, -- just to make the context order consistent with old results, deprecated
frm_ext = 5,
frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
@@ -173,6 +171,7 @@ function make_buffer(readers)
return nerv.SGDBuffer(gconf,
{
buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true
@@ -184,6 +183,10 @@ function get_input_order()
{id = "phone_state"}}
end
+function get_decode_input_order()
+ return {{id = "main_scp", global_transf = true}}
+end
+
function get_accuracy(layer_repo)
local ce_crit = layer_repo:get_layer("ce_crit")
return ce_crit.total_correct / ce_crit.total_frames * 100
diff --git a/nerv/examples/swb_baseline2.lua b/nerv/examples/swb_baseline2.lua
index 0e2a6e0..b0b9689 100644
--- a/nerv/examples/swb_baseline2.lua
+++ b/nerv/examples/swb_baseline2.lua
@@ -1,7 +1,5 @@
require 'htk_io'
gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
rearrange = true, -- just to make the context order consistent with old results, deprecated
frm_ext = 5,
frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
@@ -173,6 +171,7 @@ function make_buffer(readers)
return nerv.SGDBuffer(gconf,
{
buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true
@@ -184,6 +183,10 @@ function get_input_order()
{id = "phone_state"}}
end
+function get_decode_input_order()
+ return {{id = "main_scp", global_transf = true}}
+end
+
function get_accuracy(layer_repo)
local ce_crit = layer_repo:get_layer("ce_crit")
return ce_crit.total_correct / ce_crit.total_frames * 100
diff --git a/nerv/examples/timit_baseline2.lua b/nerv/examples/timit_baseline2.lua
index 174b9e7..103d89d 100644
--- a/nerv/examples/timit_baseline2.lua
+++ b/nerv/examples/timit_baseline2.lua
@@ -1,8 +1,5 @@
require 'kaldi_io'
-gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
- frm_ext = 5,
+gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, frm_ext = 5,
tr_scp = "ark:/speechlab/tools/KALDI/kaldi-master/src/featbin/copy-feats " ..
"scp:/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/train.scp ark:- |",
cv_scp = "ark:/speechlab/tools/KALDI/kaldi-master/src/featbin/copy-feats " ..
@@ -11,8 +8,7 @@ gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
"/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_output.nerv",
"/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"},
decode_param = {"/speechlab/users/mfy43/timit/nnet_init_20160229015745_iter_13_lr0.013437_tr72.434_cv58.729.nerv",
- "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"},
- debug = false}
+ "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"}}
function make_layer_repo(param_repo)
local layer_repo = nerv.LayerRepo(
@@ -183,6 +179,7 @@ function make_buffer(readers)
return nerv.SGDBuffer(gconf,
{
buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true