summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/asr_propagator.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2016-04-04 01:17:49 +0800
committerDeterminant <[email protected]>2016-04-04 01:17:49 +0800
commite6d9e562fa42ddafac601be12da4c4faee85dd4d (patch)
tree3755f6c38bdf4e2247b2448d1c69e67918b13b48 /kaldi_decode/src/asr_propagator.lua
parentdca8f2e7373cb12216a50108ead3a9ed10c4e49b (diff)
catch up to the latest interface of nerv.Networkalpha-3.4
Diffstat (limited to 'kaldi_decode/src/asr_propagator.lua')
-rw-r--r--kaldi_decode/src/asr_propagator.lua10
1 files changed, 5 insertions, 5 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua
index 6a95647..ff9b8a2 100644
--- a/kaldi_decode/src/asr_propagator.lua
+++ b/kaldi_decode/src/asr_propagator.lua
@@ -61,14 +61,14 @@ function build_propagator(ifname, feature)
else
transformed = data[id]
end
- table.insert(input, transformed)
+ table.insert(input, {transformed})
end
- local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])}
+ local output = {{nerv.MMatrixFloat(input[1][1]:nrow(), network.dim_out[1])}}
network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
new_seq = {},
do_train = false,
- input = {input},
- output = {output},
+ input = input,
+ output = output,
err_input = {},
err_output = {}})
network:propagate()
@@ -79,7 +79,7 @@ function build_propagator(ifname, feature)
end
collectgarbage("collect")
- return utt, output[1]
+ return utt, output[1][1]
end
return batch_propagator