aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorcloudygoose <[email protected]>2015-06-08 12:50:02 +0800
committercloudygoose <[email protected]>2015-06-08 12:50:02 +0800
commit155b0c0803f5f7cd3f8780273f6b0bdfbaed5970 (patch)
tree967c6326b83cda2b92eee5f597dde0e74b071dbb /examples
parent31330d6c095b2b11b34f524169f56dc8d18355c3 (diff)
parent0f30b1a4b5e583cb1df7dbb349c1af4378e41369 (diff)
...
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'examples')
-rw-r--r--examples/asr_trainer.lua29
-rw-r--r--examples/swb_baseline.lua14
2 files changed, 27 insertions, 16 deletions
diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua
index b43a547..05d770f 100644
--- a/examples/asr_trainer.lua
+++ b/examples/asr_trainer.lua
@@ -4,7 +4,7 @@ function build_trainer(ifname)
local layer_repo = make_layer_repo(sublayer_repo, param_repo)
local crit = get_criterion_layer(sublayer_repo)
local network = get_network(layer_repo)
- local iterative_trainer = function (ofname, scp_file, bp)
+ local iterative_trainer = function (prefix, scp_file, bp)
gconf.randomize = bp
-- build buffer
local buffer = make_buffer(make_reader(scp_file, layer_repo))
@@ -18,7 +18,7 @@ function build_trainer(ifname)
print_stat(crit)
gconf.cnt = 0
end
- if gconf.cnt == 100 then break end
+-- if gconf.cnt == 100 then break end
input = {data.main_scp, data.phone_state}
output = {}
@@ -33,9 +33,12 @@ function build_trainer(ifname)
collectgarbage("collect")
end
print_stat(crit)
- if bp then
+ nerv.CuMatrix.print_profile()
+ if (not bp) and prefix ~= nil then
nerv.info("writing back...")
- cf = nerv.ChunkFile(ofname, "w")
+ local fname = string.format("%s_cv%.3f.nerv",
+ prefix, get_accuracy(crit))
+ cf = nerv.ChunkFile(fname, "w")
for i, p in ipairs(network:get_params()) do
cf:write_chunk(p)
end
@@ -52,7 +55,7 @@ halving_factor = 0.6
end_halving_inc = 0.1
min_iter = 1
max_iter = 20
-min_halving = 6
+min_halving = 5
gconf.batch_size = 256
gconf.buffer_size = 81920
@@ -64,10 +67,18 @@ local do_halving = false
nerv.info("initial cross validation: %.3f", accu_best)
for i = 1, max_iter do
- nerv.info("iteration %d with lrate = %.6f", i, gconf.lrate)
- local accu_tr = trainer(pf0 .. "_iter" .. i .. ".nerv", gconf.tr_scp, true)
+ nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
+ local accu_tr = trainer(nil, gconf.tr_scp, true)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
- local accu_new = trainer(nil, gconf.cv_scp, false)
+ local accu_new = trainer(
+ string.format("%s_%s_iter_%d_lr%f_tr%.3f",
+ string.gsub(
+ (string.gsub(pf0, "(.*/)(.*)", "%2")),
+ "(.*)%..*", "%1"),
+ os.date("%Y%m%d%H%M%S"),
+ i, gconf.lrate,
+ accu_tr),
+ gconf.cv_scp, false)
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
-- TODO: revert the weights
local accu_diff = accu_new - accu_best
@@ -83,5 +94,5 @@ for i = 1, max_iter do
if accu_new > accu_best then
accu_best = accu_new
end
+-- nerv.Matrix.print_profile()
end
-nerv.Matrix.print_profile()
diff --git a/examples/swb_baseline.lua b/examples/swb_baseline.lua
index f536777..28cc6d5 100644
--- a/examples/swb_baseline.lua
+++ b/examples/swb_baseline.lua
@@ -6,8 +6,8 @@ gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
tr_scp = "/slfs1/users/mfy43/swb_ivec/train_bp.scp",
cv_scp = "/slfs1/users/mfy43/swb_ivec/train_cv.scp",
htk_conf = "/slfs1/users/mfy43/swb_ivec/plp_0_d_a.conf",
- global_transf = "global_transf.nerv",
- initialized_param = "converted.nerv",
+ global_transf = "/slfs1/users/mfy43/swb_global_transf.nerv",
+ initialized_param = "/slfs1/users/mfy43/swb_init.nerv",
debug = false}
function make_param_repo(param_file)
@@ -154,10 +154,10 @@ end
function print_stat(crit)
nerv.info("*** training stat begin ***")
- nerv.utils.printf("cross entropy:\t%.8f\n", crit.total_ce)
- nerv.utils.printf("correct:\t%d\n", crit.total_correct)
- nerv.utils.printf("frames:\t%d\n", crit.total_frames)
- nerv.utils.printf("err/frm:\t%.8f\n", crit.total_ce / crit.total_frames)
- nerv.utils.printf("accuracy:\t%.3f%%\n", get_accuracy(crit))
+ nerv.utils.printf("cross entropy:\t\t%.8f\n", crit.total_ce)
+ nerv.utils.printf("correct:\t\t%d\n", crit.total_correct)
+ nerv.utils.printf("frames:\t\t\t%d\n", crit.total_frames)
+ nerv.utils.printf("err/frm:\t\t%.8f\n", crit.total_ce / crit.total_frames)
+ nerv.utils.printf("accuracy:\t\t%.3f%%\n", get_accuracy(crit))
nerv.info("*** training stat end ***")
end