summaryrefslogtreecommitdiff
path: root/kaldi_decode
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_decode')
-rwxr-xr-xkaldi_decode/README.timit4
-rw-r--r--kaldi_decode/src/asr_propagator.lua18
2 files changed, 18 insertions, 4 deletions
diff --git a/kaldi_decode/README.timit b/kaldi_decode/README.timit
index 0a3e33a..4c4e310 100755
--- a/kaldi_decode/README.timit
+++ b/kaldi_decode/README.timit
@@ -5,8 +5,8 @@ source cmd.sh
gmmdir=/speechlab/users/mfy43/timit/s5/exp/tri3/
data_fmllr=/speechlab/users/mfy43/timit/s5/data-fmllr-tri3/
dir=/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_dnn/
-nerv_config=/speechlab/users/mfy43/nerv/nerv/examples/timit_baseline2.lua
-decode=/speechlab/users/mfy43/nerv/install/bin/decode_with_nerv.sh
+nerv_config=/speechlab/users/mfy43/timit/s5/timit_baseline2.lua
+decode=/speechlab/users/mfy43/timit/s5/nerv/install/bin/decode_with_nerv.sh
# Decode (reuse HCLG graph)
$decode --nj 20 --cmd "$decode_cmd" --acwt 0.2 \
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua
index 5d0ad7c..4005875 100644
--- a/kaldi_decode/src/asr_propagator.lua
+++ b/kaldi_decode/src/asr_propagator.lua
@@ -24,6 +24,9 @@ function build_propagator(ifname, feature)
local input_order = get_decode_input_order()
local readers = make_decode_readers(feature, layer_repo)
+ network = nerv.Network("nt", gconf, {network = network})
+ global_transf = nerv.Network("gt", gconf, {network = global_transf})
+
local batch_propagator = function()
local data = nil
for ri = 1, #readers do
@@ -38,7 +41,10 @@ function build_propagator(ifname, feature)
end
gconf.batch_size = data[input_order[1].id]:nrow()
- network:init(gconf.batch_size)
+ global_transf:init(gconf.batch_size, 1)
+ global_transf:epoch_init()
+ network:init(gconf.batch_size, 1)
+ network:epoch_init()
local input = {}
for i, e in ipairs(input_order) do
@@ -58,7 +64,14 @@ function build_propagator(ifname, feature)
table.insert(input, transformed)
end
local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])}
- network:propagate(input, output)
+ network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
+ new_seq = {},
+ do_train = false,
+ input = {input},
+ output = {output},
+ err_input = {},
+ err_output = {}})
+ network:propagate()
local utt = data["key"]
if utt == nil then
@@ -74,6 +87,7 @@ end
function init(config, feature)
dofile(config)
+ gconf.mmat_type = nerv.MMatrixFloat
gconf.use_cpu = true -- use CPU to decode
trainer = build_propagator(gconf.decode_param, feature)
end