diff options
Diffstat (limited to 'kaldi_decode/src')
-rw-r--r-- | kaldi_decode/src/asr_propagator.lua | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua index 5d0ad7c..4005875 100644 --- a/kaldi_decode/src/asr_propagator.lua +++ b/kaldi_decode/src/asr_propagator.lua @@ -24,6 +24,9 @@ function build_propagator(ifname, feature) local input_order = get_decode_input_order() local readers = make_decode_readers(feature, layer_repo) + network = nerv.Network("nt", gconf, {network = network}) + global_transf = nerv.Network("gt", gconf, {network = global_transf}) + local batch_propagator = function() local data = nil for ri = 1, #readers do @@ -38,7 +41,10 @@ function build_propagator(ifname, feature) end gconf.batch_size = data[input_order[1].id]:nrow() - network:init(gconf.batch_size) + global_transf:init(gconf.batch_size, 1) + global_transf:epoch_init() + network:init(gconf.batch_size, 1) + network:epoch_init() local input = {} for i, e in ipairs(input_order) do @@ -58,7 +64,14 @@ function build_propagator(ifname, feature) table.insert(input, transformed) end local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])} - network:propagate(input, output) + network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1), + new_seq = {}, + do_train = false, + input = {input}, + output = {output}, + err_input = {}, + err_output = {}}) + network:propagate() local utt = data["key"] if utt == nil then @@ -74,6 +87,7 @@ end function init(config, feature) dofile(config) + gconf.mmat_type = nerv.MMatrixFloat gconf.use_cpu = true -- use CPU to decode trainer = build_propagator(gconf.decode_param, feature) end |