summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/asr_propagator.lua
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_decode/src/asr_propagator.lua')
-rw-r--r--kaldi_decode/src/asr_propagator.lua18
1 files changed, 16 insertions, 2 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua
index 5d0ad7c..4005875 100644
--- a/kaldi_decode/src/asr_propagator.lua
+++ b/kaldi_decode/src/asr_propagator.lua
@@ -24,6 +24,9 @@ function build_propagator(ifname, feature)
local input_order = get_decode_input_order()
local readers = make_decode_readers(feature, layer_repo)
+ network = nerv.Network("nt", gconf, {network = network})
+ global_transf = nerv.Network("gt", gconf, {network = global_transf})
+
local batch_propagator = function()
local data = nil
for ri = 1, #readers do
@@ -38,7 +41,10 @@ function build_propagator(ifname, feature)
end
gconf.batch_size = data[input_order[1].id]:nrow()
- network:init(gconf.batch_size)
+ global_transf:init(gconf.batch_size, 1)
+ global_transf:epoch_init()
+ network:init(gconf.batch_size, 1)
+ network:epoch_init()
local input = {}
for i, e in ipairs(input_order) do
@@ -58,7 +64,14 @@ function build_propagator(ifname, feature)
table.insert(input, transformed)
end
local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])}
- network:propagate(input, output)
+ network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
+ new_seq = {},
+ do_train = false,
+ input = {input},
+ output = {output},
+ err_input = {},
+ err_output = {}})
+ network:propagate()
local utt = data["key"]
if utt == nil then
@@ -74,6 +87,7 @@ end
function init(config, feature)
dofile(config)
+ gconf.mmat_type = nerv.MMatrixFloat
gconf.use_cpu = true -- use CPU to decode
trainer = build_propagator(gconf.decode_param, feature)
end