aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua17
1 files changed, 10 insertions, 7 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 38ba6e9..9a764fc 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -39,12 +39,8 @@ local function build_trainer(ifname)
local buffer = make_buffer(make_readers(scp_file, layer_repo))
-- initialize the network
gconf.cnt = 0
- local err_input = {{}}
local output = {{}}
for i = 1, gconf.chunk_size do
- local mini_batch = mat_type(gconf.batch_size, 1)
- mini_batch:fill(1)
- table.insert(err_input[1], mini_batch)
table.insert(output[1], mat_type(gconf.batch_size, 1))
end
network:epoch_init()
@@ -91,7 +87,7 @@ local function build_trainer(ifname)
do_train = bp,
input = input,
output = output,
- err_input = err_input,
+ err_input = {gconf.mask},
err_output = err_output})
network:propagate()
if bp then
@@ -254,8 +250,15 @@ nerv.set_logfile(path.join(working_dir, logfile_name))
-- start the training
local trainer = build_trainer(pf0)
local pr_prev
-gconf.accu_best, pr_prev = trainer(nil, gconf.cv_scp, false)
+-- initial cross-validation
+local param_prefix = string.format("%s_%s",
+ string.gsub(
+ (string.gsub(pf0[1], "(.*/)(.*)", "%2")),
+ "(.*)%..*", "%1"),
+ os.date(date_pattern))
+gconf.accu_best, pr_prev = trainer(path.join(working_dir, param_prefix), gconf.cv_scp, false)
nerv.info("initial cross validation: %.3f", gconf.accu_best)
+-- main loop
for i = gconf.cur_iter, gconf.max_iter do
local stop = false
gconf.cur_iter = i
@@ -264,7 +267,7 @@ for i = gconf.cur_iter, gconf.max_iter do
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
- local param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f",
+ param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f",
string.gsub(
(string.gsub(pf0[1], "(.*/)(.*)", "%2")),
"(.*)%..*", "%1"),