print = function(...) io.write(table.concat({...}, "\t")) end
io.output('/dev/null')
-- path and cpath are correctly set by `path.sh`
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
local c = cls._default_context
cls.print_profile = function () c:print_profile() end
cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.MMatrix)
function build_propagator(ifname, feature)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_decode_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_decode_input_order()
local readers = make_decode_readers(feature, layer_repo)
local batch_propagator = function()
local data = nil
for ri = 1, #readers do
data = readers[ri].reader:get_data()
if data ~= nil then
break
end
end
if data == nil then
return "", nil
end
gconf.batch_size = data[input_order[1].id]:nrow()
network:init(gconf.batch_size)
local input = {}
for i, e in ipairs(input_order) do
local id = e.id
if data[id] == nil then
nerv.error("input data %s not found", id)
end
local transformed
if e.global_transf then
transformed = nerv.speech_utils.global_transf(data[id],
global_transf,
gconf.frm_ext or 0, 0,
gconf)
else
transformed = data[id]
end
table.insert(input, transformed)
end
local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])}
network:propagate(input, output)
local utt = data["key"]
if utt == nil then
nerv.error("no key found.")
end
collectgarbage("collect")
return utt, output[1]
end
return batch_propagator
end
function init(config, feature)
dofile(config)
gconf.use_cpu = true -- use CPU to decode
trainer = build_propagator(gconf.decode_param, feature)
end
function feed()
local utt, mat = trainer()
return utt, mat
end