require 'lfs'
require 'pl'
local function build_trainer(ifname)
local host_param_repo = nerv.ParamRepo()
local mat_type
local src_loc_type
local train_loc_type
host_param_repo:import(ifname, nil, gconf)
if gconf.use_cpu then
mat_type = gconf.mmat_type
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
else
mat_type = gconf.cumat_type
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
local param_repo = host_param_repo:copy(train_loc_type)
local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_input_order()
network = nerv.Network("nt", gconf, {network = network})
network:init(gconf.batch_size, 1)
global_transf = nerv.Network("gt", gconf, {network = global_transf})
global_transf:init(gconf.batch_size, 1)
local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo)
-- rebind the params if necessary
if rebind_param_repo then
host_param_repo = rebind_param_repo
param_repo = host_param_repo:copy(train_loc_type)
layer_repo:rebind(param_repo)
rebind_param_repo = nil
end
gconf.randomize = bp
-- build buffer
local buffer = make_buffer(make_readers(scp_file, layer_repo))
-- initialize the network
gconf.cnt = 0
err_input = {mat_type(gconf.batch_size, 1)}
err_input[1]:fill(1)
network:epoch_init()
global_transf:epoch_init()
for data in buffer.get_data, buffer do
-- prine stat periodically
gconf.cnt = gconf.cnt + 1
if gconf.cnt == 1000 then
print_stat(layer_repo)
mat_type.print_profile()
mat_type.clear_profile()
gconf.cnt = 0
-- break
end
local input = {}
-- if gconf.cnt == 1000 then break end
for i, e in ipairs(input_order) do
local id = e.id
if data[id] == nil then
nerv.error("input data %s not found", id)
end
local transformed
if e.global_transf then
transformed = nerv.speech_utils.global_transf(data[id],
global_transf,
gconf.frm_ext or 0, 0,
gconf)
else
transformed = data[id]
end
table.insert(input, transformed)
end
local output = {mat_type(gconf.batch_size, 1)}
err_output = {}
for i = 1, #input do
table.insert(err_output, input[i]:create())
end
network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
new_seq = {},
do_train = bp,
input = {input},
output = {output},
err_input = {err_input},
err_output = {err_output}})
network:propagate()
if bp then
network:back_propagate()
network:update()
end
-- collect garbage in-time to save GPU memory
collectgarbage("collect")
end
print_stat(layer_repo)
mat_type.print_profile()
mat_type.clear_profile()
local fname
if (not bp) then
host_param_repo = param_repo:copy(src_loc_type)
if prefix ~= nil then
nerv.info("writing back...")
fname = string.format("%s_cv%.3f.nerv",
prefix, get_accuracy(layer_repo))
host_param_repo:export(fname, nil)
end
end
return get_accuracy(layer_repo), host_param_repo, fname
end
return<