summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/nerv4decode.lua
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_decode/src/nerv4decode.lua')
-rw-r--r--kaldi_decode/src/nerv4decode.lua79
1 files changed, 0 insertions, 79 deletions
diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/nerv4decode.lua
deleted file mode 100644
index b2ff344..0000000
--- a/kaldi_decode/src/nerv4decode.lua
+++ /dev/null
@@ -1,79 +0,0 @@
-package.path="/home/slhome/ymz09/.luarocks/share/lua/5.1/?.lua;/home/slhome/ymz09/.luarocks/share/lua/5.1/?/init.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?/init.lua;"..package.path;
-package.cpath="/home/slhome/ymz09/.luarocks/lib/lua/5.1/?.so;/slfs6/users/ymz09/nerv-project/nerv/install/lib/lua/5.1/?.so;"..package.cpath;
-local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
-require 'nerv'
-
-function build_trainer(ifname, feature)
- local param_repo = nerv.ParamRepo()
- param_repo:import(ifname, nil, gconf)
- local layer_repo = make_layer_repo(param_repo)
- local network = get_decode_network(layer_repo)
- local global_transf = get_global_transf(layer_repo)
- local input_order = get_input_order()
- local readers = make_readers(feature, layer_repo)
- network:init(1)
-
- local iterative_trainer = function()
- local data = nil
- for ri = 1, #readers, 1 do
- data = readers[ri].reader:get_data()
- if data ~= nil then
- break
- end
- end
-
- if data == nil then
- return "", nil
- end
-
- local input = {}
- for i, e in ipairs(input_order) do
- local id = e.id
- if data[id] == nil then
- nerv.error("input data %s not found", id)
- end
- local transformed
- if e.global_transf then
- local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol())
- batch:copy_fromh(data[id])
- transformed = nerv.speech_utils.global_transf(batch,
- global_transf,
- gconf.frm_ext or 0, 0,
- gconf)
- else
- transformed = data[id]
- end
- table.insert(input, transformed)
- end
- local output = {nerv.CuMatrixFloat(input[1]:nrow(), network.dim_out[1])}
- network:batch_resize(input[1]:nrow())
- network:propagate(input, output)
-
- local utt = data["key"]
- if utt == nil then
- nerv.error("no key found.")
- end
-
- local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol())
- output[1]:copy_toh(mat)
-
- collectgarbage("collect")
- return utt, mat
- end
-
- return iterative_trainer
-end
-
-function init(config, feature)
- local tmp = io.write
- io.write = function(...)
- end
- dofile(config)
- trainer = build_trainer(gconf.decode_param, feature)
- io.write = tmp
-end
-
-function feed()
- local utt, mat = trainer()
- return utt, mat
-end