aboutsummaryrefslogtreecommitdiff
path: root/examples/asr_trainer.lua
diff options
context:
space:
mode:
authorcloudygoose <[email protected]>2015-06-08 12:50:02 +0800
committercloudygoose <[email protected]>2015-06-08 12:50:02 +0800
commit155b0c0803f5f7cd3f8780273f6b0bdfbaed5970 (patch)
tree967c6326b83cda2b92eee5f597dde0e74b071dbb /examples/asr_trainer.lua
parent31330d6c095b2b11b34f524169f56dc8d18355c3 (diff)
parent0f30b1a4b5e583cb1df7dbb349c1af4378e41369 (diff)
...
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'examples/asr_trainer.lua')
-rw-r--r--examples/asr_trainer.lua29
1 files changed, 20 insertions, 9 deletions
diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua
index b43a547..05d770f 100644
--- a/examples/asr_trainer.lua
+++ b/examples/asr_trainer.lua
@@ -4,7 +4,7 @@ function build_trainer(ifname)
local layer_repo = make_layer_repo(sublayer_repo, param_repo)
local crit = get_criterion_layer(sublayer_repo)
local network = get_network(layer_repo)
- local iterative_trainer = function (ofname, scp_file, bp)
+ local iterative_trainer = function (prefix, scp_file, bp)
gconf.randomize = bp
-- build buffer
local buffer = make_buffer(make_reader(scp_file, layer_repo))
@@ -18,7 +18,7 @@ function build_trainer(ifname)
print_stat(crit)
gconf.cnt = 0
end
- if gconf.cnt == 100 then break end
+-- if gconf.cnt == 100 then break end
input = {data.main_scp, data.phone_state}
output = {}
@@ -33,9 +33,12 @@ function build_trainer(ifname)
collectgarbage("collect")
end
print_stat(crit)
- if bp then
+ nerv.CuMatrix.print_profile()
+ if (not bp) and prefix ~= nil then
nerv.info("writing back...")
- cf = nerv.ChunkFile(ofname, "w")
+ local fname = string.format("%s_cv%.3f.nerv",
+ prefix, get_accuracy(crit))
+ cf = nerv.ChunkFile(fname, "w")
for i, p in ipairs(network:get_params()) do
cf:write_chunk(p)
end
@@ -52,7 +55,7 @@ halving_factor = 0.6
end_halving_inc = 0.1
min_iter = 1
max_iter = 20
-min_halving = 6
+min_halving = 5
gconf.batch_size = 256
gconf.buffer_size = 81920
@@ -64,10 +67,18 @@ local do_halving = false
nerv.info("initial cross validation: %.3f", accu_best)
for i = 1, max_iter do
- nerv.info("iteration %d with lrate = %.6f", i, gconf.lrate)
- local accu_tr = trainer(pf0 .. "_iter" .. i .. ".nerv", gconf.tr_scp, true)
+ nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
+ local accu_tr = trainer(nil, gconf.tr_scp, true)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
- local accu_new = trainer(nil, gconf.cv_scp, false)
+ local accu_new = trainer(
+ string.format("%s_%s_iter_%d_lr%f_tr%.3f",
+ string.gsub(
+ (string.gsub(pf0, "(.*/)(.*)", "%2")),
+ "(.*)%..*", "%1"),
+ os.date("%Y%m%d%H%M%S"),
+ i, gconf.lrate,
+ accu_tr),
+ gconf.cv_scp, false)
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
-- TODO: revert the weights
local accu_diff = accu_new - accu_best
@@ -83,5 +94,5 @@ for i = 1, max_iter do
if accu_new > accu_best then
accu_best = accu_new
end
+-- nerv.Matrix.print_profile()
end
-nerv.Matrix.print_profile()