diff options
Diffstat (limited to 'kaldi_decode/src/nerv4decode.lua')
-rw-r--r-- | kaldi_decode/src/nerv4decode.lua | 79 |
1 files changed, 0 insertions, 79 deletions
diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/nerv4decode.lua deleted file mode 100644 index b2ff344..0000000 --- a/kaldi_decode/src/nerv4decode.lua +++ /dev/null @@ -1,79 +0,0 @@ -package.path="/home/slhome/ymz09/.luarocks/share/lua/5.1/?.lua;/home/slhome/ymz09/.luarocks/share/lua/5.1/?/init.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?/init.lua;"..package.path; -package.cpath="/home/slhome/ymz09/.luarocks/lib/lua/5.1/?.so;/slfs6/users/ymz09/nerv-project/nerv/install/lib/lua/5.1/?.so;"..package.cpath; -local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1") -require 'nerv' - -function build_trainer(ifname, feature) - local param_repo = nerv.ParamRepo() - param_repo:import(ifname, nil, gconf) - local layer_repo = make_layer_repo(param_repo) - local network = get_decode_network(layer_repo) - local global_transf = get_global_transf(layer_repo) - local input_order = get_input_order() - local readers = make_readers(feature, layer_repo) - network:init(1) - - local iterative_trainer = function() - local data = nil - for ri = 1, #readers, 1 do - data = readers[ri].reader:get_data() - if data ~= nil then - break - end - end - - if data == nil then - return "", nil - end - - local input = {} - for i, e in ipairs(input_order) do - local id = e.id - if data[id] == nil then - nerv.error("input data %s not found", id) - end - local transformed - if e.global_transf then - local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol()) - batch:copy_fromh(data[id]) - transformed = nerv.speech_utils.global_transf(batch, - global_transf, - gconf.frm_ext or 0, 0, - gconf) - else - transformed = data[id] - end - table.insert(input, transformed) - end - local output = {nerv.CuMatrixFloat(input[1]:nrow(), network.dim_out[1])} - network:batch_resize(input[1]:nrow()) - network:propagate(input, output) - - local utt = data["key"] - if utt == nil then - nerv.error("no key found.") - end - - local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol()) - output[1]:copy_toh(mat) - - collectgarbage("collect") - return utt, mat - end - - return iterative_trainer -end - -function init(config, feature) - local tmp = io.write - io.write = function(...) - end - dofile(config) - trainer = build_trainer(gconf.decode_param, feature) - io.write = tmp -end - -function feed() - local utt, mat = trainer() - return utt, mat -end |