require 'lfs' require 'pl' local function build_trainer(ifname) local host_param_repo = nerv.ParamRepo() local mat_type local src_loc_type local train_loc_type host_param_repo:import(ifname, nil, gconf) if gconf.use_cpu then mat_type = gconf.mmat_type src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST else mat_type = gconf.cumat_type src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE end local param_repo = host_param_repo:copy(train_loc_type) local layer_repo = make_layer_repo(param_repo) local network = get_network(layer_repo) local global_transf = get_global_transf(layer_repo) local input_order = get_input_order() network = nerv.Network("nt", gconf, {network = network}) network:init(gconf.batch_size, 1) global_transf = nerv.Network("gt", gconf, {network = global_transf}) global_transf:init(gconf.batch_size, 1) local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo) -- rebind the params if necessary if rebind_param_repo then host_param_repo = rebind_param_repo param_repo = host_param_repo:copy(train_loc_type) layer_repo:rebind(param_repo) rebind_param_repo = nil end gconf.randomize = bp -- build buffer local buffer = make_buffer(make_readers(scp_file, layer_repo)) -- initialize the network gconf.cnt = 0 err_input = {mat_type(gconf.batch_size, 1)} err_input[1]:fill(1) network:epoch_init() global_transf:epoch_init() for data in buffer.get_data, buffer do -- prine stat periodically gconf.cnt = gconf.cnt + 1 if gconf.cnt == 1000 then print_stat(layer_repo) mat_type.print_profile() mat_type.clear_profile() gconf.cnt = 0 -- break end local input = {} -- if gconf.cnt == 1000 then break end for i, e in ipairs(input_order) do local id = if data[id] == nil then nerv.error("input data %s not found", id) end local transformed if e.global_transf then transformed = nerv.speech_utils.global_transf(data[id], global_transf, gconf.frm_ext or 0, 0, gconf) else transformed = data[id] end table.insert(input, transformed) end local output = {mat_type(gconf.batch_size, 1)} err_output = {} for i = 1, #input do table.insert(err_output, input[i]:create()) end network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1), new_seq = {}, do_train = bp, input = {input}, output = {output}, err_input = {err_input}, err_output = {err_output}}) network:propagate() if bp then network:back_propagate() network:update() end -- collect garbage in-time to save GPU memory collectgarbage("collect") end print_stat(layer_repo) mat_type.print_profile() mat_type.clear_profile() local fname if (not bp) then host_param_repo = param_repo:copy(src_loc_type) if prefix ~= nil then"writing back...") fname = string.format("%s_cv%.3f.nerv", prefix, get_accuracy(layer_repo)) host_param_repo:export(fname, nil) end end return get_accuracy(layer_repo), host_param_repo, fname end return iterative_trainer end local function check_and_add_defaults(spec, opts) local function get_opt_val(k) return opts[string.gsub(k, '_', '-')].val end local opt_v = get_opt_val("resume_from") if opt_v then gconf = dofile(opt_v) else for k, v in pairs(spec) do local opt_v = get_opt_val(k) if opt_v ~= nil then gconf[k] = opt_v elseif gconf[k] ~= nil then elseif v ~= nil then gconf[k] = v end end end end local function make_options(spec) local options = {} for k, v in pairs(spec) do table.insert(options, {string.gsub(k, '_', '-'), nil, type(v), default = v}) end return options end local function print_help(options) nerv.printf("Usage: [options] network_config.lua\n") nerv.print_usage(options) end local function print_gconf() local key_maxlen = 0 for k, v in pairs(gconf) do key_maxlen = math.max(key_maxlen, #k or 0) end local function pattern_gen() return string.format("%%-%ds = %%s\n", key_maxlen) end"ready to train with the following gconf settings:") nerv.printf(pattern_gen(), "Key", "Value") for k, v in pairs(gconf) do nerv.printf(pattern_gen(), k or "", v or "") end end local function dump_gconf(fname) local f =, "w") f:write("return ") f:write(table.tostring(gconf)) f:close() end local trainer_defaults = { lrate = 0.8, batch_size = 256, buffer_size = 81920, wcost = 1e-6, momentum = 0.9, start_halving_inc = 0.5, halving_factor = 0.6, end_halving_inc = 0.1, cur_iter = 1, min_iter = 1, max_iter = 20, min_halving = 5, do_halving = false, cumat_tname = "nerv.CuMatrixFloat", mmat_tname = "nerv.MMatrixFloat", debug = false, } local options = make_options(trainer_defaults) local extra_opt_spec = { {"tr-scp", nil, "string"}, {"cv-scp", nil, "string"}, {"resume-from", nil, "string"}, {"help", "h", "boolean", default = false, desc = "show this help information"}, {"dir", nil, "string", desc = "specify the working directory"}, } table.extend(options, extra_opt_spec) arg, opts = nerv.parse_args(arg, options) if #arg < 1 or opts["help"].val then print_help(options) return end dofile(arg[1]) --[[ Rule: command-line option overrides network config overrides trainer default. Note: config key like aaa_bbbb_cc could be overriden by specifying --aaa-bbbb-cc to command-line arguments. ]]-- check_and_add_defaults(trainer_defaults, opts) gconf.mmat_type = nerv.get_type(gconf.mmat_tname) gconf.cumat_type = nerv.get_type(gconf.cumat_tname) gconf.use_cpu = econf.use_cpu or false local pf0 = gconf.initialized_param local date_pattern = "%Y%m%d%H%M%S" local logfile_name = "log" local working_dir = opts["dir"].val or string.format("nerv_%s", local rebind_param_repo = nil print_gconf() if not lfs.mkdir(working_dir) then nerv.error("[asr_trainer] working directory already exists") end -- copy the network config dir.copyfile(arg[1], working_dir) -- set logfile path nerv.set_logfile(path.join(working_dir, logfile_name)) path.chdir(working_dir) -- start the training local trainer = build_trainer(pf0) local pr_prev gconf.accu_best, pr_prev = trainer(nil, gconf.cv_scp, false)"initial cross validation: %.3f", gconf.accu_best) for i = gconf.cur_iter, gconf.max_iter do local stop = false gconf.cur_iter = i dump_gconf(string.format("iter_%d.meta", i)) repeat -- trick to implement `continue` statement"[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo)"[TR] training set %d: %.3f", i, accu_tr) local param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f", string.gsub( (string.gsub(pf0[1], "(.*/)(.*)", "%2")), "(.*)%..*", "%1"),, i, gconf.lrate, accu_tr) local accu_new, pr_new, param_fname = trainer(param_prefix, gconf.cv_scp, false)"[CV] cross validation %d: %.3f", i, accu_new) local accu_prev = gconf.accu_best if accu_new < gconf.accu_best then"rejecting the trained params, rollback to the previous one") file.move(param_fname, param_fname .. ".rejected") rebind_param_repo = pr_prev break -- `continue` equivalent else"accepting the trained params") gconf.accu_best = accu_new pr_prev = pr_new gconf.initialized_param = {path.join(path.currentdir(), param_fname)} end if gconf.do_halving and gconf.accu_best - accu_prev < gconf.end_halving_inc and i > gconf.min_iter then stop = true break end if gconf.accu_best - accu_prev < gconf.start_halving_inc and i >= gconf.min_halving then gconf.do_halving = true end if gconf.do_halving then gconf.lrate = gconf.lrate * gconf.halving_factor end until true if stop then break end -- nerv.Matrix.print_profile() end