diff options
Diffstat (limited to 'kaldi_decode/src')
-rw-r--r-- | kaldi_decode/src/asr_propagator.lua | 132 |
1 files changed, 90 insertions, 42 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua index ff9b8a2..a3c5eb1 100644 --- a/kaldi_decode/src/asr_propagator.lua +++ b/kaldi_decode/src/asr_propagator.lua @@ -22,64 +22,111 @@ function build_propagator(ifname, feature) local network = get_decode_network(layer_repo) local global_transf = get_global_transf(layer_repo) local input_order = get_decode_input_order() + local input_name = gconf.decode_input_name or "main_scp" local readers = make_decode_readers(feature, layer_repo) + --nerv.info("prepare") + local buffer = nerv.SeqBuffer(gconf, { + buffer_size = gconf.buffer_size, + batch_size = gconf.batch_size, + chunk_size = gconf.chunk_size, + randomize = gconf.randomize, + readers = readers, + use_gpu = true + }) network = nerv.Network("nt", gconf, {network = network}) + network:init(gconf.batch_size, gconf.chunk_size) global_transf = nerv.Network("gt", gconf, {network = global_transf}) + global_transf:init(gconf.batch_size, gconf.chunk_size) - local batch_propagator = function() - local data = nil - for ri = 1, #readers do - data = readers[ri].reader:get_data() - if data ~= nil then - break - end + local prev_data = buffer:get_data() or nerv.error("no data in buffer") + local terminate = false + local input_pos = nil + for i, v in ipairs(input_order) do + if v.id == input_name then + input_pos = i end - - if data == nil then + end + if input_pos == nil then + nerv.error("input name %s not found in the input order list", input_name) + end + local batch_propagator = function() + if terminate then return "", nil end - - gconf.batch_size = data[input_order[1].id]:nrow() - global_transf:init(gconf.batch_size, 1) global_transf:epoch_init() - network:init(gconf.batch_size, 1) network:epoch_init() + local accu_output = {} + local utt_id = readers[input_pos].reader.key + if utt_id == nil then + nerv.error("no key found.") + end + while true do + local d + if prev_data then + d = prev_data + prev_data = nil + else + d = buffer:get_data() + if d == nil then + terminate = true + break + elseif #d.new_seq > 0 then + prev_data = d -- the first data of the next utterance + break + end + end - local input = {} - for i, e in ipairs(input_order) do - local id = e.id - if data[id] == nil then - nerv.error("input data %s not found", id) + local input = {} + local output = {{}} + for i, e in ipairs(input_order) do + local id = e.id + if d.data[id] == nil then + nerv.error("input data %s not found", id) + end + local transformed = {} + if e.global_transf then + for _, mini_batch in ipairs(d.data[id]) do + table.insert(transformed, + nerv.speech_utils.global_transf(mini_batch, + global_transf, + gconf.frm_ext or 0, 0, + gconf)) + end + else + transformed = d.data[id] + end + table.insert(input, transformed) + for i = 1, gconf.chunk_size do + table.insert(output[1], gconf.mmat_type(gconf.batch_size, network.dim_out[1])) + end end - local transformed - if e.global_transf then - transformed = nerv.speech_utils.global_transf(data[id], - global_transf, - gconf.frm_ext or 0, 0, - gconf) - else - transformed = data[id] + --nerv.info("input num: %d\nmat: %s\n", #input[1], input[1][1]) + --nerv.info("output num: %d\nmat: %s\n", #output[1], output[1][1]) + network:mini_batch_init({seq_length = d.seq_length, + new_seq = d.new_seq, + do_train = false, + input = input, + output = output, + err_input = {}, + err_output = {}}) + network:propagate() + for i, v in ipairs(output[1]) do + --nerv.info(gconf.mask[i]) + if gconf.mask[i][0][0] > 0 then -- is not a hole + table.insert(accu_output, v) + --nerv.info("input: %s\noutput: %s\n", input[1][i], output[1][i]) + end end - table.insert(input, {transformed}) end - local output = {{nerv.MMatrixFloat(input[1][1]:nrow(), network.dim_out[1])}} - network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1), - new_seq = {}, - do_train = false, - input = input, - output = output, - err_input = {}, - err_output = {}}) - network:propagate() - - local utt = data["key"] - if utt == nil then - nerv.error("no key found.") + local utt_matrix = gconf.mmat_type(#accu_output, accu_output[1]:ncol()) + for i, v in ipairs(accu_output) do + utt_matrix:copy_from(v, 0, v:nrow(), i - 1) end - + --nerv.info(utt_matrix) collectgarbage("collect") - return utt, output[1][1] + nerv.info("propagated %d features of an utterance", utt_matrix:nrow()) + return utt_id, utt_matrix end return batch_propagator @@ -89,6 +136,7 @@ function init(config, feature) dofile(config) gconf.mmat_type = nerv.MMatrixFloat gconf.use_cpu = true -- use CPU to decode + gconf.batch_size = 1 trainer = build_propagator(gconf.decode_param, feature) end |