From be0e51156e1af9d619160bba1aa7c2eb2df30731 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 5 Jun 2015 10:58:38 +0800 Subject: add debug flag --- init.lua | 19 +++++++++++++++++-- tnet_io/cwrapper.cpp | 11 +++++++---- tnet_io/cwrapper.h | 5 +++-- tnet_io/init.c | 12 ++++++++++-- tnet_io/test.c | 5 +++-- 5 files changed, 40 insertions(+), 12 deletions(-) diff --git a/init.lua b/init.lua index 39a1e9e..fec7209 100644 --- a/init.lua +++ b/init.lua @@ -6,6 +6,10 @@ function TNetReader:__init(global_conf, reader_conf) self.frm_ext = reader_conf.frm_ext self.gconf = global_conf self.global_transf = reader_conf.global_transf + self.debug = global_conf.debug + if self.debug == nil then + self.debug = false + end self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file, reader_conf.conf_file, reader_conf.frm_ext) @@ -26,22 +30,33 @@ function TNetReader:get_data() local res = {} local frm_ext = self.frm_ext local step = frm_ext * 2 + 1 - local feat_utter = self.feat_repo:cur_utter() + -- read HTK feature + local feat_utter = self.feat_repo:cur_utter(self.debug) + -- expand the feature local expanded = self.gconf.cumat_type(feat_utter:nrow(), feat_utter:ncol() * step) expanded:expand_frm(self.gconf.cumat_type.new_from_host(feat_utter), frm_ext) + -- rearrange the feature (``transpose'' operation in TNet) local rearranged = expanded:create() rearranged:rearrange_frm(expanded, step) + -- prepare for transf local input = {rearranged} local output = {rearranged:create()} + -- do transf self.global_transf:init(input[1]:nrow()) self.global_transf:propagate(input, output) + -- trim frames expanded = self.gconf.mmat_type(output[1]:nrow() - frm_ext * 2, output[1]:ncol()) output[1]:copy_toh(expanded, frm_ext, feat_utter:nrow() - frm_ext) res[self.feat_id] = expanded + -- add corresponding labels for id, repo in pairs(self.lab_repo) do - local lab_utter = repo:get_utter(self.feat_repo, expanded:nrow()) + local lab_utter = repo:get_utter(self.feat_repo, + expanded:nrow(), + self.debug) res[id] = lab_utter end + -- move the pointer to next self.feat_repo:next() + collectgarbage("collect") return res end diff --git a/tnet_io/cwrapper.cpp b/tnet_io/cwrapper.cpp index 4149557..800df2e 100644 --- a/tnet_io/cwrapper.cpp +++ b/tnet_io/cwrapper.cpp @@ -53,7 +53,7 @@ extern "C" { return repo; } - Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L) { + Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L, int debug) { Matrix *mat; /* nerv implementation */ repo->feature_repo.ReadFullMatrix(repo->feats_host); std::string utter_str = repo->feature_repo.Current().Logical(); @@ -62,7 +62,8 @@ extern "C" { int m = repo->feats_host.Cols(); mat = nerv_matrix_host_float_new_(L, n, m); size_t stride = mat->stride; - fprintf(stderr, "[tnet] feature: %s %d %d\n", utter_str.c_str(), n, m); + if (debug) + fprintf(stderr, "[tnet] feature: %s %d %d\n", utter_str.c_str(), n, m); for (int i = 0; i < n; i++) { float *row = repo->feats_host.pRowData(i); @@ -115,7 +116,8 @@ extern "C" { size_t frames, size_t sample_rate, const char *tag, - lua_State *L) { + lua_State *L, + int debug) { std::vector > labs_hosts; /* KaldiLib implementation */ Matrix *mat; repo->label_repo.GenDesiredMatrixExt(labs_hosts, frames, @@ -124,7 +126,8 @@ extern "C" { int m = labs_hosts[0].Cols(); mat = nerv_matrix_host_float_new_(L, n, m); size_t stride = mat->stride; - fprintf(stderr, "[tnet] label: %s %d %d\n", tag, n, m); + if (debug) + fprintf(stderr, "[tnet] label: %s %d %d\n", tag, n, m); for (int i = 0; i < n; i++) { float *row = labs_hosts[0].pRowData(i); diff --git a/tnet_io/cwrapper.h b/tnet_io/cwrapper.h index 54fb69b..c2ca1ba 100644 --- a/tnet_io/cwrapper.h +++ b/tnet_io/cwrapper.h @@ -10,7 +10,7 @@ extern "C" { TNetFeatureRepo *tnet_feature_repo_new(const char *scp, const char *config, int context); - Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L); + Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L, int debug); size_t tnet_feature_repo_current_samplerate(TNetFeatureRepo *repo); const char *tnet_feature_repo_current_tag(TNetFeatureRepo *repo); void tnet_feature_repo_next(TNetFeatureRepo *repo); @@ -27,7 +27,8 @@ extern "C" { size_t frames, size_t sample_rate, const char *tag, - lua_State *L); + lua_State *L, + int debug); void tnet_label_repo_destroy(TNetLabelRepo *repo); #ifdef __cplusplus diff --git a/tnet_io/init.c b/tnet_io/init.c index 16f6f37..da93b35 100644 --- a/tnet_io/init.c +++ b/tnet_io/init.c @@ -29,7 +29,11 @@ static int feat_repo_current_tag(lua_State *L) { static int feat_repo_current_utterance(lua_State *L) { TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); - Matrix *utter = tnet_feature_repo_read_utterance(repo, L); + int debug; + if (!lua_isboolean(L, 2)) + nerv_error(L, "debug flag should be a boolean"); + debug = lua_toboolean(L, 2); + Matrix *utter = tnet_feature_repo_read_utterance(repo, L, debug); luaT_pushudata(L, utter, nerv_matrix_host_float_tname); return 1; } @@ -71,10 +75,14 @@ static int label_repo_read_utterance(lua_State *L) { TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname); TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname); size_t frames = luaL_checkinteger(L, 3); + int debug; + if (!lua_isboolean(L, 4)) + nerv_error(L, "debug flag should be a boolean"); + debug = lua_toboolean(L, 4); Matrix *utter = tnet_label_repo_read_utterance(repo, frames, tnet_feature_repo_current_samplerate(feat_repo), - tnet_feature_repo_current_tag(feat_repo), L); + tnet_feature_repo_current_tag(feat_repo), L, debug); luaT_pushudata(L, utter, nerv_matrix_host_float_tname); return 1; } diff --git a/tnet_io/test.c b/tnet_io/test.c index 8c06805..6812ef1 100644 --- a/tnet_io/test.c +++ b/tnet_io/test.c @@ -22,7 +22,7 @@ int main() { "/slfs1/users/mfy43/swb_ivec/train_bp.scp", "/slfs1/users/mfy43/swb_ivec/plp_0_d_a.conf", 5); Matrix *feat_utter; - feat_utter = tnet_feature_repo_read_utterance(feat_repo, NULL); + feat_utter = tnet_feature_repo_read_utterance(feat_repo, NULL, 1); TNetLabelRepo *lab_repo = tnet_label_repo_new( "/slfs1/users/mfy43/swb_ivec/ref.mlf", @@ -33,7 +33,8 @@ int main() { Matrix *lab_utter = tnet_label_repo_read_utterance(lab_repo, feat_utter->nrow - 5 * 2, tnet_feature_repo_current_samplerate(feat_repo), - tnet_feature_repo_current_tag(feat_repo), NULL); + tnet_feature_repo_current_tag(feat_repo), NULL, + 1); print_nerv_matrix(lab_utter); return 0; } -- cgit v1.2.3-70-g09d2