print = function(...) io.write(table.concat({...}, "\t")) end
io.output('/dev/null')
-- path and cpath are correctly set by `path.sh`
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
local c = cls._default_context
cls.print_profile = function () c:print_profile() end
cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.MMatrix)
function build_propagator(ifname, feature)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_decode_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_decode_input_order()
local readers = make_decode_readers(feature, layer_repo)
network = nerv.Network("nt", gconf, {network = network})
global_transf = nerv.Network("gt", gconf, {network = global_transf})
local batch_propagator = function()
local data = nil
for ri = 1, #readers do
data = readers[ri].reader:get_data()
if data ~= nil then
break
end
end
if data == nil then
return "", nil
end
gconf.batch_size = data[input_order[1].id]:nrow()
global_transf:init(gconf.batch_size, 1)
global_transf:epoch_init()
network:init(gconf.batch_size, 1)
network:epoch_init()
local input = {}
for i, e in ipairs(input_order) do
local id = e.id
if data[id] == nil then
nerv.error("input data %s not found", id)
end
local transformed
if e.global_transf then
transformed = nerv.speech_utils.global_transf(data[id],
global_transf,
gconf.frm_ext or 0, 0,
gconf)
else
transformed = data[id]
end
table.insert(input, {transformed})
end
local output = {{nerv.MMatrixFloat(input[1][1]:nrow(), network.dim_out[1])}}
network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
new_seq = {},
do_train = false,
input = input,
output = output,
err_input = {},
err_output = {}})
network:propagate()
local utt = data["key"]
if utt == nil then
nerv.error("no key found.")
end
collectgarbage("collect")
return utt, output[1][1]
end
return batch_propagator
end
function init(config, feature)
dofile(config)
gconf.mmat_type = nerv.MMatrixFloat
gconf.use_cpu = true -- use CPU to decode
trainer = build_propagator(gconf.decode_param, feature)
end
function feed()
local utt, mat = trainer()
return utt, mat
end