From ad704f2623cc9e0a5d702434bfdebc345465ca12 Mon Sep 17 00:00:00 2001
From: Determinant <ted.sybil@gmail.com>
Date: Wed, 2 Mar 2016 18:24:09 +0800
Subject: major changes in asr_trainer.lua; unified settings in `gconf`

---
 nerv/examples/asr_trainer.lua     | 104 +++++++++++++++++++++++++++++++-------
 nerv/examples/swb_baseline.lua    |   7 ++-
 nerv/examples/swb_baseline2.lua   |   7 ++-
 nerv/examples/timit_baseline2.lua |   9 ++--
 nerv/init.lua                     |  22 ++++----
 nerv/io/sgd_buffer.lua            |   7 +--
 nerv/nerv                         |  11 ++--
 nerv/test/parse_args.lua          |  12 ++---
 8 files changed, 126 insertions(+), 53 deletions(-)

diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 3fa2653..684ea30 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,4 +1,4 @@
-function build_trainer(ifname)
+local function build_trainer(ifname)
     local param_repo = nerv.ParamRepo()
     param_repo:import(ifname, nil, gconf)
     local layer_repo = make_layer_repo(param_repo)
@@ -75,24 +75,91 @@ function build_trainer(ifname)
     return iterative_trainer
 end
 
