diff options
Diffstat (limited to 'kaldi_decode/src/nerv4decode.lua')
-rw-r--r-- | kaldi_decode/src/nerv4decode.lua | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/nerv4decode.lua index b2ff344..898b5a8 100644 --- a/kaldi_decode/src/nerv4decode.lua +++ b/kaldi_decode/src/nerv4decode.lua @@ -1,7 +1,19 @@ -package.path="/home/slhome/ymz09/.luarocks/share/lua/5.1/?.lua;/home/slhome/ymz09/.luarocks/share/lua/5.1/?/init.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?/init.lua;"..package.path; -package.cpath="/home/slhome/ymz09/.luarocks/lib/lua/5.1/?.so;/slfs6/users/ymz09/nerv-project/nerv/install/lib/lua/5.1/?.so;"..package.cpath; +print = function(...) io.write(table.concat({...}, "\t")) end +io.output('/dev/null') +-- path and cpath are correctly set by `path.sh` local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1") require 'nerv' +nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n") +nerv.info("automatically initialize a default MContext...") +nerv.MMatrix._default_context = nerv.MContext() +nerv.info("the default MContext is ok") +-- only for backward compatibilty, will be removed in the future +local function _add_profile_method(cls) + local c = cls._default_context + cls.print_profile = function () c:print_profile() end + cls.clear_profile = function () c:clear_profile() end +end +_add_profile_method(nerv.MMatrix) function build_trainer(ifname, feature) local param_repo = nerv.ParamRepo() @@ -34,9 +46,7 @@ function build_trainer(ifname, feature) end local transformed if e.global_transf then - local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol()) - batch:copy_fromh(data[id]) - transformed = nerv.speech_utils.global_transf(batch, + transformed = nerv.speech_utils.global_transf(data[id], global_transf, gconf.frm_ext or 0, 0, gconf) @@ -45,7 +55,7 @@ function build_trainer(ifname, feature) end table.insert(input, transformed) end - local output = {nerv.CuMatrixFloat(input[1]:nrow(), network.dim_out[1])} + local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])} network:batch_resize(input[1]:nrow()) network:propagate(input, output) @@ -65,12 +75,9 @@ function build_trainer(ifname, feature) end function init(config, feature) - local tmp = io.write - io.write = function(...) - end dofile(config) + gconf.use_cpu = true -- use CPU to decode trainer = build_trainer(gconf.decode_param, feature) - io.write = tmp end function feed() |