summaryrefslogblamecommitdiff
path: root/tnet_io/init.c
blob: 3e3f90dd4cbe492e1ef9f146b4ced19fbb7d8fce (plain) (tree)
1
2
3
4
5
6
7
                              




                                                               
                                                               























                                                                             




                                                                     









                                                                             





                                                                             



                                               
                                 



















                                                                                  



                                                        


                                                                        
                                                                            
































                                                                            
#include "../../nerv/common.h"
#include "cwrapper.h"
#include <stdio.h>

const char *nerv_tnet_feat_repo_tname = "nerv.TNetFeatureRepo";
const char *nerv_tnet_label_repo_tname = "nerv.TNetLabelRepo";
const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat";

static int feat_repo_new(lua_State *L) {
    const char *scp_file = luaL_checkstring(L, 1);
    const char *conf = luaL_checkstring(L, 2);
    int frm_ext = luaL_checkinteger(L, 3);
    TNetFeatureRepo *repo = tnet_feature_repo_new(scp_file, conf, frm_ext);
    luaT_pushudata(L, repo, nerv_tnet_feat_repo_tname);
    return 1;
}

static int feat_repo_destroy(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    tnet_feature_repo_destroy(repo);
    return 0;
}

static int feat_repo_current_tag(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    lua_pushstring(L, tnet_feature_repo_current_tag(repo));
    return 1;
}

static int feat_repo_current_utterance(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    int debug;
    if (!lua_isboolean(L, 2))
        nerv_error(L, "debug flag should be a boolean");
    debug = lua_toboolean(L, 2);
    Matrix *utter = tnet_feature_repo_read_utterance(repo, L, debug);
    luaT_pushudata(L, utter, nerv_matrix_host_float_tname);
    return 1;
}

static int feat_repo_next(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    tnet_feature_repo_next(repo);
    return 0;
}

static int feat_repo_is_end(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    lua_pushboolean(L, tnet_feature_repo_is_end(repo));
    return 1;
}

static const luaL_Reg feat_repo_methods[] = {
    {"cur_utter", feat_repo_current_utterance},
    {"cur_tag", feat_repo_current_tag},
    {"next", feat_repo_next},
    {"is_end", feat_repo_is_end},
    {NULL, NULL}
};

static int label_repo_new(lua_State *L) {
    const char *mlf_file = luaL_checkstring(L, 1);
    const char *fmt = luaL_checkstring(L, 2);
    const char *arg = luaL_checkstring(L, 3);
    const char *dir = luaL_checkstring(L, 4);
    const char *ext = luaL_checkstring(L, 5);
    TNetLabelRepo *repo = tnet_label_repo_new(
                                mlf_file, fmt, arg,
                                dir, ext);
    luaT_pushudata(L, repo, nerv_tnet_label_repo_tname);
    return 1;
}

static int label_repo_read_utterance(lua_State *L) {
    TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
    TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname);
    size_t frames = luaL_checkinteger(L, 3);
    int debug;
    if (!lua_isboolean(L, 4))
        nerv_error(L, "debug flag should be a boolean");
    debug = lua_toboolean(L, 4);
    Matrix *utter = tnet_label_repo_read_utterance(repo,
                        frames,
                        tnet_feature_repo_current_samplerate(feat_repo),
                        tnet_feature_repo_current_tag(feat_repo), L, debug);
    luaT_pushudata(L, utter, nerv_matrix_host_float_tname);
    return 1;
}

static int label_repo_destroy(lua_State *L) {
    TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
    tnet_label_repo_destroy(repo);
    return 0;
}

static const luaL_Reg label_repo_methods[] = {
    {"get_utter", label_repo_read_utterance},
    {NULL, NULL}
};

static void feat_repo_init(lua_State *L) {
    luaT_newmetatable(L, nerv_tnet_feat_repo_tname, NULL,
                        feat_repo_new, feat_repo_destroy, NULL);
    luaL_register(L, NULL, feat_repo_methods);
    lua_pop(L, 1);
}

static void label_repo_init(lua_State *L) {
    luaT_newmetatable(L, nerv_tnet_label_repo_tname, NULL,
                        label_repo_new, label_repo_destroy, NULL);
    luaL_register(L, NULL, label_repo_methods);
    lua_pop(L, 1);
}

void tnet_io_init(lua_State *L) {
    feat_repo_init(L);
    label_repo_init(L);
}