summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/asr_propagator.lua
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_decode/src/asr_propagator.lua')
-rw-r--r--kaldi_decode/src/asr_propagator.lua84
1 files changed, 84 insertions, 0 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua
new file mode 100644
index 0000000..5d0ad7c
--- /dev/null
+++ b/kaldi_decode/src/asr_propagator.lua
@@ -0,0 +1,84 @@
+print = function(...) io.write(table.concat({...}, "\t")) end
+io.output('/dev/null')
+-- path and cpath are correctly set by `path.sh`
+local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
+require 'nerv'
+nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
+nerv.info("automatically initialize a default MContext...")
+nerv.MMatrix._default_context = nerv.MContext()
+nerv.info("the default MContext is ok")
+-- only for backward compatibilty, will be removed in the future
+local function _add_profile_method(cls)
+ local c = cls._default_context
+ cls.print_profile = function () c:print_profile() end
+ cls.clear_profile = function () c:clear_profile() end
+end
+_add_profile_method(nerv.MMatrix)
+
+function build_propagator(ifname, feature)
+ local param_repo = nerv.ParamRepo()
+ param_repo:import(ifname, nil, gconf)
+ local layer_repo = make_layer_repo(param_repo)
+ local network = get_decode_network(layer_repo)
+ local global_transf = get_global_transf(layer_repo)
+ local input_order = get_decode_input_order()
+ local readers = make_decode_readers(feature, layer_repo)
+
+ local batch_propagator = function()
+ local data = nil
+ for ri = 1, #readers do
+ data = readers[ri].reader:get_data()
+ if data ~= nil then
+ break
+ end
+ end
+
+ if data == nil then
+ return "", nil
+ end
+
+ gconf.batch_size = data[input_order[1].id]:nrow()
+ network:init(gconf.batch_size)
+
+ local input = {}
+ for i, e in ipairs(input_order) do
+ local id = e.id
+ if data[id] == nil then
+ nerv.error("input data %s not found", id)
+ end
+ local transformed
+ if e.global_transf then
+ transformed = nerv.speech_utils.global_transf(data[id],
+ global_transf,
+ gconf.frm_ext or 0, 0,
+ gconf)
+ else
+ transformed = data[id]
+ end
+ table.insert(input, transformed)
+ end
+ local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])}
+ network:propagate(input, output)
+
+ local utt = data["key"]
+ if utt == nil then
+ nerv.error("no key found.")
+ end
+
+ collectgarbage("collect")
+ return utt, output[1]
+ end
+
+ return batch_propagator
+end
+
+function init(config, feature)
+ dofile(config)
+ gconf.use_cpu = true -- use CPU to decode
+ trainer = build_propagator(gconf.decode_param, feature)
+end
+
+function feed()
+ local utt, mat = trainer()
+ return utt, mat
+end