diff options
Diffstat (limited to 'kaldi_decode/src/asr_propagator.lua')
-rw-r--r-- | kaldi_decode/src/asr_propagator.lua | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua new file mode 100644 index 0000000..5d0ad7c --- /dev/null +++ b/kaldi_decode/src/asr_propagator.lua @@ -0,0 +1,84 @@ +print = function(...) io.write(table.concat({...}, "\t")) end +io.output('/dev/null') +-- path and cpath are correctly set by `path.sh` +local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1") +require 'nerv' +nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n") +nerv.info("automatically initialize a default MContext...") +nerv.MMatrix._default_context = nerv.MContext() +nerv.info("the default MContext is ok") +-- only for backward compatibilty, will be removed in the future +local function _add_profile_method(cls) + local c = cls._default_context + cls.print_profile = function () c:print_profile() end + cls.clear_profile = function () c:clear_profile() end +end +_add_profile_method(nerv.MMatrix) + +function build_propagator(ifname, feature) + local param_repo = nerv.ParamRepo() + param_repo:import(ifname, nil, gconf) + local layer_repo = make_layer_repo(param_repo) + local network = get_decode_network(layer_repo) + local global_transf = get_global_transf(layer_repo) + local input_order = get_decode_input_order() + local readers = make_decode_readers(feature, layer_repo) + + local batch_propagator = function() + local data = nil + for ri = 1, #readers do + data = readers[ri].reader:get_data() + if data ~= nil then + break + end + end + + if data == nil then + return "", nil + end + + gconf.batch_size = data[input_order[1].id]:nrow() + network:init(gconf.batch_size) + + local input = {} + for i, e in ipairs(input_order) do + local id = e.id + if data[id] == nil then + nerv.error("input data %s not found", id) + end + local transformed + if e.global_transf then + transformed = nerv.speech_utils.global_transf(data[id], + global_transf, + gconf.frm_ext or 0, 0, + gconf) + else + transformed = data[id] + end + table.insert(input, transformed) + end + local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])} + network:propagate(input, output) + + local utt = data["key"] + if utt == nil then + nerv.error("no key found.") + end + + collectgarbage("collect") + return utt, output[1] + end + + return batch_propagator +end + +function init(config, feature) + dofile(config) + gconf.use_cpu = true -- use CPU to decode + trainer = build_propagator(gconf.decode_param, feature) +end + +function feed() + local utt, mat = trainer() + return utt, mat +end |