summaryrefslogtreecommitdiff
path: root/htk_io
diff options
context:
space:
mode:
authoruphantom <[email protected]>2015-08-28 18:37:12 +0800
committeruphantom <[email protected]>2015-08-28 18:37:12 +0800
commite56c91ff6eecdb1663bb1722a4081ef2f190e9c0 (patch)
tree06cbe7394ba26c6a2657fddc1ebc59006fa5eee5 /htk_io
parent9e1a0931be43ea80fe7d41154007839b637d4e08 (diff)
suport multi-thread reader
Diffstat (limited to 'htk_io')
-rw-r--r--htk_io/init.lua30
-rw-r--r--htk_io/src/cwrapper.cpp49
-rw-r--r--htk_io/src/cwrapper.h9
-rw-r--r--htk_io/src/init.c63
-rw-r--r--htk_io/tools/tnet_to_nerv.cpp78
5 files changed, 202 insertions, 27 deletions
diff --git a/htk_io/init.lua b/htk_io/init.lua
index b360b67..677d3e9 100644
--- a/htk_io/init.lua
+++ b/htk_io/init.lua
@@ -1,8 +1,9 @@
require 'libhtkio'
require 'speech_utils'
+require 'threads'
local TNetReader = nerv.class("nerv.TNetReader", "nerv.DataReader")
-function TNetReader:__init(global_conf, reader_conf)
+function TNetReader:__init(global_conf, reader_conf, feat_repo_shareid, data_mutex_shareid)
self.feat_id = reader_conf.id
self.frm_ext = reader_conf.frm_ext
self.gconf = global_conf
@@ -10,9 +11,22 @@ function TNetReader:__init(global_conf, reader_conf)
if self.debug == nil then
self.debug = false
end
- self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file,
+
+ if feat_repo_shareid ~= nil then
+ self.feat_repo = nerv.TNetFeatureRepo(feat_repo_shareid)
+ else
+ self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file,
reader_conf.conf_file,
reader_conf.frm_ext)
+ end
+
+ --print(self.feat_repo)
+ self.data_mutex = nil
+ if data_mutex_shareid ~= nil then
+ self.data_mutex = threads.Mutex(data_mutex_shareid)
+ end
+
+
self.lab_repo = {}
if reader_conf.mlfs then
for id, mlf_spec in pairs(reader_conf.mlfs) do
@@ -26,7 +40,14 @@ function TNetReader:__init(global_conf, reader_conf)
end
function TNetReader:get_data()
+ if self.data_mutex ~= nil then
+ self.data_mutex:lock()
+ end
+
if self.feat_repo:is_end() then
+ if self.data_mutex ~= nil then
+ self.data_mutex:unlock()
+ end
return nil
end
local res = {}
@@ -66,6 +87,11 @@ function TNetReader:get_data()
end
-- move the pointer to next
self.feat_repo:next()
+
+ if self.data_mutex ~= nil then
+ self.data_mutex:unlock()
+ end
+
collectgarbage("collect")
return res
end
diff --git a/htk_io/src/cwrapper.cpp b/htk_io/src/cwrapper.cpp
index b7ce2d5..66cde23 100644
--- a/htk_io/src/cwrapper.cpp
+++ b/htk_io/src/cwrapper.cpp
@@ -8,6 +8,7 @@
extern "C" {
#include "cwrapper.h"
#include "string.h"
+#include "pthread.h"
#include "nerv/common.h"
extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status);
@@ -29,6 +30,7 @@ extern "C" {
const char* cvn_mask;
const char* cvg_file;
TNet::Matrix<float> feats_host; /* KaldiLib implementation */
+ int refcount;
};
TNetFeatureRepo *tnet_feature_repo_new(const char *p_script, const char *config, int context) {
@@ -53,6 +55,18 @@ extern "C" {
return repo;
}
+ TNetFeatureRepo *tnet_feature_repo_newWithId(long id)
+ {
+ TNetFeatureRepo *repo = (TNetFeatureRepo*)id;
+ __sync_fetch_and_add(&repo->refcount, 1);
+ return repo;
+ }
+
+ long tnet_feature_repo_id(TNetFeatureRepo *repo)
+ {
+ return (long)(repo);
+ }
+
Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L, int debug) {
Matrix *mat; /* nerv implementation */
repo->feature_repo.ReadFullMatrix(repo->feats_host);
@@ -93,12 +107,19 @@ extern "C" {
}
void tnet_feature_repo_destroy(TNetFeatureRepo *repo) {
- if (repo->cmn_mask)
- free(repo->cmn_path);
- if (repo->cvn_mask)
- free(repo->cvn_path);
- free(repo->p_deriv_win_lenghts);
- delete repo;
+ if (NULL != repo)
+ {
+ if(__sync_fetch_and_add(&repo->refcount, -1) == 1)
+ {
+ if (repo->cmn_mask)
+ free(repo->cmn_path);
+ if (repo->cvn_mask)
+ free(repo->cvn_path);
+ free(repo->p_deriv_win_lenghts);
+ delete repo;
+ repo = NULL;
+ }
+ }
}
struct TNetLabelRepo {
@@ -114,6 +135,16 @@ extern "C" {
return repo;
}
+ TNetLabelRepo *tnet_label_repo_newWithId(long id)
+ {
+ return (TNetLabelRepo*)id;
+ }
+
+ long tnet_label_repo_id(TNetLabelRepo *repo)
+ {
+ return (long)(repo);
+ }
+
Matrix *tnet_label_repo_read_utterance(TNetLabelRepo *repo,
size_t frames,
size_t sample_rate,
@@ -143,6 +174,10 @@ extern "C" {
}
void tnet_label_repo_destroy(TNetLabelRepo *repo) {
- delete repo;
+ if (NULL != repo)
+ {
+ delete repo;
+ repo = NULL;
+ }
}
}
diff --git a/htk_io/src/cwrapper.h b/htk_io/src/cwrapper.h
index e1bce6e..44e77cf 100644
--- a/htk_io/src/cwrapper.h
+++ b/htk_io/src/cwrapper.h
@@ -16,6 +16,10 @@ extern "C" {
void tnet_feature_repo_next(TNetFeatureRepo *repo);
int tnet_feature_repo_is_end(TNetFeatureRepo *repo);
void tnet_feature_repo_destroy(TNetFeatureRepo *repo);
+ TNetFeatureRepo *tnet_feature_repo_newWithId(long id);
+ long tnet_feature_repo_id(TNetFeatureRepo *repo);
+
+
typedef struct TNetLabelRepo TNetLabelRepo;
@@ -31,6 +35,11 @@ extern "C" {
int debug);
void tnet_label_repo_destroy(TNetLabelRepo *repo);
+
+ TNetLabelRepo *tnet_label_repo_newWithId(long id);
+
+ long tnet_label_repo_id(TNetLabelRepo *repo);
+
#ifdef __cplusplus
}
#endif
diff --git a/htk_io/src/init.c b/htk_io/src/init.c
index 8a1ec3b..04046e9 100644
--- a/htk_io/src/init.c
+++ b/htk_io/src/init.c
@@ -7,13 +7,37 @@ const char *nerv_tnet_label_repo_tname = "nerv.TNetLabelRepo";
const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat";
static int feat_repo_new(lua_State *L) {
- const char *scp_file = luaL_checkstring(L, 1);
- const char *conf = luaL_checkstring(L, 2);
- int frm_ext = luaL_checkinteger(L, 3);
- TNetFeatureRepo *repo = tnet_feature_repo_new(scp_file, conf, frm_ext);
+ TNetFeatureRepo *repo = NULL;
+ if(lua_gettop(L) == 1)
+ {
+ long id = luaL_checkinteger(L, 1);
+ repo = tnet_feature_repo_newWithId(id);
+ }
+ else
+ {
+ const char *scp_file = luaL_checkstring(L, 1);
+ const char *conf = luaL_checkstring(L, 2);
+ int frm_ext = luaL_checkinteger(L, 3);
+ repo = tnet_feature_repo_new(scp_file, conf, frm_ext);
+ }
+
luaT_pushudata(L, repo, nerv_tnet_feat_repo_tname);
return 1;
}
+static int feat_repo_id(lua_State *L) {
+ TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
+ lua_pushinteger(L, tnet_feature_repo_id(repo));
+ return 1;
+}
+
+static int feat_repo_tostring(lua_State *L)
+{
+ char str[128];
+ TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
+ snprintf(str, 128, "%s <%lx>", nerv_tnet_feat_repo_tname, tnet_feature_repo_id(repo));
+ lua_pushstring(L, str);
+ return 1;
+}
static int feat_repo_destroy(lua_State *L) {
TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
@@ -55,22 +79,40 @@ static const luaL_Reg feat_repo_methods[] = {
{"cur_tag", feat_repo_current_tag},
{"next", feat_repo_next},
{"is_end", feat_repo_is_end},
+ {"__tostring", feat_repo_tostring},
+ {"id", feat_repo_id},
{NULL, NULL}
};
static int label_repo_new(lua_State *L) {
- const char *mlf_file = luaL_checkstring(L, 1);
- const char *fmt = luaL_checkstring(L, 2);
- const char *arg = luaL_checkstring(L, 3);
- const char *dir = luaL_checkstring(L, 4);
- const char *ext = luaL_checkstring(L, 5);
- TNetLabelRepo *repo = tnet_label_repo_new(
+ TNetLabelRepo *repo = NULL;
+ if(lua_gettop(L) == 1)
+ {
+ long id = luaL_checkinteger(L, 1);
+ repo = tnet_label_repo_newWithId(id);
+ }
+ else
+ {
+ const char *mlf_file = luaL_checkstring(L, 1);
+ const char *fmt = luaL_checkstring(L, 2);
+ const char *arg = luaL_checkstring(L, 3);
+ const char *dir = luaL_checkstring(L, 4);
+ const char *ext = luaL_checkstring(L, 5);
+ repo = tnet_label_repo_new(
mlf_file, fmt, arg,
dir, ext);
+ }
luaT_pushudata(L, repo, nerv_tnet_label_repo_tname);
return 1;
}
+static int label_repo_id(lua_State *L) {
+ TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
+ lua_pushinteger(L, tnet_label_repo_id(repo));
+ return 1;
+}
+
+
static int label_repo_read_utterance(lua_State *L) {
TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname);
@@ -95,6 +137,7 @@ static int label_repo_destroy(lua_State *L) {
static const luaL_Reg label_repo_methods[] = {
{"get_utter", label_repo_read_utterance},
+ {"id", label_repo_id},
{NULL, NULL}
};
diff --git a/htk_io/tools/tnet_to_nerv.cpp b/htk_io/tools/tnet_to_nerv.cpp
index a779a25..f96781a 100644
--- a/htk_io/tools/tnet_to_nerv.cpp
+++ b/htk_io/tools/tnet_to_nerv.cpp
@@ -2,14 +2,22 @@
#include <fstream>
#include <string>
#include <cstring>
-#include <cstdlib>
+#include <stdlib.h>
char token[1024];
char output[1024];
double **mat;
int main(int argc, char **argv) {
+
+ if (argc != 3)
+ {
+ printf("%s tnet.model.in nerv.model.out\n", argv[0]);
+ }
+
std::ofstream fout;
- fout.open(argv[1]);
- int cnt = 0;
+ freopen(argv[1], "r", stdin);
+ fout.open(argv[2]);
+ int cnt = 0, bias = 1, win = 1;
+ long length = 0, base = 0;
while (scanf("%s", token) != EOF)
{
int nrow, ncol;
@@ -19,13 +27,13 @@ int main(int argc, char **argv) {
scanf("%d %d", &ncol, &nrow);
scanf("%s %d %d", token, &ncol, &nrow);
printf("%d %d\n", nrow, ncol);
- mat = (double **)malloc(nrow * sizeof(double *));
+ mat = (double **)malloc(nrow * sizeof(double *));
for (i = 0; i < nrow; i++)
mat[i] = (double *)malloc(ncol * sizeof(double));
for (j = 0; j < ncol; j++)
for (i = 0; i < nrow; i++)
scanf("%lf", mat[i] + j);
- long base = fout.tellp();
+ base = fout.tellp();
sprintf(output, "%16d", 0);
fout << output;
sprintf(output, "{type=\"nerv.LinearTransParam\",id=\"affine%d_ltp\"}\n",
@@ -38,10 +46,8 @@ int main(int argc, char **argv) {
for (j = 0; j < ncol; j++)
fout << mat[i][j] << " ";
fout << std::endl;
- free(mat[i]);
}
- free(mat);
- long length = fout.tellp() - base;
+ length = fout.tellp() - base;
fout.seekp(base);
sprintf(output, "[%13lu]\n", length);
fout << output;
@@ -69,6 +75,62 @@ int main(int argc, char **argv) {
cnt++;
}
}
+ else if (strcmp(token, "<bias>") == 0)
+ {
+ scanf("%d %d", &ncol, &nrow);
+ scanf("%s %d", token, &ncol);
+ base = fout.tellp();
+ nrow = 1;
+ mat = (double **)malloc(nrow * sizeof(double *));
+ for (i = 0; i < nrow; i++)
+ mat[i] = (double *)malloc(ncol * sizeof(double));
+ for (j = 0; j < ncol; j++)
+ for (i = 0; i < nrow; i++)
+ scanf("%lf", mat[i] + j);
+ sprintf(output, "%16d", 0);
+ fout << output;
+ sprintf(output, "{type=\"nerv.MatrixParam\",id=\"bias%d\"}\n",bias);
+ fout << output;
+ sprintf(output, "1 %d\n", ncol);
+ fout << output;
+ for (j = 0; j < ncol; j++)
+ fout << mat[0][j] << " ";
+ fout << std::endl;
+ length = fout.tellp() - base;
+ fout.seekp(base);
+ sprintf(output, "[%13lu]\n", length);
+ fout << output;
+ fout.seekp(0, std::ios_base::end);
+ bias++;
+ }
+ else if (strcmp(token, "<window>") == 0)
+ {
+ scanf("%d %d", &ncol, &nrow);
+ scanf("%s %d", token, &ncol);
+ base = fout.tellp();
+ nrow = 1;
+ mat = (double **)malloc(nrow * sizeof(double *));
+ for (i = 0; i < nrow; i++)
+ mat[i] = (double *)malloc(ncol * sizeof(double));
+ for (j = 0; j < ncol; j++)
+ for (i = 0; i < nrow; i++)
+ scanf("%lf", mat[i] + j);
+ sprintf(output, "%16d", 0);
+ fout << output;
+ sprintf(output, "{type=\"nerv.MatrixParam\",id=\"window%d\"}\n",win);
+ fout << output;
+ sprintf(output, "1 %d\n", ncol);
+ fout << output;
+ for (j = 0; j < ncol; j++)
+ fout << mat[0][j] << " ";
+ fout << std::endl;
+ length = fout.tellp() - base;
+ fout.seekp(base);
+ sprintf(output, "[%13lu]\n", length);
+ fout << output;
+ fout.seekp(0, std::ios_base::end);
+ win++;
+ }
}
return 0;
}