aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2016-03-03 15:49:10 +0800
committerDeterminant <[email protected]>2016-03-03 15:49:10 +0800
commite1c004f89ec3a783cf1948165baae5975489f775 (patch)
treef6b630853eac2719ef290d5bb8b7573e4f2c8567 /nerv/examples
parentc964d74b1927ffd9fd431f457f7e385b2d2ba5ba (diff)
add Penlight to facilitate file copying and dir making
Diffstat (limited to 'nerv/examples')
-rw-r--r--nerv/examples/asr_trainer.lua17
1 files changed, 16 insertions, 1 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 684ea30..5001e12 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,3 +1,5 @@
+require 'lfs'
+require 'pl'
local function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
@@ -133,6 +135,8 @@ local trainer_defaults = {
local options = make_options(trainer_defaults)
table.insert(options, {"help", "h", "boolean",
default = false, desc = "show this help information"})
+table.insert(options, {"dir", nil, "string",
+ default = nil, desc = "specify the working directory"})
arg, opts = nerv.parse_args(arg, options)
@@ -156,8 +160,19 @@ check_and_add_defaults(trainer_defaults)
local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)
local accu_best = trainer(nil, gconf.cv_scp, false)
+local date_pattern = "%Y%m%d%H%M%S"
+local logfile_name = "log"
+local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
print_gconf()
+if not lfs.mkdir(working_dir) then
+ nerv.error("[asr_trainer] working directory already exists")
+end
+-- copy the network config
+dir.copyfile(arg[1], working_dir)
+-- set logfile path
+nerv.set_logfile(path.join(working_dir, logfile_name))
+path.chdir(working_dir)
nerv.info("initial cross validation: %.3f", accu_best)
for i = 1, gconf.max_iter do
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
@@ -168,7 +183,7 @@ for i = 1, gconf.max_iter do
string.gsub(
(string.gsub(pf0[1], "(.*/)(.*)", "%2")),
"(.*)%..*", "%1"),
- os.date("%Y%m%d%H%M%S"),
+ os.date(date_pattern),
i, gconf.lrate,
accu_tr),
gconf.cv_scp, false)