aboutsummaryrefslogtreecommitdiff
path: root/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'examples/asr_trainer.lua')
-rw-r--r--examples/asr_trainer.lua20
1 files changed, 13 insertions, 7 deletions
diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua
index d72b763..2993192 100644
--- a/examples/asr_trainer.lua
+++ b/examples/asr_trainer.lua
@@ -35,8 +35,9 @@ function build_trainer(ifname)
print_stat(crit)
if (not bp) and prefix ~= nil then
nerv.info("writing back...")
- local accu_cv = get_accuracy(crit)
- cf = nerv.ChunkFile(prefix .. "_cv" .. accu_cv .. ".nerv", "w")
+ local fname = string.format("%s_cv%.3f.nerv",
+ prefix, get_accuracy(crit))
+ cf = nerv.ChunkFile(fname, "w")
for i, p in ipairs(network:get_params()) do
cf:write_chunk(p)
end
@@ -53,7 +54,7 @@ halving_factor = 0.6
end_halving_inc = 0.1
min_iter = 1
max_iter = 20
-min_halving = 6
+min_halving = 5
gconf.batch_size = 256
gconf.buffer_size = 81920
@@ -65,11 +66,16 @@ local do_halving = false
nerv.info("initial cross validation: %.3f", accu_best)
for i = 1, max_iter do
- nerv.info("iteration %d with lrate = %.6f", i, gconf.lrate)
+ nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local accu_tr = trainer(nil, gconf.tr_scp, true)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
- local accu_new = trainer(pf0 .. "_iter" .. i .. "_tr" .. accu_tr,
- gconf.cv_scp, false)
+ local accu_new = trainer(
+ string.format("%s_%s_iter_%d_lr%f_tr%.3f",
+ string.gsub(pf0, "(.*/)(.*)%..*", "%2"),
+ os.date("%Y%m%d%H%M%S"),
+ i, gconf.lrate,
+ accu_tr),
+ gconf.cv_scp, false)
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
-- TODO: revert the weights
local accu_diff = accu_new - accu_best
@@ -85,5 +91,5 @@ for i = 1, max_iter do
if accu_new > accu_best then
accu_best = accu_new
end
+ nerv.Matrix.print_profile()
end
-nerv.Matrix.print_profile()