print = function(...) io.write(table.concat({...}, "\t")) end
io.output('/dev/null')
-- path and cpath are correctly set by `path.sh`
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
local c = cls._default_context
cls.print_profile = function () c:print_profile() end
cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.MMatrix)
function build_propagator(ifname, feature)
-- FIXME: this is still a hack
local trainer = nerv.Trainer
----
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, gconf)
local layer_repo = trainer.make_layer_repo(nil, param_repo)
local network = trainer.get_decode_network(nil, layer_repo)
local input_order = trainer.get_decode_input_order(nil)
local input_name = gconf.decode_input_name or "main_scp"
local readers = trainer.make_decode_readers(nil, feature)
-- nerv.info("prepare")
local buffer = nerv.SeqBuffer(gconf, {
buffer_size = gconf.buffer_size,
batch_size = gconf.batch_size,
chunk_size = gconf.chunk_size,
randomize = gconf.randomize,
readers = readers,
})
network = nerv.Network("nt", gconf, {network = network})
network:init(gconf.batch_size, gconf.chunk_size)
local prev_data = buffer:get_data() or nerv.error("no data in buffer")
local terminate = false
local input_pos = nil
for i, v in ipairs(input_order) do
if v == input_name then
input_pos = i
end
end
if input_pos == nil then
nerv.error("input name %s not found in the input order list", input_name)
end
local batch_propagator = function()
if terminate then
return "", nil
end
network:epoch_init()
local accu_output = {}
local utt_id = readers[input_pos].reader.key
if utt_id == nil then
nerv.error("no key found.")
end
while true do
local d
if prev_data then
d = prev_data
prev_data = nil
else
d = buffer:get_data()
if d == nil then
terminate = true
break
elseif #d.new_seq > 0 then
prev_data = d -- the first data of the next utterance
break
end
end
local input = {}
local output = {{}}
for i, id in ipairs(input_order) do
if d.data[id] == nil then
nerv.error("input data %s not found", id)
end
table.insert(input, d.data[id])
for i = 1, gconf.chunk_size do
table.insert(output[1], gconf.mmat_type(gconf.batch_size, network.dim_out[1]))
end
end
--nerv.info("input num: %d\nmat: %s\n", #input[1], input[1][1])
--nerv.info("output num: %d\nmat: %s\n", #output[1], output[1][1])
network:mini_batch_init({seq_length = d.seq_length,
new_seq = d.new_seq,
do_train = false,
input = input,
output = output,
err_input = {},
err_output = {}})
network:propagate()
for i, v in ipairs(output[1]) do
--nerv.info(gconf.mask[i])
if gconf.mask[i][0][0] > 0 then -- is not a hole
table.insert(accu_output, v)
--nerv.info("input: %s\noutput: %s\n", input[1][i], output[1][i])
end
end
end
local utt_matrix = gconf.mmat_type(#accu_output, accu_output[1]:ncol())
for i, v in ipairs(accu_output) do
utt_matrix:copy_from(v, 0, v:nrow(), i - 1)
end
--nerv.info(utt_matrix)
collectgarbage("collect")
nerv.info("propagated %d features of an utterance", utt_matrix:nrow())
return utt_id, utt_matrix
end
return batch_propagator
end
function init(config, feature)
dofile(config)
gconf.mmat_type = nerv.MMatrixFloat
gconf.use_cpu = true -- use CPU to decode
gconf.batch_size = 1
propagator = build_propagator(gconf.decode_param, feature)
end
function feed()
local utt, mat = propagator()
return utt, mat
end