aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-08-06 14:08:26 +0800
committerDeterminant <ted.sybil@gmail.com>2015-08-06 14:08:26 +0800
commitb4d9cfa8e3a4735687311577dded97d889340134 (patch)
tree49e0f000719705c563a357e0b89d62a66c84ce75 /nerv/examples/asr_trainer.lua
parent2dc87bc02a1242dd5e029d0baaf4e0ae7173184f (diff)
make network configuration example file clearer
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua11
1 files changed, 5 insertions, 6 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 8dfb2ac..dcadfa3 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,8 +1,7 @@
function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
- local sublayer_repo = make_sublayer_repo(param_repo)
- local layer_repo = make_layer_repo(sublayer_repo, param_repo)
+ local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
local input_order = get_input_order()
local iterative_trainer = function (prefix, scp_file, bp)
@@ -18,7 +17,7 @@ function build_trainer(ifname)
-- prine stat periodically
gconf.cnt = gconf.cnt + 1
if gconf.cnt == 1000 then
- print_stat(sublayer_repo)
+ print_stat(layer_repo)
nerv.CuMatrix.print_profile()
nerv.CuMatrix.clear_profile()
gconf.cnt = 0
@@ -42,16 +41,16 @@ function build_trainer(ifname)
-- collect garbage in-time to save GPU memory
collectgarbage("collect")
end
- print_stat(sublayer_repo)
+ print_stat(layer_repo)
nerv.CuMatrix.print_profile()
nerv.CuMatrix.clear_profile()
if (not bp) and prefix ~= nil then
nerv.info("writing back...")
local fname = string.format("%s_cv%.3f.nerv",
- prefix, get_accuracy(sublayer_repo))
+ prefix, get_accuracy(layer_repo))
network:get_params():export(fname, nil)
end
- return get_accuracy(sublayer_repo)
+ return get_accuracy(layer_repo)
end
return iterative_trainer
end