function build_trainer(ifname) local param_repo = nerv.ParamRepo() param_repo:import(ifname, gconf) local layer_repo = make_layer_repo(param_repo) local network = get_network(layer_repo) local global_transf = get_global_transf(layer_repo) local input_order = get_input_order() local iterative_trainer = function (prefix, scp_file, bp) local readers = make_readers(scp_file, layer_repo) -- initialize the network network:init(1) gconf.cnt = 0 for ri = 1, #readers, 1 do while true do local data = readers[ri].reader:get_data() if data == nil then break end -- prine stat periodically gconf.cnt = gconf.cnt + 1 if gconf.cnt == 1000 then print_stat(layer_repo) nerv.CuMatrix.print_profile() nerv.CuMatrix.clear_profile() gconf.cnt = 0 -- break end local input = {} -- if gconf.cnt == 1000 then break end for i, e in ipairs(input_order) do local id = e.id if data[id] == nil then nerv.error("input data %s not found", id) end local transformed if e.global_transf then local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol()) batch:copy_fromh(data[id]) transformed = nerv.speech_utils.global_transf(batch, global_transf, gconf.frm_ext or 0, 0, gconf) else transformed = data[id] end table.insert(input, transformed) end err_output = {input[1]:create()} network:batch_resize(input[1]:nrow()) if network:propagate(input, {{}}) == true then network:back_propagate({{}}, err_output, input, {{}}) gconf.batch_size = 1.0 - gconf.momentum -- important!!! network:update({{}}, input, {{}}) end -- collect garbage in-time to save GPU memory collectgarbage("collect") end end print_stat(layer_repo) nerv.CuMatrix.print_profile() nerv.CuMatrix.clear_profile() if prefix ~= nil then nerv.info("writing back...") local fname = string.format("%s_tr%.3f.nerv", prefix, get_accuracy(layer_repo)) network:get_params():export(fname, nil) end return get_accuracy(layer_repo) end return iterative_trainer end dofile(arg[1]) local pf0 = gconf.initialized_param local trainer = build_trainer(pf0) local i = 1 nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) local accu_tr = trainer(string.format("%s_%s_iter_%d_lr%f", string.gsub( (string.gsub(pf0[1], "(.*/)(.*)", "%2")), "(.*)%..*", "%1"), os.date("%Y%m%d%H%M%S"), i, gconf.lrate), gconf.tr_scp, true) nerv.info("[TR] training set %d: %.3f", i, accu_tr)