aboutsummaryrefslogtreecommitdiff
path: root/examples/asr_trainer.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-20 20:00:25 +0800
committerDeterminant <[email protected]>2015-06-20 20:00:25 +0800
commitf3f4e74eb4dbb8829e5ee136ba4b0c0a7938b551 (patch)
tree8beb12182020267ce32904d646ad0c736c27dcd2 /examples/asr_trainer.lua
parent2ab9610a4fff798c1668cdc041515256fa813865 (diff)
change concept of ParamRepo; provide generalized param update; code clean-up; #25 #26 #27 #29
Diffstat (limited to 'examples/asr_trainer.lua')
-rw-r--r--examples/asr_trainer.lua42
1 files changed, 25 insertions, 17 deletions
diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua
index 05d770f..a5727be 100644
--- a/examples/asr_trainer.lua
+++ b/examples/asr_trainer.lua
@@ -1,50 +1,58 @@
function build_trainer(ifname)
- local param_repo = make_param_repo(ifname)
+ local param_repo = nerv.ParamRepo()
+ param_repo:import(ifname, nil, gconf)
local sublayer_repo = make_sublayer_repo(param_repo)
local layer_repo = make_layer_repo(sublayer_repo, param_repo)
local crit = get_criterion_layer(sublayer_repo)
local network = get_network(layer_repo)
+ local input_order = get_input_order()
local iterative_trainer = function (prefix, scp_file, bp)
gconf.randomize = bp
-- build buffer
- local buffer = make_buffer(make_reader(scp_file, layer_repo))
+ local buffer = make_buffer(make_readers(scp_file, layer_repo))
-- initialize the network
network:init(gconf.batch_size)
gconf.cnt = 0
+ err_input = {nerv.CuMatrixFloat(256, 1)}
+ err_input[1]:fill(1)
for data in buffer.get_data, buffer do
-- prine stat periodically
gconf.cnt = gconf.cnt + 1
if gconf.cnt == 1000 then
- print_stat(crit)
+ print_stat(sublayer_repo)
+ nerv.CuMatrix.print_profile()
+ nerv.CuMatrix.clear_profile()
gconf.cnt = 0
+ -- break
end
+ local input = {}
-- if gconf.cnt == 100 then break end
-
- input = {data.main_scp, data.phone_state}
- output = {}
- err_input = {}
+ for i, id in ipairs(input_order) do
+ if data[id] == nil then
+ nerv.error("input data %s not found", id)
+ end
+ table.insert(input, data[id])
+ end
+ local output = {nerv.CuMatrixFloat(256, 1)}
err_output = {input[1]:create()}
network:propagate(input, output)
if bp then
- network:back_propagate(err_output, err_input, input, output)
+ network:back_propagate(err_input, err_output, input, output)
network:update(err_input, input, output)
end
-- collect garbage in-time to save GPU memory
collectgarbage("collect")
end
- print_stat(crit)
+ print_stat(sublayer_repo)
nerv.CuMatrix.print_profile()
+ nerv.CuMatrix.clear_profile()
if (not bp) and prefix ~= nil then
nerv.info("writing back...")
local fname = string.format("%s_cv%.3f.nerv",
- prefix, get_accuracy(crit))
- cf = nerv.ChunkFile(fname, "w")
- for i, p in ipairs(network:get_params()) do
- cf:write_chunk(p)
- end
- cf:close()
+ prefix, get_accuracy(sublayer_repo))
+ network:get_params():export(fname, nil)
end
- return get_accuracy(crit)
+ return get_accuracy(sublayer_repo)
end
return iterative_trainer
end
@@ -73,7 +81,7 @@ for i = 1, max_iter do
local accu_new = trainer(
string.format("%s_%s_iter_%d_lr%f_tr%.3f",
string.gsub(
- (string.gsub(pf0, "(.*/)(.*)", "%2")),
+ (string.gsub(pf0[1], "(.*/)(.*)", "%2")),
"(.*)%..*", "%1"),
os.date("%Y%m%d%H%M%S"),
i, gconf.lrate,