diff options
Diffstat (limited to 'nerv/examples/seq_trainer.lua')
-rw-r--r-- | nerv/examples/seq_trainer.lua | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/nerv/examples/seq_trainer.lua b/nerv/examples/seq_trainer.lua new file mode 100644 index 0000000..df96e68 --- /dev/null +++ b/nerv/examples/seq_trainer.lua @@ -0,0 +1,86 @@ +function build_trainer(ifname) + local param_repo = nerv.ParamRepo() + param_repo:import(ifname, nil, gconf) + local layer_repo = make_layer_repo(param_repo) + local network = get_network(layer_repo) + local global_transf = get_global_transf(layer_repo) + local input_order = get_input_order() + local iterative_trainer = function (prefix, scp_file, bp) + local readers = make_readers(scp_file, layer_repo) + -- initialize the network + network:init(1) + gconf.cnt = 0 + for ri = 1, #readers, 1 do + while true do + local data = readers[ri].reader:get_data() + if data == nil then + break + end + -- prine stat periodically + gconf.cnt = gconf.cnt + 1 + if gconf.cnt == 1000 then + print_stat(layer_repo) + nerv.CuMatrix.print_profile() + nerv.CuMatrix.clear_profile() + gconf.cnt = 0 + -- break + end + local input = {} + -- if gconf.cnt == 1000 then break end + for i, e in ipairs(input_order) do + local id = e.id + if data[id] == nil then + nerv.error("input data %s not found", id) + end + local transformed + if e.global_transf then + local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol()) + batch:copy_fromh(data[id]) + transformed = nerv.speech_utils.global_transf(batch, + global_transf, + gconf.frm_ext or 0, 0, + gconf) + else + transformed = data[id] + end + table.insert(input, transformed) + end + err_output = {input[1]:create()} + network:batch_resize(input[1]:nrow()) + if network:propagate(input, {{}}) == true then + network:back_propagate({{}}, err_output, input, {{}}) + network:update({{}}, input, {{}}) + end + -- collect garbage in-time to save GPU memory + collectgarbage("collect") + end + end + print_stat(layer_repo) + nerv.CuMatrix.print_profile() + nerv.CuMatrix.clear_profile() + if prefix ~= nil then + nerv.info("writing back...") + local fname = string.format("%s_tr%.3f.nerv", + prefix, get_accuracy(layer_repo)) + network:get_params():export(fname, nil) + end + return get_accuracy(layer_repo) + end + return iterative_trainer +end + +dofile(arg[1]) + +local pf0 = gconf.initialized_param +local trainer = build_trainer(pf0) + +local i = 1 +nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) +local accu_tr = trainer(string.format("%s_%s_iter_%d_lr%f", +string.gsub( +(string.gsub(pf0[1], "(.*/)(.*)", "%2")), +"(.*)%..*", "%1"), +os.date("%Y%m%d%H%M%S"), +i, gconf.lrate), gconf.tr_scp, true) +nerv.info("[TR] training set %d: %.3f", i, accu_tr) + |