require 'lfs'
require 'pl'
local function build_trainer(ifname)
local host_param_repo = nerv.ParamRepo()
local mat_type
local src_loc_type
local train_loc_type
host_param_repo:import(ifname, gconf)
if gconf.use_cpu then
mat_type = gconf.mmat_type
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
else
mat_type = gconf.cumat_type
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
local param_repo = host_param_repo:copy(train_loc_type, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_input_order()
network = nerv.Network("nt", gconf, {network = network})
network:init(gconf.batch_size, gconf.chunk_size)
global_transf = nerv.Network("gt", gconf, {network = global_transf})
global_transf:init(gconf.batch_size, gconf.chunk_size)
local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo)
-- rebind the params if necessary
if rebind_param_repo then
host_param_repo = rebind_param_repo
param_repo = host_param_repo:copy(train_loc_type, gconf)
layer_repo:rebind(param_repo)
rebind_param_repo = nil
end
gconf.randomize = bp
-- build buffer
local buffer = make_buffer(make_readers(scp_file, layer_repo))
-- initialize the network
gconf.cnt = 0
local output = {{}}
for i = 1, gconf.chunk_size do
table.insert(output[1], mat_type(gconf.batch_size, 1))
end
network:epoch_init()
global_transf:epoch_init()
for d in buffer.get_data, buffer do
-- prine stat periodically
gconf.cnt = gconf.cnt + 1
if gconf.cnt == 1000 then
print_stat(layer_repo)
mat_type.print_profile()
mat_type.clear_profile()
gconf.cnt = 0
-- break
end
local input = {}
local err_output = {}
-- if gconf.cnt == 1000 then break end
for i, e in ipairs(input_order) do
local id = e.id
if d.data[id] == nil then
nerv.error("input data %s not found", id)
end
local transformed = {}
local err_output_i = {}
if e.global_transf then
for _, mini_batch in ipairs(d.data[id]) do
table.insert(transformed,
nerv.speech_utils.global_transf(mini_batch,
global_transf,
gconf.frm_ext or 0, 0,
gconf))
end
else
transformed = d.data[id]
end
for _, mini_batch in ipairs(transformed) do
table.insert(err_output_i, mini_batch:create())
end
table.insert(err_output, err_output_i)
table.insert(input, transformed)
end
network:mini_batch_init({seq_length = d.seq_length,
new_seq = d.new_seq,
do_train = bp,
input = input,
output = output,
err_input = {gconf.mask},
err_output = err_output})
network:propagate()
if bp then
network:back_propagate()
network:update()
end
-- collect garbage in-time to save GPU memory
collectgarbage("collect")
end
print_stat(layer_repo)
mat_type.print_profile()
mat_type.clear_profile()
local fname
if (not bp) then
-- host_param_repo = param_repo:copy(src_loc_type)
host_param_repo = nerv.ParamRepo.merge({network:get_params(),
global_transf:get_params()},
train_loc_type)
:copy(src_loc_type, gconf)
if prefix ~= nil then
nerv.info("writing back...")
fname = string.format("%s_cv%.3f.nerv",
prefix, get_accuracy(layer_repo))
host_param_repo:export(fname, nil)
end
end
return get_accuracy(layer_repo), host_param_repo, fname
end
return iterative_trainer
end
local function check_and_add_defaults(spec, opts)
local function get_opt_val(k)
local k = string.gsub(k, '_', '-')
return opts[k].val, opts[k].specified
end
local opt_v = get_opt_val("resume_from")
if opt_v then
nerv.info("resuming from previous training state")
gconf = dofile(opt_v)
else
for k, v in pairs(spec) do
local opt_v, specified = get_opt_val(k)
if (not specified) and gconf[k] ~= nil then
nerv.info("using setting in network config file: %s = %s", k, gconf[k])
elseif opt_v ~= nil then
nerv.info("using setting in options: %s = %s", k, opt_v)
gconf[k] = opt_v
end
end
end
end
local function make_options(spec)
local options = {}
for k, v in pairs(spec) do
table.insert(options,
{string.gsub(k, '_', '-'), nil, type(v), default = v})
end
return options
end
local function print_help(options)
nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
nerv.print_usage(options)
end
local function print_gconf()
local key_maxlen = 0
for k, v in pairs(gconf) do
key_maxlen = math.max(key_maxlen, #k or 0)
end
local function pattern_gen()
return string.format("%%-%ds = %%s\n", key_maxlen)
end
nerv.info("ready to train with the following gconf settings:")
nerv.printf(pattern_gen(), "Key", "Value")
for k, v in pairs(gconf) do
nerv.printf(pattern_gen(), k or "", v or "")
end
end
local function dump_gconf(fname)
local f = io.open(fname, "w")
f:write("return ")
f:write(table.tostring(gconf))
f:close()
end
local trainer_defaults = {
lrate = 0.8,
batch_size = 256,
chunk_size = 1,
buffer_size = 81920,
wcost = 1e-6,
momentum = 0.9,
start_halving_inc = 0.5,
halving_factor = 0.6,
end_halving_inc = 0.1,
cur_iter = 1,
min_iter = 1,
max_iter = 20,
min_halving = 5,
do_halving = false,
keep_halving = false,
cumat_tname = "nerv.CuMatrixFloat",
mmat_tname = "nerv.MMatrixFloat",
debug = false,
}
local options = make_options(trainer_defaults)
local extra_opt_spec = {
{"tr-scp", nil, "string"},
{"cv-scp", nil, "string"},
{"resume-from", nil, "string"},
{"help", "h", "boolean", default = false, desc = "show this help information"},
{"dir", nil, "string", desc = "specify the working directory"},
}
table.extend(options, extra_opt_spec)
arg, opts = nerv.parse_args(arg, options)
if #arg < 1 or opts["help"].val then
print_help(options)
return
end
dofile(arg[1])
--[[
Rule: command-line option overrides network config overrides trainer default.
Note: config key like aaa_bbbb_cc could be overriden by specifying
--aaa-bbbb-cc to command-line arguments.
]]--
check_and_add_defaults(trainer_defaults, opts)
gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
gconf.use_cpu = econf.use_cpu or false
local pf0 = gconf.initialized_param
local date_pattern = "%Y%m%d%H%M%S"
local logfile_name = "log"
local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
local rebind_param_repo = nil
print_gconf()
if not lfs.mkdir(working_dir) then
nerv.error("[asr_trainer] working directory already exists")
end
-- copy the network config
dir.copyfile(arg[1], working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))
--path.chdir(working_dir)
-- start the training
local trainer = build_trainer(pf0)
local pr_prev
-- initial cross-validation
local param_prefix = string.format("%s_%s",
string.gsub(
(string.gsub(pf0[1], "(.*/)(.*)", "%2")),
"(.*)%..*", "%1"),
os.date(date_pattern))
gconf.accu_best, pr_prev = trainer(path.join(working_dir, param_prefix), gconf.cv_scp, false)
nerv.info("initial cross validation: %.3f", gconf.accu_best)
-- main loop
for i = gconf.cur_iter, gconf.max_iter do
local stop = false
gconf.cur_iter = i
dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i)))
repeat -- trick to implement `continue` statement
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)
param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f",
string.gsub(
(string.gsub(pf0[1], "(.*/)(.*)", "%2")),
"(.*)%..*", "%1"),
os.date(date_pattern),
i, gconf.lrate,
accu_tr)
local accu_new, pr_new, param_fname =
trainer(path.join(working_dir, param_prefix), gconf.cv_scp, false)
nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
local accu_prev = gconf.accu_best
if accu_new < gconf.accu_best then
nerv.info("rejecting the trained params, rollback to the previous one")
file.move(param_fname, param_fname .. ".rejected")
rebind_param_repo = pr_prev
break -- `continue` equivalent
else
nerv.info("accepting the trained params")
gconf.accu_best = accu_new
pr_prev = pr_new
gconf.initialized_param = {path.join(path.currentdir(), param_fname)}
end
if gconf.do_halving and
gconf.accu_best - accu_prev < gconf.end_halving_inc and
i > gconf.min_iter then
stop = true
break
end
if gconf.accu_best - accu_prev < gconf.start_halving_inc and
i >= gconf.min_halving then
gconf.do_halving = true
elseif not gconf.keep_halving then
gconf.do_halving = false
end
if gconf.do_halving then
gconf.lrate = gconf.lrate * gconf.halving_factor
end
until true
if stop then break end
-- nerv.Matrix.print_profile()
end