local function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_input_order()
local mat_type
if gconf.use_cpu then
mat_type = gconf.mmat_type
else
mat_type = gconf.cumat_type
end
local iterative_trainer = function (prefix, scp_file, bp)
gconf.randomize = bp
-- build buffer
local buffer = make_buffer(make_readers(scp_file, layer_repo))
-- initialize the network
network:init(gconf.batch_size)
gconf.cnt = 0
err_input = {mat_type(gconf.batch_size, 1)}
err_input[1]:fill(1)
for data in buffer.get_data, buffer do
-- prine stat periodically
gconf.cnt = gconf.cnt + 1
if gconf.cnt == 1000 then
print_stat(layer_repo)
mat_type.print_profile()
mat_type.clear_profile()
gconf.cnt = 0
-- break
end
local input = {}
-- if gconf.cnt == 1000 then break end
for i, e in ipairs(input_order) do
local id = e.id
if data[id] == nil then
nerv.error("input data %s not found", id)
end
local transformed
if e.global_transf then
transformed = nerv.speech_utils.global_transf(data[id],
global_transf,
gconf.frm_ext or 0, 0,
gconf)
else
transformed = data[id]
end
table.insert(input, transformed)
end
local output = {mat_type(gconf.batch_size, 1)}
err_output = {}
for i = 1, #input do
table.insert(err_output, input[i]:create())
end
network:propagate(input, output)
if bp then
network:back_propagate(err_input, err_output, input, output)
network:update(err_input, input, output)
end
-- collect garbage in-time to save GPU memory
collectgarbage("collect")
end
print_stat(layer_repo)
mat_type.print_profile()
mat_type.clear_profile()
if (not bp) and prefix ~= nil then
nerv.info("writing back...")
local fname = string.format("%s_cv%.3f.nerv",
prefix, get_accuracy(layer_repo))
network:get_params():export(fname, nil)
end
return get_accuracy(layer_repo)
end
return iterative_trainer
end
local function check_and_add_defaults(spec)
for k, v in pairs(spec) do
gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v
end
end
local function make_options(spec)
local options = {}
for k, v in pairs(spec) do
table.insert(options,
{string.gsub(k, '_', '-'), nil, type(v), default = v})
end
return options
end
local function print_help(options)
nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
nerv.print_usage(options)
end
local function print_gconf()
local key_maxlen = 0
for k, v in pairs(gconf) do
key_maxlen = math.max(key_maxlen, #k or 0)
end
local function pattern_gen()
return string.format("%%-%ds = %%s\n", key_maxlen)
end
nerv.info("ready to train with the following gconf settings:")
nerv.printf(pattern_gen(), "Key", "Value")
for k, v in pairs(gconf) do
nerv.printf(pattern_gen(), k or "", v or "")
end
end
local trainer_defaults = {
lrate = 0.8,
batch_size = 256,
buffer_size = 81920,
wcost = 1e-6,
momentum = 0.9,
start_halving_inc = 0.5,
halving_factor = 0.6,
end_halving_inc = 0.1,
min_iter = 1,
max_iter = 20,
min_halving = 5,
do_halving = false,
tr_scp = nil,
cv_scp = nil,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMat