summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/nerv4decode.lua
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_decode/src/nerv4decode.lua')
-rw-r--r--kaldi_decode/src/nerv4decode.lua86
1 files changed, 0 insertions, 86 deletions
diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/nerv4decode.lua
deleted file mode 100644
index 898b5a8..0000000
--- a/kaldi_decode/src/nerv4decode.lua
+++ /dev/null
@@ -1,86 +0,0 @@
-print = function(...) io.write(table.concat({...}, "\t")) end
-io.output('/dev/null')
--- path and cpath are correctly set by `path.sh`
-local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
-require 'nerv'
-nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
-nerv.info("automatically initialize a default MContext...")
-nerv.MMatrix._default_context = nerv.MContext()
-nerv.info("the default MContext is ok")
--- only for backward compatibilty, will be removed in the future
-local function _add_profile_method(cls)
- local c = cls._default_context
- cls.print_profile = function () c:print_profile() end
- cls.clear_profile = function () c:clear_profile() end
-end
-_add_profile_method(nerv.MMatrix)
-
-function build_trainer(ifname, feature)
- local param_repo = nerv.ParamRepo()
- param_repo:import(ifname, nil, gconf)
- local layer_repo = make_layer_repo(param_repo)
- local network = get_decode_network(layer_repo)
- local global_transf = get_global_transf(layer_repo)
- local input_order = get_input_order()
- local readers = make_readers(feature, layer_repo)
- network:init(1)
-
- local iterative_trainer = function()
- local data = nil
- for ri = 1, #readers, 1 do
- data = readers[ri].reader:get_data()
- if data ~= nil then
- break
- end
- end
-
- if data == nil then
- return "", nil
- end
-
- local input = {}
- for i, e in ipairs(input_order) do
- local id = e.id
- if data[id] == nil then
- nerv.error("input data %s not found", id)
- end
- local transformed
- if e.global_transf then
- transformed = nerv.speech_utils.global_transf(data[id],
- global_transf,
- gconf.frm_ext or 0, 0,
- gconf)
- else
- transformed = data[id]
- end
- table.insert(input, transformed)
- end
- local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])}
- network:batch_resize(input[1]:nrow())
- network:propagate(input, output)
-
- local utt = data["key"]
- if utt == nil then
- nerv.error("no key found.")
- end
-
- local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol())
- output[1]:copy_toh(mat)
-
- collectgarbage("collect")
- return utt, mat
- end
-
- return iterative_trainer
-end
-
-function init(config, feature)
- dofile(config)
- gconf.use_cpu = true -- use CPU to decode
- trainer = build_trainer(gconf.decode_param, feature)
-end
-
-function feed()
- local utt, mat = trainer()
- return utt, mat
-end