diff options
Diffstat (limited to 'htk_io')
-rw-r--r-- | htk_io/init.lua | 30 | ||||
-rw-r--r-- | htk_io/src/cwrapper.cpp | 49 | ||||
-rw-r--r-- | htk_io/src/cwrapper.h | 9 | ||||
-rw-r--r-- | htk_io/src/init.c | 63 | ||||
-rw-r--r-- | htk_io/tools/tnet_to_nerv.cpp | 78 |
5 files changed, 202 insertions, 27 deletions
diff --git a/htk_io/init.lua b/htk_io/init.lua index b360b67..677d3e9 100644 --- a/htk_io/init.lua +++ b/htk_io/init.lua @@ -1,8 +1,9 @@ require 'libhtkio' require 'speech_utils' +require 'threads' local TNetReader = nerv.class("nerv.TNetReader", "nerv.DataReader") -function TNetReader:__init(global_conf, reader_conf) +function TNetReader:__init(global_conf, reader_conf, feat_repo_shareid, data_mutex_shareid) self.feat_id = reader_conf.id self.frm_ext = reader_conf.frm_ext self.gconf = global_conf @@ -10,9 +11,22 @@ function TNetReader:__init(global_conf, reader_conf) if self.debug == nil then self.debug = false end - self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file, + + if feat_repo_shareid ~= nil then + self.feat_repo = nerv.TNetFeatureRepo(feat_repo_shareid) + else + self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file, reader_conf.conf_file, reader_conf.frm_ext) + end + + --print(self.feat_repo) + self.data_mutex = nil + if data_mutex_shareid ~= nil then + self.data_mutex = threads.Mutex(data_mutex_shareid) + end + + self.lab_repo = {} if reader_conf.mlfs then for id, mlf_spec in pairs(reader_conf.mlfs) do @@ -26,7 +40,14 @@ function TNetReader:__init(global_conf, reader_conf) end function TNetReader:get_data() + if self.data_mutex ~= nil then + self.data_mutex:lock() + end + if self.feat_repo:is_end() then + if self.data_mutex ~= nil then + self.data_mutex:unlock() + end return nil end local res = {} @@ -66,6 +87,11 @@ function TNetReader:get_data() end -- move the pointer to next self.feat_repo:next() + + if self.data_mutex ~= nil then + self.data_mutex:unlock() + end + collectgarbage("collect") return res end diff --git a/htk_io/src/cwrapper.cpp b/htk_io/src/cwrapper.cpp index b7ce2d5..66cde23 100644 --- a/htk_io/src/cwrapper.cpp +++ b/htk_io/src/cwrapper.cpp @@ -8,6 +8,7 @@ extern "C" { #include "cwrapper.h" #include "string.h" +#include "pthread.h" #include "nerv/common.h" extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status); @@ -29,6 +30,7 @@ extern "C" { const char* cvn_mask; const char* cvg_file; TNet::Matrix<float> feats_host; /* KaldiLib implementation */ + int refcount; }; TNetFeatureRepo *tnet_feature_repo_new(const char *p_script, const char *config, int context) { @@ -53,6 +55,18 @@ extern "C" { return repo; } + TNetFeatureRepo *tnet_feature_repo_newWithId(long id) + { + TNetFeatureRepo *repo = (TNetFeatureRepo*)id; + __sync_fetch_and_add(&repo->refcount, 1); + return repo; + } + + long tnet_feature_repo_id(TNetFeatureRepo *repo) + { + return (long)(repo); + } + Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L, int debug) { Matrix *mat; /* nerv implementation */ repo->feature_repo.ReadFullMatrix(repo->feats_host); @@ -93,12 +107,19 @@ extern "C" { } void tnet_feature_repo_destroy(TNetFeatureRepo *repo) { - if (repo->cmn_mask) - free(repo->cmn_path); - if (repo->cvn_mask) - free(repo->cvn_path); - free(repo->p_deriv_win_lenghts); - delete repo; + if (NULL != repo) + { + if(__sync_fetch_and_add(&repo->refcount, -1) == 1) + { + if (repo->cmn_mask) + free(repo->cmn_path); + if (repo->cvn_mask) + free(repo->cvn_path); + free(repo->p_deriv_win_lenghts); + delete repo; + repo = NULL; + } + } } struct TNetLabelRepo { @@ -114,6 +135,16 @@ extern "C" { return repo; } + TNetLabelRepo *tnet_label_repo_newWithId(long id) + { + return (TNetLabelRepo*)id; + } + + long tnet_label_repo_id(TNetLabelRepo *repo) + { + return (long)(repo); + } + Matrix *tnet_label_repo_read_utterance(TNetLabelRepo *repo, size_t frames, size_t sample_rate, @@ -143,6 +174,10 @@ extern "C" { } void tnet_label_repo_destroy(TNetLabelRepo *repo) { - delete repo; + if (NULL != repo) + { + delete repo; + repo = NULL; + } } } diff --git a/htk_io/src/cwrapper.h b/htk_io/src/cwrapper.h index e1bce6e..44e77cf 100644 --- a/htk_io/src/cwrapper.h +++ b/htk_io/src/cwrapper.h @@ -16,6 +16,10 @@ extern "C" { void tnet_feature_repo_next(TNetFeatureRepo *repo); int tnet_feature_repo_is_end(TNetFeatureRepo *repo); void tnet_feature_repo_destroy(TNetFeatureRepo *repo); + TNetFeatureRepo *tnet_feature_repo_newWithId(long id); + long tnet_feature_repo_id(TNetFeatureRepo *repo); + + typedef struct TNetLabelRepo TNetLabelRepo; @@ -31,6 +35,11 @@ extern "C" { int debug); void tnet_label_repo_destroy(TNetLabelRepo *repo); + + TNetLabelRepo *tnet_label_repo_newWithId(long id); + + long tnet_label_repo_id(TNetLabelRepo *repo); + #ifdef __cplusplus } #endif diff --git a/htk_io/src/init.c b/htk_io/src/init.c index 8a1ec3b..04046e9 100644 --- a/htk_io/src/init.c +++ b/htk_io/src/init.c @@ -7,13 +7,37 @@ const char *nerv_tnet_label_repo_tname = "nerv.TNetLabelRepo"; const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat"; static int feat_repo_new(lua_State *L) { - const char *scp_file = luaL_checkstring(L, 1); - const char *conf = luaL_checkstring(L, 2); - int frm_ext = luaL_checkinteger(L, 3); - TNetFeatureRepo *repo = tnet_feature_repo_new(scp_file, conf, frm_ext); + TNetFeatureRepo *repo = NULL; + if(lua_gettop(L) == 1) + { + long id = luaL_checkinteger(L, 1); + repo = tnet_feature_repo_newWithId(id); + } + else + { + const char *scp_file = luaL_checkstring(L, 1); + const char *conf = luaL_checkstring(L, 2); + int frm_ext = luaL_checkinteger(L, 3); + repo = tnet_feature_repo_new(scp_file, conf, frm_ext); + } + luaT_pushudata(L, repo, nerv_tnet_feat_repo_tname); return 1; } +static int feat_repo_id(lua_State *L) { + TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); + lua_pushinteger(L, tnet_feature_repo_id(repo)); + return 1; +} + +static int feat_repo_tostring(lua_State *L) +{ + char str[128]; + TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); + snprintf(str, 128, "%s <%lx>", nerv_tnet_feat_repo_tname, tnet_feature_repo_id(repo)); + lua_pushstring(L, str); + return 1; +} static int feat_repo_destroy(lua_State *L) { TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); @@ -55,22 +79,40 @@ static const luaL_Reg feat_repo_methods[] = { {"cur_tag", feat_repo_current_tag}, {"next", feat_repo_next}, {"is_end", feat_repo_is_end}, + {"__tostring", feat_repo_tostring}, + {"id", feat_repo_id}, {NULL, NULL} }; static int label_repo_new(lua_State *L) { - const char *mlf_file = luaL_checkstring(L, 1); - const char *fmt = luaL_checkstring(L, 2); - const char *arg = luaL_checkstring(L, 3); - const char *dir = luaL_checkstring(L, 4); - const char *ext = luaL_checkstring(L, 5); - TNetLabelRepo *repo = tnet_label_repo_new( + TNetLabelRepo *repo = NULL; + if(lua_gettop(L) == 1) + { + long id = luaL_checkinteger(L, 1); + repo = tnet_label_repo_newWithId(id); + } + else + { + const char *mlf_file = luaL_checkstring(L, 1); + const char *fmt = luaL_checkstring(L, 2); + const char *arg = luaL_checkstring(L, 3); + const char *dir = luaL_checkstring(L, 4); + const char *ext = luaL_checkstring(L, 5); + repo = tnet_label_repo_new( mlf_file, fmt, arg, dir, ext); + } luaT_pushudata(L, repo, nerv_tnet_label_repo_tname); return 1; } +static int label_repo_id(lua_State *L) { + TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname); + lua_pushinteger(L, tnet_label_repo_id(repo)); + return 1; +} + + static int label_repo_read_utterance(lua_State *L) { TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname); TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname); @@ -95,6 +137,7 @@ static int label_repo_destroy(lua_State *L) { static const luaL_Reg label_repo_methods[] = { {"get_utter", label_repo_read_utterance}, + {"id", label_repo_id}, {NULL, NULL} }; diff --git a/htk_io/tools/tnet_to_nerv.cpp b/htk_io/tools/tnet_to_nerv.cpp index a779a25..f96781a 100644 --- a/htk_io/tools/tnet_to_nerv.cpp +++ b/htk_io/tools/tnet_to_nerv.cpp @@ -2,14 +2,22 @@ #include <fstream> #include <string> #include <cstring> -#include <cstdlib> +#include <stdlib.h> char token[1024]; char output[1024]; double **mat; int main(int argc, char **argv) { + + if (argc != 3) + { + printf("%s tnet.model.in nerv.model.out\n", argv[0]); + } + std::ofstream fout; - fout.open(argv[1]); - int cnt = 0; + freopen(argv[1], "r", stdin); + fout.open(argv[2]); + int cnt = 0, bias = 1, win = 1; + long length = 0, base = 0; while (scanf("%s", token) != EOF) { int nrow, ncol; @@ -19,13 +27,13 @@ int main(int argc, char **argv) { scanf("%d %d", &ncol, &nrow); scanf("%s %d %d", token, &ncol, &nrow); printf("%d %d\n", nrow, ncol); - mat = (double **)malloc(nrow * sizeof(double *)); + mat = (double **)malloc(nrow * sizeof(double *)); for (i = 0; i < nrow; i++) mat[i] = (double *)malloc(ncol * sizeof(double)); for (j = 0; j < ncol; j++) for (i = 0; i < nrow; i++) scanf("%lf", mat[i] + j); - long base = fout.tellp(); + base = fout.tellp(); sprintf(output, "%16d", 0); fout << output; sprintf(output, "{type=\"nerv.LinearTransParam\",id=\"affine%d_ltp\"}\n", @@ -38,10 +46,8 @@ int main(int argc, char **argv) { for (j = 0; j < ncol; j++) fout << mat[i][j] << " "; fout << std::endl; - free(mat[i]); } - free(mat); - long length = fout.tellp() - base; + length = fout.tellp() - base; fout.seekp(base); sprintf(output, "[%13lu]\n", length); fout << output; @@ -69,6 +75,62 @@ int main(int argc, char **argv) { cnt++; } } + else if (strcmp(token, "<bias>") == 0) + { + scanf("%d %d", &ncol, &nrow); + scanf("%s %d", token, &ncol); + base = fout.tellp(); + nrow = 1; + mat = (double **)malloc(nrow * sizeof(double *)); + for (i = 0; i < nrow; i++) + mat[i] = (double *)malloc(ncol * sizeof(double)); + for (j = 0; j < ncol; j++) + for (i = 0; i < nrow; i++) + scanf("%lf", mat[i] + j); + sprintf(output, "%16d", 0); + fout << output; + sprintf(output, "{type=\"nerv.MatrixParam\",id=\"bias%d\"}\n",bias); + fout << output; + sprintf(output, "1 %d\n", ncol); + fout << output; + for (j = 0; j < ncol; j++) + fout << mat[0][j] << " "; + fout << std::endl; + length = fout.tellp() - base; + fout.seekp(base); + sprintf(output, "[%13lu]\n", length); + fout << output; + fout.seekp(0, std::ios_base::end); + bias++; + } + else if (strcmp(token, "<window>") == 0) + { + scanf("%d %d", &ncol, &nrow); + scanf("%s %d", token, &ncol); + base = fout.tellp(); + nrow = 1; + mat = (double **)malloc(nrow * sizeof(double *)); + for (i = 0; i < nrow; i++) + mat[i] = (double *)malloc(ncol * sizeof(double)); + for (j = 0; j < ncol; j++) + for (i = 0; i < nrow; i++) + scanf("%lf", mat[i] + j); + sprintf(output, "%16d", 0); + fout << output; + sprintf(output, "{type=\"nerv.MatrixParam\",id=\"window%d\"}\n",win); + fout << output; + sprintf(output, "1 %d\n", ncol); + fout << output; + for (j = 0; j < ncol; j++) + fout << mat[0][j] << " "; + fout << std::endl; + length = fout.tellp() - base; + fout.seekp(base); + sprintf(output, "[%13lu]\n", length); + fout << output; + fout.seekp(0, std::ios_base::end); + win++; + } } return 0; } |