summaryrefslogtreecommitdiff
path: root/kaldi_decode/src
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-02-29 20:03:52 +0800
committerDeterminant <ted.sybil@gmail.com>2016-02-29 20:03:52 +0800
commit1e0ac0fb5c9f517e7325deb16004de1054454da7 (patch)
treec75a6f0fc9aa50caa9fb9dccec7a56b41d3b63fd /kaldi_decode/src
parentfda1c8cf07c5130aff53775454a5f2cfc8f5d2e0 (diff)
refactor kaldi_decode
Diffstat (limited to 'kaldi_decode/src')
-rw-r--r--kaldi_decode/src/Makefile12
-rw-r--r--kaldi_decode/src/asr_propagator.lua (renamed from kaldi_decode/src/nerv4decode.lua)30
-rw-r--r--kaldi_decode/src/nnet-forward.cc12
3 files changed, 20 insertions, 34 deletions
diff --git a/kaldi_decode/src/Makefile b/kaldi_decode/src/Makefile
deleted file mode 100644
index 0897798..0000000
--- a/kaldi_decode/src/Makefile
+++ /dev/null
@@ -1,12 +0,0 @@
-# Change KDIR to `kaldi-trunk' path (Kaldi must be compiled with --share)
-KDIR := /speechlab/tools/KALDI/kaldi-master/
-NERVDIR := /speechlab/users/mfy43/nerv/
-CUDADIR := /usr/local/cuda/
-
-nnet-forward:
- g++ -msse -msse2 -Wall -I $(KDIR)/src/ -pthread -DKALDI_DOUBLEPRECISION=0 -DHAVE_POSIX_MEMALIGN -Wno-sign-compare -Wno-unused-local-typedefs -Winit-self -DHAVE_EXECINFO_H=1 -rdynamic -DHAVE_CXXABI_H -DHAVE_ATLAS -I $(KDIR)/tools/ATLAS/include -I $(KDIR)/tools/openfst/include -Wno-sign-compare -g -fPIC -DHAVE_CUDA -I $(CUDADIR)/include -DKALDI_NO_EXPF -I $(NERVDIR)/install//include/luajit-2.0/ -I $(NERVDIR)/install/include/ -DLUA_USE_APICHECK -c -o nnet-forward.o nnet-forward.cc
- g++ -rdynamic -Wl,-rpath=$(KDIR)/tools/openfst/lib -L$(CUDADIR)/lib64 -Wl,-rpath,$(CUDADIR)/lib64 -Wl,-rpath=$(KDIR)/src/lib -L. -L$(KDIR)/src/nnet/ -L$(KDIR)/src/cudamatrix/ -L$(KDIR)/src/lat/ -L$(KDIR)/src/hmm/ -L$(KDIR)/src/tree/ -L$(KDIR)/src/matrix/ -L$(KDIR)/src/util/ -L$(KDIR)/src/base/ nnet-forward.o $(KDIR)/src/nnet//libkaldi-nnet.so $(KDIR)/src/cudamatrix//libkaldi-cudamatrix.so $(KDIR)/src/lat//libkaldi-lat.so $(KDIR)/src/hmm//libkaldi-hmm.so $(KDIR)/src/tree//libkaldi-tree.so $(KDIR)/src/matrix//libkaldi-matrix.so $(KDIR)/src/util//libkaldi-util.so $(KDIR)/src/base//libkaldi-base.so -L$(KDIR)/tools/openfst/lib -lfst -lm -lpthread -ldl -lkaldi-nnet -lkaldi-cudamatrix -lkaldi-lat -lkaldi-hmm -lkaldi-tree -lkaldi-matrix -lkaldi-util -lkaldi-base -lstdc++ -L$(NERVDIR)/install/lib -Wl,-rpath=$(NERVDIR)/install/lib -lnervcore -lluaT -rdynamic -Wl,-rpath=$(KDIR)//tools/openfst/lib -L$(DUDADIR)/lib64 -Wl,-rpath,$(CUDADIR)/lib64 -Wl,-rpath=$(KDIR)//src/lib -lfst -lm -lpthread -ldl -L $(NERVDIR)/luajit-2.0/src/ -lluajit -o nnet-forward -L/home/intel/mkl/lib/intel64/ -Wl,-rpath=/home/intel/mkl/lib/intel64/ -lmkl_rt
-
-clean:
- -rm nnet-forward.o nnet-forward
-
diff --git a/kaldi_decode/src/nerv4decode.lua b/kaldi_decode/src/asr_propagator.lua
index 898b5a8..5d0ad7c 100644
--- a/kaldi_decode/src/nerv4decode.lua
+++ b/kaldi_decode/src/asr_propagator.lua
@@ -15,19 +15,18 @@ local function _add_profile_method(cls)
end
_add_profile_method(nerv.MMatrix)
-function build_trainer(ifname, feature)
+function build_propagator(ifname, feature)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_decode_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
- local input_order = get_input_order()
- local readers = make_readers(feature, layer_repo)
- network:init(1)
+ local input_order = get_decode_input_order()
+ local readers = make_decode_readers(feature, layer_repo)
- local iterative_trainer = function()
+ local batch_propagator = function()
local data = nil
- for ri = 1, #readers, 1 do
+ for ri = 1, #readers do
data = readers[ri].reader:get_data()
if data ~= nil then
break
@@ -38,6 +37,9 @@ function build_trainer(ifname, feature)
return "", nil
end
+ gconf.batch_size = data[input_order[1].id]:nrow()
+ network:init(gconf.batch_size)
+
local input = {}
for i, e in ipairs(input_order) do
local id = e.id
@@ -47,16 +49,15 @@ function build_trainer(ifname, feature)
local transformed
if e.global_transf then
transformed = nerv.speech_utils.global_transf(data[id],
- global_transf,
- gconf.frm_ext or 0, 0,
- gconf)
+ global_transf,
+ gconf.frm_ext or 0, 0,
+ gconf)
else
transformed = data[id]
end
table.insert(input, transformed)
end
local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])}
- network:batch_resize(input[1]:nrow())
network:propagate(input, output)
local utt = data["key"]
@@ -64,20 +65,17 @@ function build_trainer(ifname, feature)
nerv.error("no key found.")
end
- local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol())
- output[1]:copy_toh(mat)
-
collectgarbage("collect")
- return utt, mat
+ return utt, output[1]
end
- return iterative_trainer
+ return batch_propagator
end
function init(config, feature)
dofile(config)
gconf.use_cpu = true -- use CPU to decode
- trainer = build_trainer(gconf.decode_param, feature)
+ trainer = build_propagator(gconf.decode_param, feature)
end
function feed()
diff --git a/kaldi_decode/src/nnet-forward.cc b/kaldi_decode/src/nnet-forward.cc
index 4911791..8781705 100644
--- a/kaldi_decode/src/nnet-forward.cc
+++ b/kaldi_decode/src/nnet-forward.cc
@@ -46,7 +46,7 @@ int main(int argc, char *argv[]) {
const char *usage =
"Perform forward pass through Neural Network.\n"
"\n"
- "Usage: nnet-forward [options] <nerv-config> <feature-rspecifier> <feature-wspecifier> [nerv4decode.lua]\n"
+ "Usage: nnet-forward [options] <nerv-config> <feature-rspecifier> <feature-wspecifier> [asr_propagator.lua]\n"
"e.g.: \n"
" nnet-forward config.lua ark:features.ark ark:mlpoutput.ark\n";
@@ -78,9 +78,9 @@ int main(int argc, char *argv[]) {
std::string config = po.GetArg(1),
feature_rspecifier = po.GetArg(2),
feature_wspecifier = po.GetArg(3),
- nerv4decode = "src/nerv4decode.lua";
- if(po.NumArgs() >= 4)
- nerv4decode = po.GetArg(4);
+ propagator = "src/asr_propagator.lua";
+ if(po.NumArgs() >= 4)
+ propagator = po.GetArg(4);
//Select the GPU
#if HAVE_CUDA==1
@@ -99,8 +99,8 @@ int main(int argc, char *argv[]) {
lua_State *L = lua_open();
luaL_openlibs(L);
- if(luaL_loadfile(L, nerv4decode.c_str()))
- KALDI_ERR << "luaL_loadfile() " << nerv4decode << " failed " << lua_tostring(L, -1);
+ if(luaL_loadfile(L, propagator.c_str()))
+ KALDI_ERR << "luaL_loadfile() " << propagator << " failed " << lua_tostring(L, -1);
if(lua_pcall(L, 0, 0, 0))
KALDI_ERR << "lua_pall failed " << lua_tostring(L, -1);