summaryrefslogtreecommitdiff
path: root/kaldi_decode
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_decode')
-rw-r--r--kaldi_decode/src/asr_propagator.lua132
1 files changed, 90 insertions, 42 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua
index ff9b8a2..a3c5eb1 100644
--- a/kaldi_decode/src/asr_propagator.lua
+++ b/kaldi_decode/src/asr_propagator.lua
@@ -22,64 +22,111 @@ function build_propagator(ifname, feature)
local network = get_decode_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_decode_input_order()
+ local input_name = gconf.decode_input_name or "main_scp"
local readers = make_decode_readers(feature, layer_repo)
+ --nerv.info("prepare")
+ local buffer = nerv.SeqBuffer(gconf, {
+ buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
+ chunk_size = gconf.chunk_size,
+ randomize = gconf.randomize,
+ readers = readers,
+ use_gpu = true
+ })
network = nerv.Network("nt", gconf, {network = network})
+ network:init(gconf.batch_size, gconf.chunk_size)
global_transf = nerv.Network("gt", gconf, {network = global_transf})
+ global_transf:init(gconf.batch_size, gconf.chunk_size)
- local batch_propagator = function()
- local data = nil
- for ri = 1, #readers do
- data = readers[ri].reader:get_data()
- if data ~= nil then
- break
- end
+ local prev_data = buffer:get_data() or nerv.error("no data in buffer")
+ local terminate = false
+ local input_pos = nil
+ for i, v in ipairs(input_order) do
+ if v.id == input_name then
+ input_pos = i
end
-
- if data == nil then
+ end
+ if input_pos == nil then
+ nerv.error("input name %s not found in the input order list", input_name)
+ end
+ local batch_propagator = function()
+ if terminate then
return "", nil
end
-
- gconf.batch_size = data[input_order[1].id]:nrow()
- global_transf:init(gconf.batch_size, 1)
global_transf:epoch_init()
- network:init(gconf.batch_size, 1)
network:epoch_init()
+ local accu_output = {}
+ local utt_id = readers[input_pos].reader.key
+ if utt_id == nil then
+ nerv.error("no key found.")
+ end
+ while true do
+ local d
+ if prev_data then
+ d = prev_data
+ prev_data = nil
+ else
+ d = buffer:get_data()
+ if d == nil then
+ terminate = true
+ break
+ elseif #d.new_seq > 0 then
+ prev_data = d -- the first data of the next utterance
+ break
+ end
+ end
- local input = {}
- for i, e in ipairs(input_order) do
- local id = e.id
- if data[id] == nil then
- nerv.error("input data %s not found", id)
+ local input = {}
+ local output = {{}}
+ for i, e in ipairs(input_order) do
+ local id = e.id
+ if d.data[id] == nil then
+ nerv.error("input data %s not found", id)
+ end
+ local transformed = {}
+ if e.global_transf then
+ for _, mini_batch in ipairs(d.data[id]) do
+ table.insert(transformed,
+ nerv.speech_utils.global_transf(mini_batch,
+ global_transf,
+ gconf.frm_ext or 0, 0,
+ gconf))
+ end
+ else
+ transformed = d.data[id]
+ end
+ table.insert(input, transformed)
+ for i = 1, gconf.chunk_size do
+ table.insert(output[1], gconf.mmat_type(gconf.batch_size, network.dim_out[1]))
+ end
end
- local transformed
- if e.global_transf then
- transformed = nerv.speech_utils.global_transf(data[id],
- global_transf,
- gconf.frm_ext or 0, 0,
- gconf)
- else
- transformed = data[id]
+ --nerv.info("input num: %d\nmat: %s\n", #input[1], input[1][1])
+ --nerv.info("output num: %d\nmat: %s\n", #output[1], output[1][1])
+ network:mini_batch_init({seq_length = d.seq_length,
+ new_seq = d.new_seq,
+ do_train = false,
+ input = input,
+ output = output,
+ err_input = {},
+ err_output = {}})
+ network:propagate()
+ for i, v in ipairs(output[1]) do
+ --nerv.info(gconf.mask[i])
+ if gconf.mask[i][0][0] > 0 then -- is not a hole
+ table.insert(accu_output, v)
+ --nerv.info("input: %s\noutput: %s\n", input[1][i], output[1][i])
+ end
end
- table.insert(input, {transformed})
end
- local output = {{nerv.MMatrixFloat(input[1][1]:nrow(), network.dim_out[1])}}
- network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
- new_seq = {},
- do_train = false,
- input = input,
- output = output,
- err_input = {},
- err_output = {}})
- network:propagate()
-
- local utt = data["key"]
- if utt == nil then
- nerv.error("no key found.")
+ local utt_matrix = gconf.mmat_type(#accu_output, accu_output[1]:ncol())
+ for i, v in ipairs(accu_output) do
+ utt_matrix:copy_from(v, 0, v:nrow(), i - 1)
end
-
+ --nerv.info(utt_matrix)
collectgarbage("collect")
- return utt, output[1][1]
+ nerv.info("propagated %d features of an utterance", utt_matrix:nrow())
+ return utt_id, utt_matrix
end
return batch_propagator
@@ -89,6 +136,7 @@ function init(config, feature)
dofile(config)
gconf.mmat_type = nerv.MMatrixFloat
gconf.use_cpu = true -- use CPU to decode
+ gconf.batch_size = 1
trainer = build_propagator(gconf.decode_param, feature)
end