aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua7
1 files changed, 4 insertions, 3 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 52cb754..aa1019d 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -248,7 +248,7 @@ end
dir.copyfile(arg[1], working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))
-path.chdir(working_dir)
+--path.chdir(working_dir)
-- start the training
local trainer = build_trainer(pf0)
@@ -258,7 +258,7 @@ nerv.info("initial cross validation: %.3f", gconf.accu_best)
for i = gconf.cur_iter, gconf.max_iter do
local stop = false
gconf.cur_iter = i
- dump_gconf(string.format("iter_%d.meta", i))
+ dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i)))
repeat -- trick to implement `continue` statement
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo)
@@ -270,7 +270,8 @@ for i = gconf.cur_iter, gconf.max_iter do
os.date(date_pattern),
i, gconf.lrate,
accu_tr)
- local accu_new, pr_new, param_fname = trainer(param_prefix, gconf.cv_scp, false)
+ local accu_new, pr_new, param_fname =
+ trainer(path.join(working_dir, param_prefix), gconf.cv_scp, false)
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
local accu_prev = gconf.accu_best
if accu_new < gconf.accu_best then