diff options
author | Determinant <[email protected]> | 2016-05-08 11:40:13 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-05-08 11:40:13 +0800 |
commit | 3101d1f9c1b2e31fbde75c1c9de5f6872340f5f7 (patch) | |
tree | 2f6bcf926ab3ebdedb5e4920a884ac5031e698b7 | |
parent | 2da71705cab5a583c642441f8321ddbaf0c7cb42 (diff) |
change decoder API (adapted to `trainer.lua`); remove redundant options in kaldi_io
-rw-r--r-- | kaldi_decode/src/asr_propagator.lua | 41 | ||||
-rw-r--r-- | kaldi_io/init.lua | 10 | ||||
-rw-r--r-- | kaldi_io/src/cwrapper_kaldi.cpp | 2 | ||||
-rw-r--r-- | kaldi_io/src/cwrapper_kaldi.h | 2 | ||||
-rw-r--r-- | kaldi_io/src/init.c | 3 |
5 files changed, 22 insertions, 36 deletions
diff --git a/kaldi_decode/src/asr_propagator.lua b/kaldi_decode/src/asr_propagator.lua index a3c5eb1..ab18d6d 100644 --- a/kaldi_decode/src/asr_propagator.lua +++ b/kaldi_decode/src/asr_propagator.lua @@ -16,34 +16,33 @@ end _add_profile_method(nerv.MMatrix) function build_propagator(ifname, feature) + -- FIXME: this is still a hack + local trainer = nerv.Trainer + ---- local param_repo = nerv.ParamRepo() param_repo:import(ifname, gconf) - local layer_repo = make_layer_repo(param_repo) - local network = get_decode_network(layer_repo) - local global_transf = get_global_transf(layer_repo) - local input_order = get_decode_input_order() + local layer_repo = trainer.make_layer_repo(nil, param_repo) + local network = trainer.get_decode_network(nil, layer_repo) + local input_order = trainer.get_decode_input_order(nil) local input_name = gconf.decode_input_name or "main_scp" - local readers = make_decode_readers(feature, layer_repo) - --nerv.info("prepare") + local readers = trainer.make_decode_readers(nil, feature) + -- nerv.info("prepare") local buffer = nerv.SeqBuffer(gconf, { buffer_size = gconf.buffer_size, batch_size = gconf.batch_size, chunk_size = gconf.chunk_size, randomize = gconf.randomize, readers = readers, - use_gpu = true }) network = nerv.Network("nt", gconf, {network = network}) network:init(gconf.batch_size, gconf.chunk_size) - global_transf = nerv.Network("gt", gconf, {network = global_transf}) - global_transf:init(gconf.batch_size, gconf.chunk_size) local prev_data = buffer:get_data() or nerv.error("no data in buffer") local terminate = false local input_pos = nil for i, v in ipairs(input_order) do - if v.id == input_name then + if v == input_name then input_pos = i end end @@ -54,7 +53,6 @@ function build_propagator(ifname, feature) if terminate then return "", nil end - global_transf:epoch_init() network:epoch_init() local accu_output = {} local utt_id = readers[input_pos].reader.key @@ -79,24 +77,11 @@ function build_propagator(ifname, feature) local input = {} local output = {{}} - for i, e in ipairs(input_order) do - local id = e.id + for i, id in ipairs(input_order) do if d.data[id] == nil then nerv.error("input data %s not found", id) end - local transformed = {} - if e.global_transf then - for _, mini_batch in ipairs(d.data[id]) do - table.insert(transformed, - nerv.speech_utils.global_transf(mini_batch, - global_transf, - gconf.frm_ext or 0, 0, - gconf)) - end - else - transformed = d.data[id] - end - table.insert(input, transformed) + table.insert(input, d.data[id]) for i = 1, gconf.chunk_size do table.insert(output[1], gconf.mmat_type(gconf.batch_size, network.dim_out[1])) end @@ -137,10 +122,10 @@ function init(config, feature) gconf.mmat_type = nerv.MMatrixFloat gconf.use_cpu = true -- use CPU to decode gconf.batch_size = 1 - trainer = build_propagator(gconf.decode_param, feature) + propagator = build_propagator(gconf.decode_param, feature) end function feed() - local utt, mat = trainer() + local utt, mat = propagator() return utt, mat end diff --git a/kaldi_io/init.lua b/kaldi_io/init.lua index bec2589..5325630 100644 --- a/kaldi_io/init.lua +++ b/kaldi_io/init.lua @@ -21,16 +21,18 @@ function KaldiReader:__init(global_conf, reader_conf) self.lab_repo = {} if reader_conf.mlfs then for id, mlf_spec in pairs(reader_conf.mlfs) do - if mlf_spec.format == nil then - nerv.error("format spec is expected for label %s", id) + if mlf_spec.targets_rspecifier == nil then + nerv.error("target spec is expected for label %s", id) end - self.lab_repo[id] = nerv.KaldiLabelRepo(mlf_spec.targets_rspecifier, - mlf_spec.format) + self.lab_repo[id] = nerv.KaldiLabelRepo(mlf_spec.targets_rspecifier) end end self.lookup_repo = {} if reader_conf.lookup then for id, lookup_spec in pairs(reader_conf.lookup) do + if lookup_spec.targets_rspecifier == nil then + nerv.error("target spec is expected for label %s", id) + end if lookup_spec.map_rspecifier == nil then nerv.error("map spec is expected for lookup %s", id) end diff --git a/kaldi_io/src/cwrapper_kaldi.cpp b/kaldi_io/src/cwrapper_kaldi.cpp index 788128b..9cff12e 100644 --- a/kaldi_io/src/cwrapper_kaldi.cpp +++ b/kaldi_io/src/cwrapper_kaldi.cpp @@ -150,7 +150,7 @@ extern "C" { kaldi::RandomAccessPosteriorReader *targets_reader; }; - KaldiLabelRepo *kaldi_label_repo_new(const char *targets_rspecifier, const char *fmt) { + KaldiLabelRepo *kaldi_label_repo_new(const char *targets_rspecifier) { KaldiLabelRepo *repo = new KaldiLabelRepo(); repo->targets_reader = new kaldi::RandomAccessPosteriorReader(string(targets_rspecifier)); return repo; diff --git a/kaldi_io/src/cwrapper_kaldi.h b/kaldi_io/src/cwrapper_kaldi.h index db20087..3b01ee1 100644 --- a/kaldi_io/src/cwrapper_kaldi.h +++ b/kaldi_io/src/cwrapper_kaldi.h @@ -18,7 +18,7 @@ extern "C" { typedef struct KaldiLabelRepo KaldiLabelRepo; - KaldiLabelRepo *kaldi_label_repo_new(const char *, const char *fmt); + KaldiLabelRepo *kaldi_label_repo_new(const char *); Matrix *kaldi_label_repo_read_utterance(KaldiLabelRepo *repo, KaldiFeatureRepo *, int, lua_State *L, diff --git a/kaldi_io/src/init.c b/kaldi_io/src/init.c index e8b4ea6..efe3ff7 100644 --- a/kaldi_io/src/init.c +++ b/kaldi_io/src/init.c @@ -99,8 +99,7 @@ static const luaL_Reg lookup_feat_repo_methods[] = { static int label_repo_new(lua_State *L) { const char *targets_rspecifier = luaL_checkstring(L, 1); - const char *fmt = luaL_checkstring(L, 2); - KaldiLabelRepo *repo = kaldi_label_repo_new(targets_rspecifier, fmt); + KaldiLabelRepo *repo = kaldi_label_repo_new(targets_rspecifier); luaT_pushudata(L, repo, nerv_kaldi_label_repo_tname); return 1; } |