diff options
author | Determinant <[email protected]> | 2016-03-12 13:36:59 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-03-12 13:36:59 +0800 |
commit | ddc4545050b41d12cfdc19cea9ba31c940d3d537 (patch) | |
tree | b47b54949885a11de97c1406c3a61ab7b0ffeb56 /kaldi_decode/src/asr_propagator.lua | |
parent | 54b33aa3a95f5a7a023e9ea453094ae081c91f64 (diff) |
adapt kaldi_decode propagator to the new arch
Diffstat (limited to 'kaldi_decode/src/asr_propagator.lua')
-rw-r--r-- | kaldi_decode/src/asr_propagator.lua | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua index 5d0ad7c..4005875 100644 --- a/kaldi_decode/src/asr_propagator.lua +++ b/kaldi_decode/src/asr_propagator.lua @@ -24,6 +24,9 @@ function build_propagator(ifname, feature) local input_order = get_decode_input_order() local readers = make_decode_readers(feature, layer_repo) + network = nerv.Network("nt", gconf, {network = network}) + global_transf = nerv.Network("gt", gconf, {network = global_transf}) + local batch_propagator = function() local data = nil for ri = 1, #readers do @@ -38,7 +41,10 @@ function build_propagator(ifname, feature) end gconf.batch_size = data[input_order[1].id]:nrow() - network:init(gconf.batch_size) + global_transf:init(gconf.batch_size, 1) + global_transf:epoch_init() + network:init(gconf.batch_size, 1) + network:epoch_init() local input = {} for i, e in ipairs(input_order) do @@ -58,7 +64,14 @@ function build_propagator(ifname, feature) table.insert(input, transformed) end local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])} - network:propagate(input, output) + network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1), + new_seq = {}, + do_train = false, + input = {input}, + output = {output}, + err_input = {}, + err_output = {}}) + network:propagate() local utt = data["key"] if utt == nil then @@ -74,6 +87,7 @@ end function init(config, feature) dofile(config) + gconf.mmat_type = nerv.MMatrixFloat gconf.use_cpu = true -- use CPU to decode trainer = build_propagator(gconf.decode_param, feature) end |