diff options
author | cloudygoose <[email protected]> | 2015-06-08 12:50:02 +0800 |
---|---|---|
committer | cloudygoose <[email protected]> | 2015-06-08 12:50:02 +0800 |
commit | 155b0c0803f5f7cd3f8780273f6b0bdfbaed5970 (patch) | |
tree | 967c6326b83cda2b92eee5f597dde0e74b071dbb /examples | |
parent | 31330d6c095b2b11b34f524169f56dc8d18355c3 (diff) | |
parent | 0f30b1a4b5e583cb1df7dbb349c1af4378e41369 (diff) |
...
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'examples')
-rw-r--r-- | examples/asr_trainer.lua | 29 | ||||
-rw-r--r-- | examples/swb_baseline.lua | 14 |
2 files changed, 27 insertions, 16 deletions
diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua index b43a547..05d770f 100644 --- a/examples/asr_trainer.lua +++ b/examples/asr_trainer.lua @@ -4,7 +4,7 @@ function build_trainer(ifname) local layer_repo = make_layer_repo(sublayer_repo, param_repo) local crit = get_criterion_layer(sublayer_repo) local network = get_network(layer_repo) - local iterative_trainer = function (ofname, scp_file, bp) + local iterative_trainer = function (prefix, scp_file, bp) gconf.randomize = bp -- build buffer local buffer = make_buffer(make_reader(scp_file, layer_repo)) @@ -18,7 +18,7 @@ function build_trainer(ifname) print_stat(crit) gconf.cnt = 0 end - if gconf.cnt == 100 then break end +-- if gconf.cnt == 100 then break end input = {data.main_scp, data.phone_state} output = {} @@ -33,9 +33,12 @@ function build_trainer(ifname) collectgarbage("collect") end print_stat(crit) - if bp then + nerv.CuMatrix.print_profile() + if (not bp) and prefix ~= nil then nerv.info("writing back...") - cf = nerv.ChunkFile(ofname, "w") + local fname = string.format("%s_cv%.3f.nerv", + prefix, get_accuracy(crit)) + cf = nerv.ChunkFile(fname, "w") for i, p in ipairs(network:get_params()) do cf:write_chunk(p) end @@ -52,7 +55,7 @@ halving_factor = 0.6 end_halving_inc = 0.1 min_iter = 1 max_iter = 20 -min_halving = 6 +min_halving = 5 gconf.batch_size = 256 gconf.buffer_size = 81920 @@ -64,10 +67,18 @@ local do_halving = false nerv.info("initial cross validation: %.3f", accu_best) for i = 1, max_iter do - nerv.info("iteration %d with lrate = %.6f", i, gconf.lrate) - local accu_tr = trainer(pf0 .. "_iter" .. i .. ".nerv", gconf.tr_scp, true) + nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) + local accu_tr = trainer(nil, gconf.tr_scp, true) nerv.info("[TR] training set %d: %.3f", i, accu_tr) - local accu_new = trainer(nil, gconf.cv_scp, false) + local accu_new = trainer( + string.format("%s_%s_iter_%d_lr%f_tr%.3f", + string.gsub( + (string.gsub(pf0, "(.*/)(.*)", "%2")), + "(.*)%..*", "%1"), + os.date("%Y%m%d%H%M%S"), + i, gconf.lrate, + accu_tr), + gconf.cv_scp, false) nerv.info("[CV] cross validation %d: %.3f", i, accu_new) -- TODO: revert the weights local accu_diff = accu_new - accu_best @@ -83,5 +94,5 @@ for i = 1, max_iter do if accu_new > accu_best then accu_best = accu_new end +-- nerv.Matrix.print_profile() end -nerv.Matrix.print_profile() diff --git a/examples/swb_baseline.lua b/examples/swb_baseline.lua index f536777..28cc6d5 100644 --- a/examples/swb_baseline.lua +++ b/examples/swb_baseline.lua @@ -6,8 +6,8 @@ gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, tr_scp = "/slfs1/users/mfy43/swb_ivec/train_bp.scp", cv_scp = "/slfs1/users/mfy43/swb_ivec/train_cv.scp", htk_conf = "/slfs1/users/mfy43/swb_ivec/plp_0_d_a.conf", - global_transf = "global_transf.nerv", - initialized_param = "converted.nerv", + global_transf = "/slfs1/users/mfy43/swb_global_transf.nerv", + initialized_param = "/slfs1/users/mfy43/swb_init.nerv", debug = false} function make_param_repo(param_file) @@ -154,10 +154,10 @@ end function print_stat(crit) nerv.info("*** training stat begin ***") - nerv.utils.printf("cross entropy:\t%.8f\n", crit.total_ce) - nerv.utils.printf("correct:\t%d\n", crit.total_correct) - nerv.utils.printf("frames:\t%d\n", crit.total_frames) - nerv.utils.printf("err/frm:\t%.8f\n", crit.total_ce / crit.total_frames) - nerv.utils.printf("accuracy:\t%.3f%%\n", get_accuracy(crit)) + nerv.utils.printf("cross entropy:\t\t%.8f\n", crit.total_ce) + nerv.utils.printf("correct:\t\t%d\n", crit.total_correct) + nerv.utils.printf("frames:\t\t\t%d\n", crit.total_frames) + nerv.utils.printf("err/frm:\t\t%.8f\n", crit.total_ce / crit.total_frames) + nerv.utils.printf("accuracy:\t\t%.3f%%\n", get_accuracy(crit)) nerv.info("*** training stat end ***") end |