From 89d57b6fae6bcb0195a73fb97ab6870ee0d0ce20 Mon Sep 17 00:00:00 2001 From: Determinant Date: Wed, 30 Mar 2016 13:54:14 +0800 Subject: fix bug in passing err_input to network; gen zero vectors for bias --- nerv/examples/asr_trainer.lua | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'nerv/examples/asr_trainer.lua') diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index 38ba6e9..9a764fc 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -39,12 +39,8 @@ local function build_trainer(ifname) local buffer = make_buffer(make_readers(scp_file, layer_repo)) -- initialize the network gconf.cnt = 0 - local err_input = {{}} local output = {{}} for i = 1, gconf.chunk_size do - local mini_batch = mat_type(gconf.batch_size, 1) - mini_batch:fill(1) - table.insert(err_input[1], mini_batch) table.insert(output[1], mat_type(gconf.batch_size, 1)) end network:epoch_init() @@ -91,7 +87,7 @@ local function build_trainer(ifname) do_train = bp, input = input, output = output, - err_input = err_input, + err_input = {gconf.mask}, err_output = err_output}) network:propagate() if bp then @@ -254,8 +250,15 @@ nerv.set_logfile(path.join(working_dir, logfile_name)) -- start the training local trainer = build_trainer(pf0) local pr_prev -gconf.accu_best, pr_prev = trainer(nil, gconf.cv_scp, false) +-- initial cross-validation +local param_prefix = string.format("%s_%s", + string.gsub( + (string.gsub(pf0[1], "(.*/)(.*)", "%2")), + "(.*)%..*", "%1"), + os.date(date_pattern)) +gconf.accu_best, pr_prev = trainer(path.join(working_dir, param_prefix), gconf.cv_scp, false) nerv.info("initial cross validation: %.3f", gconf.accu_best) +-- main loop for i = gconf.cur_iter, gconf.max_iter do local stop = false gconf.cur_iter = i @@ -264,7 +267,7 @@ for i = gconf.cur_iter, gconf.max_iter do nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo) nerv.info("[TR] training set %d: %.3f", i, accu_tr) - local param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f", + param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f", string.gsub( (string.gsub(pf0[1], "(.*/)(.*)", "%2")), "(.*)%..*", "%1"), -- cgit v1.2.3