diff options
author | Determinant <[email protected]> | 2016-02-29 20:03:52 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-02-29 20:03:52 +0800 |
commit | 1e0ac0fb5c9f517e7325deb16004de1054454da7 (patch) | |
tree | c75a6f0fc9aa50caa9fb9dccec7a56b41d3b63fd /kaldi_decode/src/nerv4decode.lua | |
parent | fda1c8cf07c5130aff53775454a5f2cfc8f5d2e0 (diff) |
refactor kaldi_decode
Diffstat (limited to 'kaldi_decode/src/nerv4decode.lua')
-rw-r--r-- | kaldi_decode/src/nerv4decode.lua | 86 |
1 files changed, 0 insertions, 86 deletions
diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/nerv4decode.lua deleted file mode 100644 index 898b5a8..0000000 --- a/kaldi_decode/src/nerv4decode.lua +++ /dev/null @@ -1,86 +0,0 @@ -print = function(...) io.write(table.concat({...}, "\t")) end -io.output('/dev/null') --- path and cpath are correctly set by `path.sh` -local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1") -require 'nerv' -nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n") -nerv.info("automatically initialize a default MContext...") -nerv.MMatrix._default_context = nerv.MContext() -nerv.info("the default MContext is ok") --- only for backward compatibilty, will be removed in the future -local function _add_profile_method(cls) - local c = cls._default_context - cls.print_profile = function () c:print_profile() end - cls.clear_profile = function () c:clear_profile() end -end -_add_profile_method(nerv.MMatrix) - -function build_trainer(ifname, feature) - local param_repo = nerv.ParamRepo() - param_repo:import(ifname, nil, gconf) - local layer_repo = make_layer_repo(param_repo) - local network = get_decode_network(layer_repo) - local global_transf = get_global_transf(layer_repo) - local input_order = get_input_order() - local readers = make_readers(feature, layer_repo) - network:init(1) - - local iterative_trainer = function() - local data = nil - for ri = 1, #readers, 1 do - data = readers[ri].reader:get_data() - if data ~= nil then - break - end - end - - if data == nil then - return "", nil - end - - local input = {} - for i, e in ipairs(input_order) do - local id = e.id - if data[id] == nil then - nerv.error("input data %s not found", id) - end - local transformed - if e.global_transf then - transformed = nerv.speech_utils.global_transf(data[id], - global_transf, - gconf.frm_ext or 0, 0, - gconf) - else - transformed = data[id] - end - table.insert(input, transformed) - end - local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])} - network:batch_resize(input[1]:nrow()) - network:propagate(input, output) - - local utt = data["key"] - if utt == nil then - nerv.error("no key found.") - end - - local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol()) - output[1]:copy_toh(mat) - - collectgarbage("collect") - return utt, mat - end - - return iterative_trainer -end - -function init(config, feature) - dofile(config) - gconf.use_cpu = true -- use CPU to decode - trainer = build_trainer(gconf.decode_param, feature) -end - -function feed() - local utt, mat = trainer() - return utt, mat -end |