summaryrefslogtreecommitdiff
path: root/htk_io/src
diff options
context:
space:
mode:
authorTed Yin <[email protected]>2015-08-31 16:16:00 +0800
committerTed Yin <[email protected]>2015-08-31 16:16:00 +0800
commit014bbfb7e64999a75f9e0dc52267a36741281624 (patch)
tree127888bcefaf8ce4991bb4c173d6538f1172f35f /htk_io/src
parent9e1a0931be43ea80fe7d41154007839b637d4e08 (diff)
parent2196e0a591b9bc254aa95e180adf188fd70ded68 (diff)
Merge pull request #4 from uphantom/master
support fastnn multi-thread TNetReader
Diffstat (limited to 'htk_io/src')
-rw-r--r--htk_io/src/cwrapper.cpp49
-rw-r--r--htk_io/src/cwrapper.h9
-rw-r--r--htk_io/src/init.c63
3 files changed, 104 insertions, 17 deletions
diff --git a/htk_io/src/cwrapper.cpp b/htk_io/src/cwrapper.cpp
index b7ce2d5..66cde23 100644
--- a/htk_io/src/cwrapper.cpp
+++ b/htk_io/src/cwrapper.cpp
@@ -8,6 +8,7 @@
extern "C" {
#include "cwrapper.h"
#include "string.h"
+#include "pthread.h"
#include "nerv/common.h"
extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status);
@@ -29,6 +30,7 @@ extern "C" {
const char* cvn_mask;
const char* cvg_file;
TNet::Matrix<float> feats_host; /* KaldiLib implementation */
+ int refcount;
};
TNetFeatureRepo *tnet_feature_repo_new(const char *p_script, const char *config, int context) {
@@ -53,6 +55,18 @@ extern "C" {
return repo;
}
+ TNetFeatureRepo *tnet_feature_repo_newWithId(long id)
+ {
+ TNetFeatureRepo *repo = (TNetFeatureRepo*)id;
+ __sync_fetch_and_add(&repo->refcount, 1);
+ return repo;
+ }
+
+ long tnet_feature_repo_id(TNetFeatureRepo *repo)
+ {
+ return (long)(repo);
+ }
+
Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L, int debug) {
Matrix *mat; /* nerv implementation */
repo->feature_repo.ReadFullMatrix(repo->feats_host);
@@ -93,12 +107,19 @@ extern "C" {
}
void tnet_feature_repo_destroy(TNetFeatureRepo *repo) {
- if (repo->cmn_mask)
- free(repo->cmn_path);
- if (repo->cvn_mask)
- free(repo->cvn_path);
- free(repo->p_deriv_win_lenghts);
- delete repo;
+ if (NULL != repo)
+ {
+ if(__sync_fetch_and_add(&repo->refcount, -1) == 1)
+ {
+ if (repo->cmn_mask)
+ free(repo->cmn_path);
+ if (repo->cvn_mask)
+ free(repo->cvn_path);
+ free(repo->p_deriv_win_lenghts);
+ delete repo;
+ repo = NULL;
+ }
+ }
}
struct TNetLabelRepo {
@@ -114,6 +135,16 @@ extern "C" {
return repo;
}
+ TNetLabelRepo *tnet_label_repo_newWithId(long id)
+ {
+ return (TNetLabelRepo*)id;
+ }
+
+ long tnet_label_repo_id(TNetLabelRepo *repo)
+ {
+ return (long)(repo);
+ }
+
Matrix *tnet_label_repo_read_utterance(TNetLabelRepo *repo,
size_t frames,
size_t sample_rate,
@@ -143,6 +174,10 @@ extern "C" {
}
void tnet_label_repo_destroy(TNetLabelRepo *repo) {
- delete repo;
+ if (NULL != repo)
+ {
+ delete repo;
+ repo = NULL;
+ }
}
}
diff --git a/htk_io/src/cwrapper.h b/htk_io/src/cwrapper.h
index e1bce6e..44e77cf 100644
--- a/htk_io/src/cwrapper.h
+++ b/htk_io/src/cwrapper.h
@@ -16,6 +16,10 @@ extern "C" {
void tnet_feature_repo_next(TNetFeatureRepo *repo);
int tnet_feature_repo_is_end(TNetFeatureRepo *repo);
void tnet_feature_repo_destroy(TNetFeatureRepo *repo);
+ TNetFeatureRepo *tnet_feature_repo_newWithId(long id);
+ long tnet_feature_repo_id(TNetFeatureRepo *repo);
+
+
typedef struct TNetLabelRepo TNetLabelRepo;
@@ -31,6 +35,11 @@ extern "C" {
int debug);
void tnet_label_repo_destroy(TNetLabelRepo *repo);
+
+ TNetLabelRepo *tnet_label_repo_newWithId(long id);
+
+ long tnet_label_repo_id(TNetLabelRepo *repo);
+
#ifdef __cplusplus
}
#endif
diff --git a/htk_io/src/init.c b/htk_io/src/init.c
index 8a1ec3b..04046e9 100644
--- a/htk_io/src/init.c
+++ b/htk_io/src/init.c
@@ -7,13 +7,37 @@ const char *nerv_tnet_label_repo_tname = "nerv.TNetLabelRepo";
const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat";
static int feat_repo_new(lua_State *L) {
- const char *scp_file = luaL_checkstring(L, 1);
- const char *conf = luaL_checkstring(L, 2);
- int frm_ext = luaL_checkinteger(L, 3);
- TNetFeatureRepo *repo = tnet_feature_repo_new(scp_file, conf, frm_ext);
+ TNetFeatureRepo *repo = NULL;
+ if(lua_gettop(L) == 1)
+ {
+ long id = luaL_checkinteger(L, 1);
+ repo = tnet_feature_repo_newWithId(id);
+ }
+ else
+ {
+ const char *scp_file = luaL_checkstring(L, 1);
+ const char *conf = luaL_checkstring(L, 2);
+ int frm_ext = luaL_checkinteger(L, 3);
+ repo = tnet_feature_repo_new(scp_file, conf, frm_ext);
+ }
+
luaT_pushudata(L, repo, nerv_tnet_feat_repo_tname);
return 1;
}
+static int feat_repo_id(lua_State *L) {
+ TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
+ lua_pushinteger(L, tnet_feature_repo_id(repo));
+ return 1;
+}
+
+static int feat_repo_tostring(lua_State *L)
+{
+ char str[128];
+ TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
+ snprintf(str, 128, "%s <%lx>", nerv_tnet_feat_repo_tname, tnet_feature_repo_id(repo));
+ lua_pushstring(L, str);
+ return 1;
+}
static int feat_repo_destroy(lua_State *L) {
TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
@@ -55,22 +79,40 @@ static const luaL_Reg feat_repo_methods[] = {
{"cur_tag", feat_repo_current_tag},
{"next", feat_repo_next},
{"is_end", feat_repo_is_end},
+ {"__tostring", feat_repo_tostring},
+ {"id", feat_repo_id},
{NULL, NULL}
};
static int label_repo_new(lua_State *L) {
- const char *mlf_file = luaL_checkstring(L, 1);
- const char *fmt = luaL_checkstring(L, 2);
- const char *arg = luaL_checkstring(L, 3);
- const char *dir = luaL_checkstring(L, 4);
- const char *ext = luaL_checkstring(L, 5);
- TNetLabelRepo *repo = tnet_label_repo_new(
+ TNetLabelRepo *repo = NULL;
+ if(lua_gettop(L) == 1)
+ {
+ long id = luaL_checkinteger(L, 1);
+ repo = tnet_label_repo_newWithId(id);
+ }
+ else
+ {
+ const char *mlf_file = luaL_checkstring(L, 1);
+ const char *fmt = luaL_checkstring(L, 2);
+ const char *arg = luaL_checkstring(L, 3);
+ const char *dir = luaL_checkstring(L, 4);
+ const char *ext = luaL_checkstring(L, 5);
+ repo = tnet_label_repo_new(
mlf_file, fmt, arg,
dir, ext);
+ }
luaT_pushudata(L, repo, nerv_tnet_label_repo_tname);
return 1;
}
+static int label_repo_id(lua_State *L) {
+ TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
+ lua_pushinteger(L, tnet_label_repo_id(repo));
+ return 1;
+}
+
+
static int label_repo_read_utterance(lua_State *L) {
TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname);
@@ -95,6 +137,7 @@ static int label_repo_destroy(lua_State *L) {
static const luaL_Reg label_repo_methods[] = {
{"get_utter", label_repo_read_utterance},
+ {"id", label_repo_id},
{NULL, NULL}
};