diff options
Diffstat (limited to 'kaldi_decode')
-rwxr-xr-x | kaldi_decode/README.timit | 4 | ||||
-rw-r--r-- | kaldi_decode/src/asr_propagator.lua | 18 |
2 files changed, 18 insertions, 4 deletions
diff --git a/kaldi_decode/README.timit b/kaldi_decode/README.timit index 0a3e33a..4c4e310 100755 --- a/kaldi_decode/README.timit +++ b/kaldi_decode/README.timit @@ -5,8 +5,8 @@ source cmd.sh gmmdir=/speechlab/users/mfy43/timit/s5/exp/tri3/ data_fmllr=/speechlab/users/mfy43/timit/s5/data-fmllr-tri3/ dir=/speechlab/users/mfy43/timit/s5/exp/dnn4_nerv_dnn/ -nerv_config=/speechlab/users/mfy43/nerv/nerv/examples/timit_baseline2.lua -decode=/speechlab/users/mfy43/nerv/install/bin/decode_with_nerv.sh +nerv_config=/speechlab/users/mfy43/timit/s5/timit_baseline2.lua +decode=/speechlab/users/mfy43/timit/s5/nerv/install/bin/decode_with_nerv.sh # Decode (reuse HCLG graph) $decode --nj 20 --cmd "$decode_cmd" --acwt 0.2 \ diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua index 5d0ad7c..4005875 100644 --- a/kaldi_decode/src/asr_propagator.lua +++ b/kaldi_decode/src/asr_propagator.lua @@ -24,6 +24,9 @@ function build_propagator(ifname, feature) local input_order = get_decode_input_order() local readers = make_decode_readers(feature, layer_repo) + network = nerv.Network("nt", gconf, {network = network}) + global_transf = nerv.Network("gt", gconf, {network = global_transf}) + local batch_propagator = function() local data = nil for ri = 1, #readers do @@ -38,7 +41,10 @@ function build_propagator(ifname, feature) end gconf.batch_size = data[input_order[1].id]:nrow() - network:init(gconf.batch_size) + global_transf:init(gconf.batch_size, 1) + global_transf:epoch_init() + network:init(gconf.batch_size, 1) + network:epoch_init() local input = {} for i, e in ipairs(input_order) do @@ -58,7 +64,14 @@ function build_propagator(ifname, feature) table.insert(input, transformed) end local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])} - network:propagate(input, output) + network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1), + new_seq = {}, + do_train = false, + input = {input}, + output = {output}, + err_input = {}, + err_output = {}}) + network:propagate() local utt = data["key"] if utt == nil then @@ -74,6 +87,7 @@ end function init(config, feature) dofile(config) + gconf.mmat_type = nerv.MMatrixFloat gconf.use_cpu = true -- use CPU to decode trainer = build_propagator(gconf.decode_param, feature) end |