From f3f4e74eb4dbb8829e5ee136ba4b0c0a7938b551 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sat, 20 Jun 2015 20:00:25 +0800 Subject: change concept of ParamRepo; provide generalized param update; code clean-up; #25 #26 #27 #29 --- examples/asr_trainer.lua | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) (limited to 'examples/asr_trainer.lua') diff --git a/examples/asr_trainer.lua b/examples/asr_trainer.lua index 05d770f..a5727be 100644 --- a/examples/asr_trainer.lua +++ b/examples/asr_trainer.lua @@ -1,50 +1,58 @@ function build_trainer(ifname) - local param_repo = make_param_repo(ifname) + local param_repo = nerv.ParamRepo() + param_repo:import(ifname, nil, gconf) local sublayer_repo = make_sublayer_repo(param_repo) local layer_repo = make_layer_repo(sublayer_repo, param_repo) local crit = get_criterion_layer(sublayer_repo) local network = get_network(layer_repo) + local input_order = get_input_order() local iterative_trainer = function (prefix, scp_file, bp) gconf.randomize = bp -- build buffer - local buffer = make_buffer(make_reader(scp_file, layer_repo)) + local buffer = make_buffer(make_readers(scp_file, layer_repo)) -- initialize the network network:init(gconf.batch_size) gconf.cnt = 0 + err_input = {nerv.CuMatrixFloat(256, 1)} + err_input[1]:fill(1) for data in buffer.get_data, buffer do -- prine stat periodically gconf.cnt = gconf.cnt + 1 if gconf.cnt == 1000 then - print_stat(crit) + print_stat(sublayer_repo) + nerv.CuMatrix.print_profile() + nerv.CuMatrix.clear_profile() gconf.cnt = 0 + -- break end + local input = {} -- if gconf.cnt == 100 then break end - - input = {data.main_scp, data.phone_state} - output = {} - err_input = {} + for i, id in ipairs(input_order) do + if data[id] == nil then + nerv.error("input data %s not found", id) + end + table.insert(input, data[id]) + end + local output = {nerv.CuMatrixFloat(256, 1)} err_output = {input[1]:create()} network:propagate(input, output) if bp then - network:back_propagate(err_output, err_input, input, output) + network:back_propagate(err_input, err_output, input, output) network:update(err_input, input, output) end -- collect garbage in-time to save GPU memory collectgarbage("collect") end - print_stat(crit) + print_stat(sublayer_repo) nerv.CuMatrix.print_profile() + nerv.CuMatrix.clear_profile() if (not bp) and prefix ~= nil then nerv.info("writing back...") local fname = string.format("%s_cv%.3f.nerv", - prefix, get_accuracy(crit)) - cf = nerv.ChunkFile(fname, "w") - for i, p in ipairs(network:get_params()) do - cf:write_chunk(p) - end - cf:close() + prefix, get_accuracy(sublayer_repo)) + network:get_params():export(fname, nil) end - return get_accuracy(crit) + return get_accuracy(sublayer_repo) end return iterative_trainer end @@ -73,7 +81,7 @@ for i = 1, max_iter do local accu_new = trainer( string.format("%s_%s_iter_%d_lr%f_tr%.3f", string.gsub( - (string.gsub(pf0, "(.*/)(.*)", "%2")), + (string.gsub(pf0[1], "(.*/)(.*)", "%2")), "(.*)%..*", "%1"), os.date("%Y%m%d%H%M%S"), i, gconf.lrate, -- cgit v1.2.3-70-g09d2