diff options
author | Determinant <[email protected]> | 2016-02-29 20:03:52 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-02-29 20:03:52 +0800 |
commit | 1e0ac0fb5c9f517e7325deb16004de1054454da7 (patch) | |
tree | c75a6f0fc9aa50caa9fb9dccec7a56b41d3b63fd /kaldi_decode/src | |
parent | fda1c8cf07c5130aff53775454a5f2cfc8f5d2e0 (diff) |
refactor kaldi_decode
Diffstat (limited to 'kaldi_decode/src')
-rw-r--r-- | kaldi_decode/src/Makefile | 12 | ||||
-rw-r--r-- | kaldi_decode/src/asr_propagator.lua (renamed from kaldi_decode/src/nerv4decode.lua) | 30 | ||||
-rw-r--r-- | kaldi_decode/src/nnet-forward.cc | 12 |
3 files changed, 20 insertions, 34 deletions
diff --git a/kaldi_decode/src/Makefile b/kaldi_decode/src/Makefile deleted file mode 100644 index 0897798..0000000 --- a/kaldi_decode/src/Makefile +++ /dev/null @@ -1,12 +0,0 @@ -# Change KDIR to `kaldi-trunk' path (Kaldi must be compiled with --share) -KDIR := /speechlab/tools/KALDI/kaldi-master/ -NERVDIR := /speechlab/users/mfy43/nerv/ -CUDADIR := /usr/local/cuda/ - -nnet-forward: - g++ -msse -msse2 -Wall -I $(KDIR)/src/ -pthread -DKALDI_DOUBLEPRECISION=0 -DHAVE_POSIX_MEMALIGN -Wno-sign-compare -Wno-unused-local-typedefs -Winit-self -DHAVE_EXECINFO_H=1 -rdynamic -DHAVE_CXXABI_H -DHAVE_ATLAS -I $(KDIR)/tools/ATLAS/include -I $(KDIR)/tools/openfst/include -Wno-sign-compare -g -fPIC -DHAVE_CUDA -I $(CUDADIR)/include -DKALDI_NO_EXPF -I $(NERVDIR)/install//include/luajit-2.0/ -I $(NERVDIR)/install/include/ -DLUA_USE_APICHECK -c -o nnet-forward.o nnet-forward.cc - g++ -rdynamic -Wl,-rpath=$(KDIR)/tools/openfst/lib -L$(CUDADIR)/lib64 -Wl,-rpath,$(CUDADIR)/lib64 -Wl,-rpath=$(KDIR)/src/lib -L. -L$(KDIR)/src/nnet/ -L$(KDIR)/src/cudamatrix/ -L$(KDIR)/src/lat/ -L$(KDIR)/src/hmm/ -L$(KDIR)/src/tree/ -L$(KDIR)/src/matrix/ -L$(KDIR)/src/util/ -L$(KDIR)/src/base/ nnet-forward.o $(KDIR)/src/nnet//libkaldi-nnet.so $(KDIR)/src/cudamatrix//libkaldi-cudamatrix.so $(KDIR)/src/lat//libkaldi-lat.so $(KDIR)/src/hmm//libkaldi-hmm.so $(KDIR)/src/tree//libkaldi-tree.so $(KDIR)/src/matrix//libkaldi-matrix.so $(KDIR)/src/util//libkaldi-util.so $(KDIR)/src/base//libkaldi-base.so -L$(KDIR)/tools/openfst/lib -lfst -lm -lpthread -ldl -lkaldi-nnet -lkaldi-cudamatrix -lkaldi-lat -lkaldi-hmm -lkaldi-tree -lkaldi-matrix -lkaldi-util -lkaldi-base -lstdc++ -L$(NERVDIR)/install/lib -Wl,-rpath=$(NERVDIR)/install/lib -lnervcore -lluaT -rdynamic -Wl,-rpath=$(KDIR)//tools/openfst/lib -L$(DUDADIR)/lib64 -Wl,-rpath,$(CUDADIR)/lib64 -Wl,-rpath=$(KDIR)//src/lib -lfst -lm -lpthread -ldl -L $(NERVDIR)/luajit-2.0/src/ -lluajit -o nnet-forward -L/home/intel/mkl/lib/intel64/ -Wl,-rpath=/home/intel/mkl/lib/intel64/ -lmkl_rt - -clean: - -rm nnet-forward.o nnet-forward - diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/asr_propagator.lua index 898b5a8..5d0ad7c 100644 --- a/kaldi_decode/src/nerv4decode.lua +++ b/kaldi_decode/src/asr_propagator.lua @@ -15,19 +15,18 @@ local function _add_profile_method(cls) end _add_profile_method(nerv.MMatrix) -function build_trainer(ifname, feature) +function build_propagator(ifname, feature) local param_repo = nerv.ParamRepo() param_repo:import(ifname, nil, gconf) local layer_repo = make_layer_repo(param_repo) local network = get_decode_network(layer_repo) local global_transf = get_global_transf(layer_repo) - local input_order = get_input_order() - local readers = make_readers(feature, layer_repo) - network:init(1) + local input_order = get_decode_input_order() + local readers = make_decode_readers(feature, layer_repo) - local iterative_trainer = function() + local batch_propagator = function() local data = nil - for ri = 1, #readers, 1 do + for ri = 1, #readers do data = readers[ri].reader:get_data() if data ~= nil then break @@ -38,6 +37,9 @@ function build_trainer(ifname, feature) return "", nil end + gconf.batch_size = data[input_order[1].id]:nrow() + network:init(gconf.batch_size) + local input = {} for i, e in ipairs(input_order) do local id = e.id @@ -47,16 +49,15 @@ function build_trainer(ifname, feature) local transformed if e.global_transf then transformed = nerv.speech_utils.global_transf(data[id], - global_transf, - gconf.frm_ext or 0, 0, - gconf) + global_transf, + gconf.frm_ext or 0, 0, + gconf) else transformed = data[id] end table.insert(input, transformed) end local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])} - network:batch_resize(input[1]:nrow()) network:propagate(input, output) local utt = data["key"] @@ -64,20 +65,17 @@ function build_trainer(ifname, feature) nerv.error("no key found.") end - local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol()) - output[1]:copy_toh(mat) - collectgarbage("collect") - return utt, mat + return utt, output[1] end - return iterative_trainer + return batch_propagator end function init(config, feature) dofile(config) gconf.use_cpu = true -- use CPU to decode - trainer = build_trainer(gconf.decode_param, feature) + trainer = build_propagator(gconf.decode_param, feature) end function feed() diff --git a/kaldi_decode/src/nnet-forward.cc b/kaldi_decode/src/nnet-forward.cc index 4911791..8781705 100644 --- a/kaldi_decode/src/nnet-forward.cc +++ b/kaldi_decode/src/nnet-forward.cc @@ -46,7 +46,7 @@ int main(int argc, char *argv[]) { const char *usage = "Perform forward pass through Neural Network.\n" "\n" - "Usage: nnet-forward [options] <nerv-config> <feature-rspecifier> <feature-wspecifier> [nerv4decode.lua]\n" + "Usage: nnet-forward [options] <nerv-config> <feature-rspecifier> <feature-wspecifier> [asr_propagator.lua]\n" "e.g.: \n" " nnet-forward config.lua ark:features.ark ark:mlpoutput.ark\n"; @@ -78,9 +78,9 @@ int main(int argc, char *argv[]) { std::string config = po.GetArg(1), feature_rspecifier = po.GetArg(2), feature_wspecifier = po.GetArg(3), - nerv4decode = "src/nerv4decode.lua"; - if(po.NumArgs() >= 4) - nerv4decode = po.GetArg(4); + propagator = "src/asr_propagator.lua"; + if(po.NumArgs() >= 4) + propagator = po.GetArg(4); //Select the GPU #if HAVE_CUDA==1 @@ -99,8 +99,8 @@ int main(int argc, char *argv[]) { lua_State *L = lua_open(); luaL_openlibs(L); - if(luaL_loadfile(L, nerv4decode.c_str())) - KALDI_ERR << "luaL_loadfile() " << nerv4decode << " failed " << lua_tostring(L, -1); + if(luaL_loadfile(L, propagator.c_str())) + KALDI_ERR << "luaL_loadfile() " << propagator << " failed " << lua_tostring(L, -1); if(lua_pcall(L, 0, 0, 0)) KALDI_ERR << "lua_pall failed " << lua_tostring(L, -1); |