summaryrefslogblamecommitdiff
path: root/tnet_io/init.c
blob: 7195eb7b9123a5f4da900e030a1862d644126a9c (plain) (tree)






































































































                                                                                  
#include "../../common.h"
#include "cwrapper.h"
#include <stdio.h>

const char *nerv_tnet_feat_repo_tname = "nerv.TNetFeatureRepo";
const char *nerv_tnet_label_repo_tname = "nerv.TNetLabelRepo";
extern const char *nerv_matrix_host_float_tname;

static int feat_repo_new(lua_State *L) {
    const char *scp_file = luaL_checkstring(L, 1);
    const char *conf = luaL_checkstring(L, 2);
    int frm_ext = luaL_checkinteger(L, 3);
    TNetFeatureRepo *repo = tnet_feature_repo_new(scp_file, conf, frm_ext);
    luaT_pushudata(L, repo, nerv_tnet_feat_repo_tname);
    return 1;
}

static int feat_repo_destroy(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    tnet_feature_repo_destroy(repo);
    return 0;
}

static int feat_repo_current_tag(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    lua_pushstring(L, tnet_feature_repo_current_tag(repo));
    return 1;
}

static int feat_repo_current_utterance(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    Matrix *utter = tnet_feature_repo_read_utterance(repo);
    luaT_pushudata(L, utter, nerv_matrix_host_float_tname);
    return 1;
}

static int feat_repo_next(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    tnet_feature_repo_next(repo);
    return 0;
}

static const luaL_Reg feat_repo_methods[] = {
    {"cur_utter", feat_repo_current_utterance},
    {"cur_tag", feat_repo_current_tag},
    {"next", feat_repo_next},
    {NULL, NULL}
};

static int label_repo_new(lua_State *L) {
    const char *mlf_file = luaL_checkstring(L, 1);
    const char *fmt = luaL_checkstring(L, 2);
    const char *arg = luaL_checkstring(L, 3);
    const char *dir = luaL_checkstring(L, 4);
    const char *ext = luaL_checkstring(L, 5);
    TNetLabelRepo *repo = tnet_label_repo_new(
                                mlf_file, fmt, arg,
                                dir, ext);
    luaT_pushudata(L, repo, nerv_tnet_label_repo_tname);
    return 1;
}

static int label_repo_read_utterance(lua_State *L) {
    TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
    TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname);
    size_t frames = luaL_checkinteger(L, 3);
    Matrix *utter = tnet_label_repo_read_utterance(repo,
                        frames,
                        tnet_feature_repo_current_samplerate(feat_repo),
                        tnet_feature_repo_current_tag(feat_repo));
    luaT_pushudata(L, utter, nerv_matrix_host_float_tname);
    return 1;
}

static int label_repo_destroy(lua_State *L) {
    TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
    tnet_label_repo_destroy(repo);
    return 0;
}

static const luaL_Reg label_repo_methods[] = {
    {"get_utter", label_repo_read_utterance},
    {NULL, NULL}
};

static void feat_repo_init(lua_State *L) {
    luaT_newmetatable(L, nerv_tnet_feat_repo_tname, NULL,
                        feat_repo_new, feat_repo_destroy, NULL);
    luaL_register(L, NULL, feat_repo_methods);
    lua_pop(L, 1);
}

static void label_repo_init(lua_State *L) {
    luaT_newmetatable(L, nerv_tnet_label_repo_tname, NULL,
                        label_repo_new, label_repo_destroy, NULL);
    luaL_register(L, NULL, label_repo_methods);
    lua_pop(L, 1);
}

void tnet_io_init(lua_State *L) {
    feat_repo_init(L);
    label_repo_init(L);
}