require 'lfs' require 'pl' local function build_trainer(ifname) local host_param_repo = nerv.ParamRepo() local mat_type local src_loc_type local train_loc_type host_param_repo:import(ifname, gconf) if gconf.use_cpu then mat_type = gconf.mmat_type src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST else mat_type = gconf.cumat_type src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE end local param_repo = host_param_repo:copy(train_loc_type, gconf) local layer_repo = make_layer_repo(param_repo) local network = get_network(layer_repo) local global_transf = get_global_transf(layer_repo) local input_order = get_input_order() network = nerv.Network("nt", gconf, {network = network}) network:init(gconf.batch_size, gconf.chunk_size) global_transf = nerv.Network("gt", gconf, {network = global_transf}) global_transf:init(gconf.batch_size, gconf.chunk_size) local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo) -- rebind the params if necessary if rebind_param_repo then host_param_repo = rebind_param_repo param_repo = host_param_repo:copy(train_loc_type, gconf) layer_repo:rebind(param_repo) rebind_param_repo = nil end gconf.randomize = bp -- build buffer local buffer = make_buffer(make_readers(scp_file, layer_repo)) -- initialize the network gconf.cnt = 0 local err_input = {{}} local output = {{}} for i = 1, gconf.chunk_size do local mini_batch = mat_type(gconf.batch_size, 1) mini_batch:fill(1) table.insert(err_input[1], mini_batch) table.insert(output[1], mat_type(gconf.batch_size, 1)) end network:epoch_init() global_transf:epoch_init() for d in buffer.get_data, buffer do -- prine stat periodically gconf.cnt = gconf.cnt + 1 if gconf.cnt == 1000 then print_stat(layer_repo) mat_type.print_profile() mat_type.clear_profile() gconf.cnt = 0 -- break end local input = {} local err_output = {} -- if gconf.cnt == 1000 then break end for i, e in ipairs(input_order) do local id = e.id if d.data[id] == nil then nerv.error("input data %s not found", id) end local transformed = {} local err_output_i = {} if e.global_transf then for _, mini_batch in ipairs(d.data[id]) do table.insert(transformed, nerv.speech_utils.global_transf(mini_batch, global_transf, gconf.frm_ext or 0, 0, gconf)) end else transformed = d.data[id] end for _, mini_batch in ipairs(transformed) do table.insert(err_output_i, mini_batch:create()) end table.insert(err_output, err_output_i) table.insert(input, transformed) end network:mini_batch_init({seq_length = d.seq_length, new_seq = d.new_seq, do_train = bp, input = input, output = output, err_input = err_input, err_output = err_output}) network:propagate() if bp then network:back_propagate() network:update() end -- collect garbage in-time to save GPU memory collectgarbage("collect") end print_stat(layer_repo) mat_type.print_profile() mat_type.clear_profile() local fname if (not bp) then -- host_param_repo = param_repo:copy(src_loc_type) host_param_repo = nerv.ParamRepo.merge({network:get_params(), global_transf:get_params()}, train_loc_type) :copy(src_loc_type, gconf) if prefix ~= nil then nerv.info("writing back...") fname = string.format("%s_cv%.3f.nerv", prefix, get_accuracy(layer_repo)) host_param_repo:export(fname, nil) end end return get_accuracy(layer_repo), host_param_repo, fname end return iterative_trainer end local function check_and_add_defaults(spec, opts) local function get_opt_val(k) local k = string.gsub(k, '_', '-') return opts[k].val, opts[k].specified end local opt_v = get_opt_val("resume_from") if opt_v then nerv.info("resuming from previous training state") gconf = dofile(opt_v) else for k, v in pairs(spec) do local opt_v, specified = get_opt_val(k) if (not specified) and gconf[k] ~= nil then nerv.info("using setting in network config file: %s = %s", k, gconf[k]) elseif opt_v ~= nil then nerv.info("using setting in options: %s = %s", k, opt_v) gconf[k] = opt_v end end end end local function make_options(spec) local options = {} for k, v in pairs(spec) do table.insert(options, {string.gsub(k, '_', '-'), nil, type(v), default = v}) end return options end local function print_help(options) nerv.printf("Usage: [options] network_config.lua\n") nerv.print_usage(options) end local function print_gconf() local key_maxlen = 0 for k, v in pairs(gconf) do key_maxlen = math.max(key_maxlen, #k or 0) end local function pattern_gen() return string.format("%%-%ds = %%s\n", key_maxlen) end nerv.info("ready to train with the following gconf settings:") nerv.printf(pattern_gen(), "Key", "Value") for k, v in pairs(gconf) do nerv.printf(pattern_gen(), k or "", v or "") end end local function dump_gconf(fname) local f = io.open(fname, "w") f:write("return ") f:write(table.tostring(gconf)) f:close() end local trainer_defaults = { lrate = 0.8, batch_size = 256, chunk_size = 1, buffer_size = 81920, wcost = 1e-6, momentum = 0.9, start_halving_inc = 0.5, halving_factor = 0.6, end_halving_inc = 0.1, cur_iter = 1, min_iter = 1, max_iter = 20, min_halving = 5, do_halving = false, cumat_tname = "nerv.CuMatrixFloat", mmat_tname = "nerv.MMatrixFloat", debug = false, } local options = make_options(trainer_defaults) local extra_opt_spec = { {"tr-scp", nil, "string"}, {"cv-scp", nil, "string"}, {"resume-from", nil, "string"}, {"help", "h", "boolean", default = false, desc = "show this help information"}, {"dir", nil, "string", desc = "specify the working directory"}, } table.extend(options, extra_opt_spec) arg, opts = nerv.parse_args(arg, options) if #arg < 1 or opts["help"].val then print_help(options) return end dofile(arg[1]) --[[ Rule: command-line option overrides network config overrides trainer default. Note: config key like aaa_bbbb_cc could be overriden by specifying --aaa-bbbb-cc to command-line arguments. ]]-- check_and_add_defaults(trainer_defaults, opts) gconf.mmat_type = nerv.get_type(gconf.mmat_tname) gconf.cumat_type = nerv.get_type(gconf.cumat_tname) gconf.use_cpu = econf.use_cpu or false local pf0 = gconf.initialized_param local date_pattern = "%Y%m%d%H%M%S" local logfile_name = "log" local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern)) local rebind_param_repo = nil print_gconf() if not lfs.mkdir(working_dir) then nerv.error("[asr_trainer] working directory already exists") end -- copy the network config dir.copyfile(arg[1], working_dir) -- set logfile path nerv.set_logfile(path.join(working_dir, logfile_name)) --path.chdir(working_dir) -- start the training local trainer = build_trainer(pf0) local pr_prev gconf.accu_best, pr_prev = trainer(nil, gconf.cv_scp, false) nerv.info("initial cross validation: %.3f", gconf.accu_best) for i = gconf.cur_iter, gconf.max_iter do local stop = false gconf.cur_iter = i dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i))) repeat -- trick to implement `continue` statement nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo) nerv.info("[TR] training set %d: %.3f", i, accu_tr) local param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f", string.gsub( (string.gsub(pf0[1], "(.*/)(.*)", "%2")), "(.*)%..*", "%1"), os.date(date_pattern), i, gconf.lrate, accu_tr) local accu_new, pr_new, param_fname = trainer(path.join(working_dir, param_prefix), gconf.cv_scp, false) nerv.info("[CV] cross validation %d: %.3f", i, accu_new) local accu_prev = gconf.accu_best if accu_new < gconf.accu_best then nerv.info("rejecting the trained params, rollback to the previous one") file.move(param_fname, param_fname .. ".rejected") rebind_param_repo = pr_prev break -- `continue` equivalent else nerv.info("accepting the trained params") gconf.accu_best = accu_new pr_prev = pr_new gconf.initialized_param = {path.join(path.currentdir(), param_fname)} end if gconf.do_halving and gconf.accu_best - accu_prev < gconf.end_halving_inc and i > gconf.min_iter then stop = true break end if gconf.accu_best - accu_prev < gconf.start_halving_inc and i >= gconf.min_halving then gconf.do_halving = true end if gconf.do_halving then gconf.lrate = gconf.lrate * gconf.halving_factor end until true if stop then break end -- nerv.Matrix.print_profile() end