+local function check_and_add_defaults(spec)
+    for k, v in pairs(spec) do
+        gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v
+    end
+end
+
+local function make_options(spec)
+    local options = {}
+    for k, v in pairs(spec) do
+        table.insert(options,
+                    {string.gsub(k, '_', '-'), nil, type(v), default = v})
+    end
+    return options
+end
+
+local function print_help(options)
+    nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
+    nerv.print_usage(options)
+end
+
+local function print_gconf()
+    local key_maxlen = 0
+    for k, v in pairs(gconf) do
+        key_maxlen = math.max(key_maxlen, #k or 0)
+    end
+    local function pattern_gen()
+        return string.format("%%-%ds = %%s\n", key_maxlen)
+    end
+    nerv.info("ready to train with the following gconf settings:")
+    nerv.printf(pattern_gen(), "Key", "Value")
+    for k, v in pairs(gconf) do
+        nerv.printf(pattern_gen(), k or "", v or "")
+    end
+end
+
+local trainer_defaults = {
+    lrate = 0.8,
+    batch_size = 256,
+    buffer_size = 81920,
+    wcost = 1e-6,
+    momentum = 0.9,
+    start_halving_inc = 0.5,
+    halving_factor = 0.6,
+    end_halving_inc = 0.1,
+    min_iter = 1,
+    max_iter = 20,
+    min_halving = 5,
+    do_halving = false,
+    tr_scp = nil,
+    cv_scp = nil,
+    cumat_type = nerv.CuMatrixFloat,
+    mmat_type = nerv.MMatrixFloat,
+    debug = false
+}
+
+local options = make_options(trainer_defaults)
+table.insert(options, {"help", "h", "boolean",
+                        default = false, desc = "show this help information"})
+
+arg, opts = nerv.parse_args(arg, options)
+
+if #arg < 1 or opts["help"].val then
+    print_help(options)
+    return
+end
+
 dofile(arg[1])
-start_halving_inc = 0.5
-halving_factor = 0.6
-end_halving_inc = 0.1
-min_iter = 1
-max_iter = 20
-min_halving = 5
-gconf.batch_size = 256
-gconf.buffer_size = 81920
+
+--[[
+
+Rule: command-line option overrides network config overrides trainer default.
+Note: config key like aaa_bbbb_cc could be overriden by specifying
+--aaa-bbbb-cc to command-line arguments.
+
+]]--
+
+check_and_add_defaults(trainer_defaults)
 
 local pf0 = gconf.initialized_param
 local trainer = build_trainer(pf0)
---local trainer = build_trainer("c3.nerv")
 local accu_best = trainer(nil, gconf.cv_scp, false)
-local do_halving = false
 
+print_gconf()
 nerv.info("initial cross validation: %.3f", accu_best)
-for i = 1, max_iter do
+for i = 1, gconf.max_iter do
     nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
     local accu_tr = trainer(nil, gconf.tr_scp, true)
     nerv.info("[TR] training set %d: %.3f", i, accu_tr)
@@ -108,14 +175,17 @@ for i = 1, max_iter do
     nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
     -- TODO: revert the weights
     local accu_diff = accu_new - accu_best
-    if do_halving and accu_diff < end_halving_inc and i > min_iter then
+    if gconf.do_halving and
+        accu_diff < gconf.end_halving_inc and
+        i > gconf.min_iter then
         break
     end
-    if accu_diff < start_halving_inc and i >= min_halving then
-        do_halving = true
+    if accu_diff < gconf.start_halving_inc and
+        i >= gconf.min_halving then
+        gconf.do_halving = true
     end
-    if do_halving then
-        gconf.lrate = gconf.lrate * halving_factor
+    if gconf.do_halving then
+        gconf.lrate = gconf.lrate * gconf.halving_factor
     end
     if accu_new > accu_best then
         accu_best = accu_new
diff --git a/nerv/examples/swb_baseline.lua b/nerv/examples/swb_baseline.lua
index cacc401..4cb2389 100644
--- a/nerv/examples/swb_baseline.lua
+++ b/nerv/examples/swb_baseline.lua
@@ -1,7 +1,5 @@
 require 'htk_io'
 gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
-        cumat_type = nerv.CuMatrixFloat,
-        mmat_type = nerv.MMatrixFloat,
         rearrange = true, -- just to make the context order consistent with old results, deprecated
         frm_ext = 5,
         frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
@@ -173,6 +171,7 @@ function make_buffer(readers)
     return nerv.SGDBuffer(gconf,
         {
             buffer_size = gconf.buffer_size,
+            batch_size = gconf.batch_size,
             randomize = gconf.randomize,
             readers = readers,
             use_gpu = true
@@ -184,6 +183,10 @@ function get_input_order()
             {id = "phone_state"}}
 end
 
+function get_decode_input_order()
+    return {{id = "main_scp", global_transf = true}}
+end
+
 function get_accuracy(layer_repo)
     local ce_crit = layer_repo:get_layer("ce_crit")
     return ce_crit.total_correct / ce_crit.total_frames * 100
diff --git a/nerv/examples/swb_baseline2.lua b/nerv/examples/swb_baseline2.lua
index 0e2a6e0..b0b9689 100644
--- a/nerv/examples/swb_baseline2.lua
+++ b/nerv/examples/swb_baseline2.lua
@@ -1,7 +1,5 @@
 require 'htk_io'
 gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
-        cumat_type = nerv.CuMatrixFloat,
-        mmat_type = nerv.MMatrixFloat,
         rearrange = true, -- just to make the context order consistent with old results, deprecated
         frm_ext = 5,
         frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
@@ -173,6 +171,7 @@ function make_buffer(readers)
     return nerv.SGDBuffer(gconf,
         {
             buffer_size = gconf.buffer_size,
+            batch_size = gconf.batch_size,
             randomize = gconf.randomize,
             readers = readers,
             use_gpu = true
@@ -184,6 +183,10 @@ function get_input_order()
             {id = "phone_state"}}
 end
 
+function get_decode_input_order()
+    return {{id = "main_scp", global_transf = true}}
+end
+
 function get_accuracy(layer_repo)
     local ce_crit = layer_repo:get_layer("ce_crit")
     return ce_crit.total_correct / ce_crit.total_frames * 100
diff --git a/nerv/examples/timit_baseline2.lua b/nerv/examples/timit_baseline2.lua
index 174b9e7..103d89d 100644
--- a/nerv/examples/timit_baseline2.lua
+++ b/nerv/examples/timit_baseline2.lua
@@ -1,8 +1,5 @@
 require 'kaldi_io'
-gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
-        cumat_type = nerv.CuMatrixFloat,
-        mmat_type = nerv.MMatrixFloat,
-        frm_ext = 5,
+gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, frm_ext = 5,
         tr_scp = "ark:/speechlab/tools/KALDI/kaldi-master/src/featbin/copy-feats " ..
                     "scp:/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/train.scp ark:- |",
         cv_scp = "ark:/speechlab/tools/KALDI/kaldi-master/src/featbin/copy-feats " ..
@@ -11,8 +8,7 @@ gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
                             "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_output.nerv",
                             "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"},
         decode_param = {"/speechlab/users/mfy43/timit/nnet_init_20160229015745_iter_13_lr0.013437_tr72.434_cv58.729.nerv",
-                        "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"},
-        debug = false}
+                        "/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_prepare/nnet_trans.nerv"}}
 
 function make_layer_repo(param_repo)
     local layer_repo = nerv.LayerRepo(
@@ -183,6 +179,7 @@ function make_buffer(readers)
     return nerv.SGDBuffer(gconf,
         {
             buffer_size = gconf.buffer_size,
+            batch_size = gconf.batch_size,
             randomize = gconf.randomize,
             readers = readers,
             use_gpu = true
diff --git a/nerv/init.lua b/nerv/init.lua
index d72d8b8..a5b032c 100644
--- a/nerv/init.lua
+++ b/nerv/init.lua
@@ -182,15 +182,15 @@ end
 -- value and description of the option.
 --
 -- An example of specification:
--- {{"aaa", "a", "bool", default = false, desc = "an option called aaa"},
--- {"bbb", "b", "bool", default = true, desc = "bbb is set to be true if --bbb=no does not present"},
+-- {{"aaa", "a", "boolean", default = false, desc = "an option called aaa"},
+-- {"bbb", "b", "boolean", default = true, desc = "bbb is set to be true if --bbb=no does not present"},
 -- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}`
 --
 -- @return args, opts The non-option arguments and parsed options. `opts` is
 -- again a list of tables, each of which corresponds to one table in parameter
 -- `options`. The parsed value could be accessed by `opts["aaa"].val` (which is
 -- `true` if "--aaa" or "-a" is specified).
-function nerv.parse_args(argv, options)
+function nerv.parse_args(argv, options, unordered)
     local is_opt_exp = "^[-](.*)$"
     local sim_opt_exp = "^[-]([a-z]+)$"
     local opt_exp = "^[-][-]([^=]+)$"
@@ -198,6 +198,7 @@ function nerv.parse_args(argv, options)
     local opts = {}
     local sopts = {}
     local args = {}
+    local arg_start = false
     local function err()
         nerv.error("invalid format of option specification")
     end
@@ -215,7 +216,7 @@ function nerv.parse_args(argv, options)
                             val = v.default}
         if opt_short ~= nil then
             if type(opt_short) ~= "string" or #opt_short ~= 1 then err() end
-            if opt_type ~= "bool" then
+            if opt_type ~= "boolean" then
                 nerv.error("only boolean option could have short form")
             end
             sopts[opt_short] = opt_meta
@@ -226,7 +227,7 @@ function nerv.parse_args(argv, options)
         end
     end
     for _, token in ipairs(argv) do
-        if token:match(is_opt_exp) then
+        if ((not arg_start) or unordered) and token:match(is_opt_exp) then
             local k = token:match(sim_opt_exp)
             if k then
                 for c in k:gmatch"." do
@@ -242,7 +243,7 @@ function nerv.parse_args(argv, options)
                     if opts[k] == nil then
                         nerv.error("invalid option %s", token)
                     end
-                    if opts[k].type ~= "bool" then
+                    if opts[k].type ~= "boolean" then
                         nerv.error("invalid option --%s: " ..
                                     "a %s value needs to be specified",
                                     k, opts[k].type)
@@ -255,13 +256,13 @@ function nerv.parse_args(argv, options)
                         if opts[k] == nil then
                             nerv.error("invalid option %s", token)
                         end
-                        if opts[k].type == "bool" then
+                        if opts[k].type == "boolean" then
                             if v == "yes" then
                                 opts[k].val = true
                             elseif v == "no" then
                                 opts[k].val = false
                             else
-                                nerv.error("bool value should be \"yes\" or \"no\"")
+                                nerv.error("boolean value should be \"yes\" or \"no\"")
                             end
                         elseif opts[k].type == "int" then
                             local t = tonumber(v)
@@ -269,11 +270,11 @@ function nerv.parse_args(argv, options)
                             if t == nil or math.floor(t) ~= t then
                                 nerv.error("int value is expected")
                             end
-                        elseif opts[k].type == "float" then
+                        elseif opts[k].type == "number" then
                             local t = tonumber(v)
                             opts[k].val = t
                             if t == nil then
-                                nerv.error("float value is expected")
+                                nerv.error("numeric value is expected")
                             end
                         elseif opts[k].type == "string" then
                             opts[k].val = v
@@ -287,6 +288,7 @@ function nerv.parse_args(argv, options)
             end
         else
             table.insert(args, token)
+            arg_start = true
         end
     end
     return args, opts
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
index 3cf4f5a..d78f6d1 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/sgd_buffer.lua
@@ -2,8 +2,9 @@ local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
 
 function SGDBuffer:__init(global_conf, buffer_conf)
     self.gconf = global_conf
+    self.batch_size = buffer_conf.batch_size
     self.buffer_size = math.floor(buffer_conf.buffer_size /
-                                global_conf.batch_size) * global_conf.batch_size
+                                    self.batch_size) * self.batch_size
     self.randomize = buffer_conf.randomize
     self.consume = buffer_conf.consume
     local cumat_type = global_conf.cumat_type
@@ -112,11 +113,11 @@ function SGDBuffer:saturate()
     end
     self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
     collectgarbage("collect")
-    return self.tail >= self.gconf.batch_size
+    return self.tail >= self.batch_size
 end
 
 function SGDBuffer:get_data()
-    local batch_size = self.gconf.batch_size
+    local batch_size = self.batch_size
     if self.head >= self.tail then -- buffer is empty
         local t = os.clock()
         if (not self:saturate()) and (not self.consume) then
diff --git a/nerv/nerv b/nerv/nerv
index 0b75a9b..f73d517 100644
--- a/nerv/nerv
+++ b/nerv/nerv
@@ -1,8 +1,8 @@
 #! /usr/bin/env luajit
 require 'nerv'
-local options = {{"help", "h", "bool", default = false, desc = "print this help message"},
-                 {"use-cpu", "c", "bool", default = false, desc = "use CPU by default (instead of gpu by default)"},
-                 {"select-gpu", nil, "int", default = nil, desc = "select the GPU for computation, fallback to auto mode if not specified"}}
+local options = {{"help", "h", "boolean", default = false, desc = "print this help message"},
+                 {"use-cpu", "c", "boolean", default = false, desc = "use CPU by default (instead of gpu by default)"},
+                 {"select-gpu", nil, "int", default = -1, desc = "select the GPU for computation, fallback to auto mode if not specified"}}
 
 local function print_help()
     nerv.printf("Usage: <nerv_prog> [options] script.lua\n")
@@ -24,10 +24,7 @@ local function _add_profile_method(cls)
 end
 
 if not opts["use-cpu"].val then
-    local dev = -1
-    if opts["select-gpu"].val then
-        dev = opts["select-gpu"].val
-    end
+    local dev = opts["select-gpu"].val
     nerv.info("automatically initialize a default CuContext...")
     nerv.CuMatrix._default_context = nerv.CuContext(dev)
     nerv.info("the default CuContext is ok")
diff --git a/nerv/test/parse_args.lua b/nerv/test/parse_args.lua
index 0d280a1..34ad55e 100644
--- a/nerv/test/parse_args.lua
+++ b/nerv/test/parse_args.lua
@@ -1,9 +1,9 @@
-local options = {{"abandon", "a", "bool", default = false, desc = "abandon your belief"},
-                 {"bullshit", "b", "bool", default = false, desc = "start to bullshit"},
-                 {"cheat", "c", "bool", default = false, desc = "try to cheat"},
-                 {"delete", "d", "bool", default = false, desc = "remove everything"},
-                 {"hehe", "h", "bool", default = false, desc = "233333"},
-                 {"oh", "o", "bool", default = true, desc = "oh yes!"},
+local options = {{"abandon", "a", "boolean", default = false, desc = "abandon your belief"},
+                 {"bullshit", "b", "boolean", default = false, desc = "start to bullshit"},
+                 {"cheat", "c", "boolean", default = false, desc = "try to cheat"},
+                 {"delete", "d", "boolean", default = false, desc = "remove everything"},
+                 {"hehe", "h", "boolean", default = false, desc = "233333"},
+                 {"oh", "o", "boolean", default = true, desc = "oh yes!"},
                  {"uid", nil, "int", desc = "user uid"},
                  {"str", nil, "string", desc = "test string"}}
 
-- 
cgit v1.2.3-70-g09d2