summaryrefslogtreecommitdiff
path: root/examples/asr_trainer.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-06 15:48:04 +0800
committerDeterminant <[email protected]>2015-06-06 15:48:04 +0800
commit5bcd5d79875587b08d598cc08bd5f8b1f5e14a23 (patch)
tree279cc3546d816d175dcff85c48f67f62468f97ed /examples/asr_trainer.lua
parent3959b99c4853c4deb20324ad0c54906f8ed1348a (diff)
...
Diffstat (limited to 'examples/asr_trainer.lua')
-rw-r--r--examples/asr_trainer.lua14
1 files changed, 8 insertions, 6 deletions
diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua
index b43a547..d72b763 100644
--- a/examples/asr_trainer.lua
+++ b/examples/asr_trainer.lua
@@ -4,7 +4,7 @@ function build_trainer(ifname)
local layer_repo = make_layer_repo(sublayer_repo, param_repo)
local crit = get_criterion_layer(sublayer_repo)
local network = get_network(layer_repo)
- local iterative_trainer = function (ofname, scp_file, bp)
+ local iterative_trainer = function (prefix, scp_file, bp)
gconf.randomize = bp
-- build buffer
local buffer = make_buffer(make_reader(scp_file, layer_repo))
@@ -18,7 +18,7 @@ function build_trainer(ifname)
print_stat(crit)
gconf.cnt = 0
end
- if gconf.cnt == 100 then break end
+-- if gconf.cnt == 100 then break end
input = {data.main_scp, data.phone_state}
output = {}
@@ -33,9 +33,10 @@ function build_trainer(ifname)
collectgarbage("collect")
end
print_stat(crit)
- if bp then
+ if (not bp) and prefix ~= nil then
nerv.info("writing back...")
- cf = nerv.ChunkFile(ofname, "w")
+ local accu_cv = get_accuracy(crit)
+ cf = nerv.ChunkFile(prefix .. "_cv" .. accu_cv .. ".nerv", "w")
for i, p in ipairs(network:get_params()) do
cf:write_chunk(p)
end
@@ -65,9 +66,10 @@ local do_halving = false
nerv.info("initial cross validation: %.3f", accu_best)
for i = 1, max_iter do
nerv.info("iteration %d with lrate = %.6f", i, gconf.lrate)
- local accu_tr = trainer(pf0 .. "_iter" .. i .. ".nerv", gconf.tr_scp, true)
+ local accu_tr = trainer(nil, gconf.tr_scp, true)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
- local accu_new = trainer(nil, gconf.cv_scp, false)
+ local accu_new = trainer(pf0 .. "_iter" .. i .. "_tr" .. accu_tr,
+ gconf.cv_scp, false)
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
-- TODO: revert the weights
local accu_diff = accu_new - accu_best