From 3101d1f9c1b2e31fbde75c1c9de5f6872340f5f7 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sun, 8 May 2016 11:40:13 +0800 Subject: change decoder API (adapted to `trainer.lua`); remove redundant options in kaldi_io --- kaldi_decode/src/asr_propagator.lua | 41 ++++++++++++------------------------- 1 file changed, 13 insertions(+), 28 deletions(-) (limited to 'kaldi_decode/src') diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua index a3c5eb1..ab18d6d 100644 --- a/kaldi_decode/src/asr_propagator.lua +++ b/kaldi_decode/src/asr_propagator.lua @@ -16,34 +16,33 @@ end _add_profile_method(nerv.MMatrix) function build_propagator(ifname, feature) + -- FIXME: this is still a hack + local trainer = nerv.Trainer + ---- local param_repo = nerv.ParamRepo() param_repo:import(ifname, gconf) - local layer_repo = make_layer_repo(param_repo) - local network = get_decode_network(layer_repo) - local global_transf = get_global_transf(layer_repo) - local input_order = get_decode_input_order() + local layer_repo = trainer.make_layer_repo(nil, param_repo) + local network = trainer.get_decode_network(nil, layer_repo) + local input_order = trainer.get_decode_input_order(nil) local input_name = gconf.decode_input_name or "main_scp" - local readers = make_decode_readers(feature, layer_repo) - --nerv.info("prepare") + local readers = trainer.make_decode_readers(nil, feature) + -- nerv.info("prepare") local buffer = nerv.SeqBuffer(gconf, { buffer_size = gconf.buffer_size, batch_size = gconf.batch_size, chunk_size = gconf.chunk_size, randomize = gconf.randomize, readers = readers, - use_gpu = true }) network = nerv.Network("nt", gconf, {network = network}) network:init(gconf.batch_size, gconf.chunk_size) - global_transf = nerv.Network("gt", gconf, {network = global_transf}) - global_transf:init(gconf.batch_size, gconf.chunk_size) local prev_data = buffer:get_data() or nerv.error("no data in buffer") local terminate = false local input_pos = nil for i, v in ipairs(input_order) do - if v.id == input_name then + if v == input_name then input_pos = i end end @@ -54,7 +53,6 @@ function build_propagator(ifname, feature) if terminate then return "", nil end - global_transf:epoch_init() network:epoch_init() local accu_output = {} local utt_id = readers[input_pos].reader.key @@ -79,24 +77,11 @@ function build_propagator(ifname, feature) local input = {} local output = {{}} - for i, e in ipairs(input_order) do - local id = e.id + for i, id in ipairs(input_order) do if d.data[id] == nil then nerv.error("input data %s not found", id) end - local transformed = {} - if e.global_transf then - for _, mini_batch in ipairs(d.data[id]) do - table.insert(transformed, - nerv.speech_utils.global_transf(mini_batch, - global_transf, - gconf.frm_ext or 0, 0, - gconf)) - end - else - transformed = d.data[id] - end - table.insert(input, transformed) + table.insert(input, d.data[id]) for i = 1, gconf.chunk_size do table.insert(output[1], gconf.mmat_type(gconf.batch_size, network.dim_out[1])) end @@ -137,10 +122,10 @@ function init(config, feature) gconf.mmat_type = nerv.MMatrixFloat gconf.use_cpu = true -- use CPU to decode gconf.batch_size = 1 - trainer = build_propagator(gconf.decode_param, feature) + propagator = build_propagator(gconf.decode_param, feature) end function feed() - local utt, mat = trainer() + local utt, mat = propagator() return utt, mat end -- cgit v1.2.3-70-g09d2