summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/nerv4decode.lua
diff options
context:
space:
mode:
authorYimmon Zhuang <[email protected]>2015-10-14 15:37:20 +0800
committerYimmon Zhuang <[email protected]>2015-10-14 15:37:20 +0800
commitb33b3a6732c6b6a66bd5c44c615be56d66f4ed67 (patch)
tree47501412a3324e4c13b1238eeb913aae02b2024a /kaldi_decode/src/nerv4decode.lua
parente39fb231f64ddc8b79a6eb5434f529aadb3165fe (diff)
support kaldi decoder
Diffstat (limited to 'kaldi_decode/src/nerv4decode.lua')
-rw-r--r--kaldi_decode/src/nerv4decode.lua79
1 files changed, 79 insertions, 0 deletions
diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/nerv4decode.lua
new file mode 100644
index 0000000..b2ff344
--- /dev/null
+++ b/kaldi_decode/src/nerv4decode.lua
@@ -0,0 +1,79 @@
+package.path="/home/slhome/ymz09/.luarocks/share/lua/5.1/?.lua;/home/slhome/ymz09/.luarocks/share/lua/5.1/?/init.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?/init.lua;"..package.path;
+package.cpath="/home/slhome/ymz09/.luarocks/lib/lua/5.1/?.so;/slfs6/users/ymz09/nerv-project/nerv/install/lib/lua/5.1/?.so;"..package.cpath;
+local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
+require 'nerv'
+
+function build_trainer(ifname, feature)
+ local param_repo = nerv.ParamRepo()
+ param_repo:import(ifname, nil, gconf)
+ local layer_repo = make_layer_repo(param_repo)
+ local network = get_decode_network(layer_repo)
+ local global_transf = get_global_transf(layer_repo)
+ local input_order = get_input_order()
+ local readers = make_readers(feature, layer_repo)
+ network:init(1)
+
+ local iterative_trainer = function()
+ local data = nil
+ for ri = 1, #readers, 1 do
+ data = readers[ri].reader:get_data()
+ if data ~= nil then
+ break
+ end
+ end
+
+ if data == nil then
+ return "", nil
+ end
+
+ local input = {}
+ for i, e in ipairs(input_order) do
+ local id = e.id
+ if data[id] == nil then
+ nerv.error("input data %s not found", id)
+ end
+ local transformed
+ if e.global_transf then
+ local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol())
+ batch:copy_fromh(data[id])
+ transformed = nerv.speech_utils.global_transf(batch,
+ global_transf,
+ gconf.frm_ext or 0, 0,
+ gconf)
+ else
+ transformed = data[id]
+ end
+ table.insert(input, transformed)
+ end
+ local output = {nerv.CuMatrixFloat(input[1]:nrow(), network.dim_out[1])}
+ network:batch_resize(input[1]:nrow())
+ network:propagate(input, output)
+
+ local utt = data["key"]
+ if utt == nil then
+ nerv.error("no key found.")
+ end
+
+ local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol())
+ output[1]:copy_toh(mat)
+
+ collectgarbage("collect")
+ return utt, mat
+ end
+
+ return iterative_trainer
+end
+
+function init(config, feature)
+ local tmp = io.write
+ io.write = function(...)
+ end
+ dofile(config)
+ trainer = build_trainer(gconf.decode_param, feature)
+ io.write = tmp
+end
+
+function feed()
+ local utt, mat = trainer()
+ return utt, mat
+end