#include <string>
#include <map>
#include "base/kaldi-common.h"
#include "hmm/posterior.h"
#include "util/table-types.h"
typedef kaldi::BaseFloat BaseFloat;
typedef std::map<std::string, kaldi::Matrix<BaseFloat> > StringToMatrix_t;
typedef std::map<std::string, std::string > StringToString_t;
extern "C" {
#include "cwrapper_kaldi.h"
#include "string.h"
#include "assert.h"
#include "nerv/lib/common.h"
#include "nerv/lib/matrix/mmatrix.h"
extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, MContext *context, Status *status);
extern Matrix *nerv_matrix_host_double_create(long nrow, long ncol, MContext *context, Status *status);
struct KaldiFeatureRepo {
kaldi::SequentialBaseFloatMatrixReader* feature_reader;
string utt;
};
KaldiFeatureRepo *kaldi_feature_repo_new(const char *feature_rspecifier) {
KaldiFeatureRepo *repo = new KaldiFeatureRepo();
repo->feature_reader = new kaldi::SequentialBaseFloatMatrixReader(string(feature_rspecifier));
return repo;
}
Matrix *kaldi_feature_repo_read_utterance(KaldiFeatureRepo *repo, lua_State *L,
int debug, MContext *context) {
Matrix *mat; /* nerv implementation */
repo->utt = repo->feature_reader->Key();
kaldi::Matrix<BaseFloat> kmat = repo->feature_reader->Value();
int n = kmat.NumRows();
int m = kmat.NumCols();
Status status;
assert(sizeof(BaseFloat) == sizeof(float));
if(sizeof(BaseFloat) == sizeof(float))
mat = nerv_matrix_host_float_create(n, m, context, &status);
else if(sizeof(BaseFloat) == sizeof(double))
mat = nerv_matrix_host_double_create(n, m, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
size_t stride = mat->stride;
if (debug)
fprintf(stderr, "[kaldi] feature: %s %d %d\n", repo->utt.c_str(), n, m);
for (int i = 0; i < n; i++)
{
const BaseFloat *row = kmat.RowData(i);
BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
/* use memmove to copy the row, since KaldiLib uses compact storage */
memmove(nerv_row, row, sizeof(BaseFloat) * m);
}
return mat;
}
void kaldi_feature_repo_next(KaldiFeatureRepo *repo) {
repo->feature_reader->Next();
}
int kaldi_feature_repo_is_end(KaldiFeatureRepo *repo) {
return repo->feature_reader->Done();
}
const char *kaldi_feature_repo_key(KaldiFeatureRepo *repo) {
return repo->utt.c_str();
}
void kaldi_feature_repo_destroy(KaldiFeatureRepo *repo) {
if (repo->feature_reader)
delete repo->feature_reader;
delete repo;
}
struct KaldiLookupFeatureRepo {
StringToMatrix_t key2mat;
StringToString_t map;
};
KaldiLookupFeatureRepo *kaldi_lookup_feature_repo_new(const char *feature_rspecifier, const char *map_rspecifier) {
KaldiLookupFeatureRepo *repo = new KaldiLookupFeatureRepo();
kaldi::SequentialBaseFloatMatrixReader *feature_reader = \
new kaldi::SequentialBaseFloatMatrixReader(string(feature_rspecifier));
for (; !feature_reader->Done(); feature_reader->Next())
{
const std::string &key = feature_reader->Key();
const kaldi::Matrix<BaseFloat> &feat = feature_reader->Value();
if (repo->key2mat.find(key) != repo->key2mat.end())
fprintf(stderr, "[kaldi] warning: lookup feature for key %s already exists", key.c_str());
repo->key2mat[key] = feat;
}
delete feature_reader;
kaldi::SequentialTokenVectorReader *map_reader = \
new kaldi::SequentialTokenVectorReader(string(map_rspecifier));
for (; !map_reader->Done(); map_reader->Next())
{
const std::vector<std::string> target = map_reader->Value();
assert(target.size() >= 1);
repo->map[map_reader->Key()] = *target.begin();
}
delete map_reader;
return repo;
}
Matrix *kaldi_lookup_feature_repo_read_utterance(KaldiLookupFeatureRepo *repo,
KaldiFeatureRepo *frepo,
int nframes, lua_State *L,
int debug, MContext *context) {
Matrix *mat; /* nerv implementation */
StringToString_t::iterator mit = repo->map.find(frepo->utt);
if (mit == repo->map.end())
nerv_error(L, "[kaldi] mapped key for key %s not found", frepo->utt.c_str());
const std::string &key = mit->second;
StringToMatrix_t::iterator it = repo->key2mat.find(key);
if (it == repo->key2mat.end())
nerv_error(L, "[kaldi] lookup feature for key %s not found", key.c_str());
const kaldi::Matrix<BaseFloat> &kmat = it->second;
int n = kmat.NumRows() < nframes ? kmat.NumRows() : nframes;
int m = kmat.NumCols();
Status status;
assert(sizeof(BaseFloat) == sizeof(float));
if(sizeof(BaseFloat) == sizeof(float))
mat = nerv_matrix_host_float_create(n, m, context, &status);
else if(sizeof(BaseFloat) == sizeof(double))
mat = nerv_matrix_host_double_create(n, m, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
size_t stride = mat->stride;
if (debug)
fprintf(stderr, "[kaldi] lookup feature: %s %d %d\n", frepo->utt.c_str(), n, m);
for (int i = 0; i < n; i++)
{
const BaseFloat *row = kmat.RowData(i);
BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
/* use memmove to copy the row, since KaldiLib uses compact storage */
memmove(nerv_row, row, sizeof(BaseFloat) * m);
}
return mat;
}
void kaldi_lookup_feature_repo_destroy(KaldiLookupFeatureRepo *repo) {
delete repo;
}
struct KaldiLabelRepo {
kaldi::RandomAccessPosteriorReader *targets_reader;
};
KaldiLabelRepo *kaldi_label_repo_new(const char *targets_rspecifier, const char *fmt) {
KaldiLabelRepo *repo = new KaldiLabelRepo();
repo->targets_reader = new kaldi::RandomAccessPosteriorReader(string(targets_rspecifier));
return repo;
}
Matrix *kaldi_label_repo_read_utterance(KaldiLabelRepo *repo, KaldiFeatureRepo *frepo, int nframes,
lua_State *L,
int debug, MContext *context) {
Matrix *mat;
kaldi::Posterior targets = repo->targets_reader->Value(frepo->utt);
int n = targets.size() < nframes ? targets.size() : nframes;
int m = (int)targets[0].size();
Status status;
assert(sizeof(BaseFloat) == sizeof(float));
if(sizeof(BaseFloat) == sizeof(float))
mat = nerv_matrix_host_float_create(n, m, context, &status);
else if(sizeof(BaseFloat) == sizeof(double))
mat = nerv_matrix_host_double_create(n, m, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
size_t stride = mat->stride;
if (debug)
fprintf(stderr, "[kaldi] label: %s %d %d\n", frepo->utt.c_str(), n, m);
for (int i = 0; i < n; i++)
for(int j = 0; j < m; j++)
*((BaseFloat *)((char *)mat->data.f + (i * stride + j))) = (BaseFloat)targets[i][j].first;
return mat;
}
void kaldi_label_repo_destroy(KaldiLabelRepo *repo) {
if(repo->targets_reader)
delete repo->targets_reader;
delete repo;
}
}