From 5bcd5d79875587b08d598cc08bd5f8b1f5e14a23 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sat, 6 Jun 2015 15:48:04 +0800 Subject: ... --- examples/asr_trainer.lua | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'examples/asr_trainer.lua') diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua index b43a547..d72b763 100644 --- a/examples/asr_trainer.lua +++ b/examples/asr_trainer.lua @@ -4,7 +4,7 @@ function build_trainer(ifname) local layer_repo = make_layer_repo(sublayer_repo, param_repo) local crit = get_criterion_layer(sublayer_repo) local network = get_network(layer_repo) - local iterative_trainer = function (ofname, scp_file, bp) + local iterative_trainer = function (prefix, scp_file, bp) gconf.randomize = bp -- build buffer local buffer = make_buffer(make_reader(scp_file, layer_repo)) @@ -18,7 +18,7 @@ function build_trainer(ifname) print_stat(crit) gconf.cnt = 0 end - if gconf.cnt == 100 then break end +-- if gconf.cnt == 100 then break end input = {data.main_scp, data.phone_state} output = {} @@ -33,9 +33,10 @@ function build_trainer(ifname) collectgarbage("collect") end print_stat(crit) - if bp then + if (not bp) and prefix ~= nil then nerv.info("writing back...") - cf = nerv.ChunkFile(ofname, "w") + local accu_cv = get_accuracy(crit) + cf = nerv.ChunkFile(prefix .. "_cv" .. accu_cv .. ".nerv", "w") for i, p in ipairs(network:get_params()) do cf:write_chunk(p) end @@ -65,9 +66,10 @@ local do_halving = false nerv.info("initial cross validation: %.3f", accu_best) for i = 1, max_iter do nerv.info("iteration %d with lrate = %.6f", i, gconf.lrate) - local accu_tr = trainer(pf0 .. "_iter" .. i .. ".nerv", gconf.tr_scp, true) + local accu_tr = trainer(nil, gconf.tr_scp, true) nerv.info("[TR] training set %d: %.3f", i, accu_tr) - local accu_new = trainer(nil, gconf.cv_scp, false) + local accu_new = trainer(pf0 .. "_iter" .. i .. "_tr" .. accu_tr, + gconf.cv_scp, false) nerv.info("[CV] cross validation %d: %.3f", i, accu_new) -- TODO: revert the weights local accu_diff = accu_new - accu_best -- cgit v1.2.3 From 6e720b961f7edac9c3a41affe0ca40dd0ec9fc85 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sun, 7 Jun 2015 11:55:09 +0800 Subject: fix memory leak in profiling; other minor changes --- examples/asr_trainer.lua | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'examples/asr_trainer.lua') diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua index d72b763..2993192 100644 --- a/examples/asr_trainer.lua +++ b/examples/asr_trainer.lua @@ -35,8 +35,9 @@ function build_trainer(ifname) print_stat(crit) if (not bp) and prefix ~= nil then nerv.info("writing back...") - local accu_cv = get_accuracy(crit) - cf = nerv.ChunkFile(prefix .. "_cv" .. accu_cv .. ".nerv", "w") + local fname = string.format("%s_cv%.3f.nerv", + prefix, get_accuracy(crit)) + cf = nerv.ChunkFile(fname, "w") for i, p in ipairs(network:get_params()) do cf:write_chunk(p) end @@ -53,7 +54,7 @@ halving_factor = 0.6 end_halving_inc = 0.1 min_iter = 1 max_iter = 20 -min_halving = 6 +min_halving = 5 gconf.batch_size = 256 gconf.buffer_size = 81920 @@ -65,11 +66,16 @@ local do_halving = false nerv.info("initial cross validation: %.3f", accu_best) for i = 1, max_iter do - nerv.info("iteration %d with lrate = %.6f", i, gconf.lrate) + nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) local accu_tr = trainer(nil, gconf.tr_scp, true) nerv.info("[TR] training set %d: %.3f", i, accu_tr) - local accu_new = trainer(pf0 .. "_iter" .. i .. "_tr" .. accu_tr, - gconf.cv_scp, false) + local accu_new = trainer( + string.format("%s_%s_iter_%d_lr%f_tr%.3f", + string.gsub(pf0, "(.*/)(.*)%..*", "%2"), + os.date("%Y%m%d%H%M%S"), + i, gconf.lrate, + accu_tr), + gconf.cv_scp, false) nerv.info("[CV] cross validation %d: %.3f", i, accu_new) -- TODO: revert the weights local accu_diff = accu_new - accu_best @@ -85,5 +91,5 @@ for i = 1, max_iter do if accu_new > accu_best then accu_best = accu_new end + nerv.Matrix.print_profile() end -nerv.Matrix.print_profile() -- cgit v1.2.3 From 0f30b1a4b5e583cb1df7dbb349c1af4378e41369 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sun, 7 Jun 2015 21:59:10 +0800 Subject: fix minor bugs in cumatrix; clean up part of code --- examples/asr_trainer.lua | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'examples/asr_trainer.lua') diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua index 2993192..05d770f 100644 --- a/examples/asr_trainer.lua +++ b/examples/asr_trainer.lua @@ -33,6 +33,7 @@ function build_trainer(ifname) collectgarbage("collect") end print_stat(crit) + nerv.CuMatrix.print_profile() if (not bp) and prefix ~= nil then nerv.info("writing back...") local fname = string.format("%s_cv%.3f.nerv", @@ -71,7 +72,9 @@ for i = 1, max_iter do nerv.info("[TR] training set %d: %.3f", i, accu_tr) local accu_new = trainer( string.format("%s_%s_iter_%d_lr%f_tr%.3f", - string.gsub(pf0, "(.*/)(.*)%..*", "%2"), + string.gsub( + (string.gsub(pf0, "(.*/)(.*)", "%2")), + "(.*)%..*", "%1"), os.date("%Y%m%d%H%M%S"), i, gconf.lrate, accu_tr), @@ -91,5 +94,5 @@ for i = 1, max_iter do if accu_new > accu_best then accu_best = accu_new end - nerv.Matrix.print_profile() +-- nerv.Matrix.print_profile() end -- cgit v1.2.3