#include "KaldiLib/Features.h"
#include "KaldiLib/Labels.h"
#include "KaldiLib/Common.h"
#include "KaldiLib/UserInterface.h"
#include <string>
#define SNAME "TNET"
extern "C" {
#include "cwrapper.h"
#include "string.h"
#include "nerv/lib/common.h"
#include "nerv/lib/matrix/mmatrix.h"
extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, MContext *context, Status *status);
struct TNetFeatureRepo {
TNet::FeatureRepository feature_repo;
TNet::UserInterface ui;
bool swap_features;
int target_kind;
int deriv_order;
int* p_deriv_win_lenghts;
int start_frm_ext;
int end_frm_ext;
char* cmn_path;
char* cmn_file;
const char* cmn_mask;
char* cvn_path;
char* cvn_file;
const char* cvn_mask;
const char* cvg_file;
TNet::Matrix<float> feats_host; /* KaldiLib implementation */
};
TNetFeatureRepo *tnet_feature_repo_new(const char *p_script, const char *config, int context) {
TNetFeatureRepo *repo = new TNetFeatureRepo();
repo->ui.ReadConfig(config);
repo->swap_features = !repo->ui.GetBool(SNAME":NATURALREADORDER", TNet::IsBigEndian());
/* load defaults */
repo->target_kind = repo->ui.GetFeatureParams(&repo->deriv_order,
&repo->p_deriv_win_lenghts,
&repo->start_frm_ext, &repo->end_frm_ext,
&repo->cmn_path, &repo->cmn_file, &repo->cmn_mask,
&repo->cvn_path, &repo->cvn_file, &repo->cvn_mask,
&repo->cvg_file, SNAME":", 0);
repo->start_frm_ext = repo->end_frm_ext = context;
repo->feature_repo.Init(repo->swap_features,
repo->start_frm_ext, repo->end_frm_ext, repo->target_kind,
repo->deriv_order, repo->p_deriv_win_lenghts,
repo->cmn_path, repo->cmn_mask,
repo->cvn_path, repo->cvn_mask, repo->cvg_file);
repo->feature_repo.AddFileList(p_script);
repo->feature_repo.Rewind();
return repo;
}
Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L,
int debug, MContext *context) {
Matrix *mat; /* nerv implementation */
repo->feature_repo.ReadFullMatrix(repo->feats_host);
std::string utter_str = repo->feature_repo.Current().Logical();
repo->feats_host.CheckData(utter_str);
int n = repo->feats_host.Rows();
int m = repo->feats_host.Cols();
Status status;
mat = nerv_matrix_host_float_create(n, m, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
size_t stride = mat->stride;
if (debug)
fprintf(stderr, "[tnet] feature: %s %d %d\n", utter_str.c_str(), n, m);
for (int i = 0; i < n; i++)
{
float *row = repo->feats_host.pRowData(i);
float *nerv_row = (float *)((char *)mat->data.f + i * stride);
/* use memmove to copy the row, since KaldiLib uses compact storage */
memmove(nerv_row, row, sizeof(float) * m);
}
return mat;
}
void tnet_feature_repo_next(TNetFeatureRepo *repo) {
repo->feature_repo.MoveNext();
}
int tnet_feature_repo_is_end(TNetFeatureRepo *repo) {
return repo->feature_repo.EndOfList();
}
size_t tnet_feature_repo_current_samplerate(TNetFeatureRepo *repo) {
return repo->feature_repo.CurrentHeader().mSamplePeriod;
}
const char *tnet_feature_repo_current_tag(TNetFeatureRepo *repo) {
return repo->feature_repo.Current().Logical().c_str();
}
void tnet_feature_repo_destroy(TNetFeatureRepo *repo) {
if (repo->cmn_mask)
free(repo->cmn_path);
if (repo->cvn_mask)
free(repo->cvn_path);
free(repo->p_deriv_win_lenghts);
delete repo;
}
struct TNetLabelRepo {
TNet::LabelRepository label_repo;
};
TNetLabelRepo *tnet_label_repo_new(const char *mlf, const char *fmt,
const char *fmt_arg, const char *dir,
const char *ext) {
TNetLabelRepo *repo = new TNetLabelRepo();
repo->label_repo.InitExt(mlf, fmt, fmt_arg, dir, ext);
/* repo->label_repo.Init(mlf, fmt_arg, dir, ext); */
return repo;
}
Matrix *tnet_label_repo_read_utterance(TNetLabelRepo *repo,
size_t frames,
size_t sample_rate,
const char *tag,
lua_State *L,
int debug,
MContext *context) {
std::vector<TNet::Matrix<float> > labs_hosts; /* KaldiLib implementation */
Matrix *mat;
repo->label_repo.GenDesiredMatrixExt(labs_hosts, frames,
sample_rate, tag);
int n = labs_hosts[0].Rows();
int m = labs_hosts[0].Cols();
Status status;
mat = nerv_matrix_host_float_create(n, m, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
size_t stride = mat->stride;
if (debug)
fprintf(stderr, "[tnet] label: %s %d %d\n", tag, n, m);
for (int i = 0; i < n; i++)
{
float *row = labs_hosts[0].pRowData(i);
float *nerv_row = (float *)((char *)mat->data.f + i * stride);
/* use memmove to copy the row, since KaldiLib uses compact storage */
memmove(nerv_row, row, sizeof(float) * m);
}
return mat;
}
void tnet_label_repo_destroy(TNetLabelRepo *repo) {
delete repo;
}
}