diff options
Diffstat (limited to 'kaldi_decode/src/nerv4decode.lua')
-rw-r--r-- | kaldi_decode/src/nerv4decode.lua | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/nerv4decode.lua new file mode 100644 index 0000000..b2ff344 --- /dev/null +++ b/kaldi_decode/src/nerv4decode.lua @@ -0,0 +1,79 @@ +package.path="/home/slhome/ymz09/.luarocks/share/lua/5.1/?.lua;/home/slhome/ymz09/.luarocks/share/lua/5.1/?/init.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?/init.lua;"..package.path; +package.cpath="/home/slhome/ymz09/.luarocks/lib/lua/5.1/?.so;/slfs6/users/ymz09/nerv-project/nerv/install/lib/lua/5.1/?.so;"..package.cpath; +local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1") +require 'nerv' + +function build_trainer(ifname, feature) + local param_repo = nerv.ParamRepo() + param_repo:import(ifname, nil, gconf) + local layer_repo = make_layer_repo(param_repo) + local network = get_decode_network(layer_repo) + local global_transf = get_global_transf(layer_repo) + local input_order = get_input_order() + local readers = make_readers(feature, layer_repo) + network:init(1) + + local iterative_trainer = function() + local data = nil + for ri = 1, #readers, 1 do + data = readers[ri].reader:get_data() + if data ~= nil then + break + end + end + + if data == nil then + return "", nil + end + + local input = {} + for i, e in ipairs(input_order) do + local id = e.id + if data[id] == nil then + nerv.error("input data %s not found", id) + end + local transformed + if e.global_transf then + local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol()) + batch:copy_fromh(data[id]) + transformed = nerv.speech_utils.global_transf(batch, + global_transf, + gconf.frm_ext or 0, 0, + gconf) + else + transformed = data[id] + end + table.insert(input, transformed) + end + local output = {nerv.CuMatrixFloat(input[1]:nrow(), network.dim_out[1])} + network:batch_resize(input[1]:nrow()) + network:propagate(input, output) + + local utt = data["key"] + if utt == nil then + nerv.error("no key found.") + end + + local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol()) + output[1]:copy_toh(mat) + + collectgarbage("collect") + return utt, mat + end + + return iterative_trainer +end + +function init(config, feature) + local tmp = io.write + io.write = function(...) + end + dofile(config) + trainer = build_trainer(gconf.decode_param, feature) + io.write = tmp +end + +function feed() + local utt, mat = trainer() + return utt, mat +end |