summaryrefslogblamecommitdiff
path: root/htk_io/src/init.c
blob: 04046e9dab3d063863e09f710bc34febb18b6e11 (plain) (tree)
1
2
3
4
5
6
7
8
9
                        




                                                               
                                                               

                                        













                                                              


                                                       













                                                                                              














                                                                             




                                                                     









                                                                             





                                                                             



                                               
                                 

                                       



                                         













                                                      

                                                   
     



                                                        






                                                                                



                                                                                  



                                                        


                                                                        
                                                                            











                                                                            
                          




















                                                                  
#include "nerv/common.h"
#include "cwrapper.h"
#include <stdio.h>

const char *nerv_tnet_feat_repo_tname = "nerv.TNetFeatureRepo";
const char *nerv_tnet_label_repo_tname = "nerv.TNetLabelRepo";
const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat";

static int feat_repo_new(lua_State *L) {
    TNetFeatureRepo *repo = NULL;
    if(lua_gettop(L) == 1)
    {
	long id = luaL_checkinteger(L, 1);
    	repo = tnet_feature_repo_newWithId(id);
    }
    else
    {
    	const char *scp_file = luaL_checkstring(L, 1);
    	const char *conf = luaL_checkstring(L, 2);
    	int frm_ext = luaL_checkinteger(L, 3);
    	repo = tnet_feature_repo_new(scp_file, conf, frm_ext);
    }

    luaT_pushudata(L, repo, nerv_tnet_feat_repo_tname);
    return 1;
}
static int feat_repo_id(lua_State *L) {
	TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
	lua_pushinteger(L, tnet_feature_repo_id(repo));
	return 1;
}

static int feat_repo_tostring(lua_State *L) 
{
        char str[128];
       	TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);  
        snprintf(str, 128, "%s <%lx>", nerv_tnet_feat_repo_tname, tnet_feature_repo_id(repo));
        lua_pushstring(L, str);
        return 1;
}

static int feat_repo_destroy(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    tnet_feature_repo_destroy(repo);
    return 0;
}

static int feat_repo_current_tag(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    lua_pushstring(L, tnet_feature_repo_current_tag(repo));
    return 1;
}

static int feat_repo_current_utterance(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    int debug;
    if (!lua_isboolean(L, 2))
        nerv_error(L, "debug flag should be a boolean");
    debug = lua_toboolean(L, 2);
    Matrix *utter = tnet_feature_repo_read_utterance(repo, L, debug);
    luaT_pushudata(L, utter, nerv_matrix_host_float_tname);
    return 1;
}

static int feat_repo_next(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    tnet_feature_repo_next(repo);
    return 0;
}

static int feat_repo_is_end(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    lua_pushboolean(L, tnet_feature_repo_is_end(repo));
    return 1;
}

static const luaL_Reg feat_repo_methods[] = {
    {"cur_utter", feat_repo_current_utterance},
    {"cur_tag", feat_repo_current_tag},
    {"next", feat_repo_next},
    {"is_end", feat_repo_is_end},
    {"__tostring", feat_repo_tostring},
    {"id", feat_repo_id},
    {NULL, NULL}
};

static int label_repo_new(lua_State *L) {
    TNetLabelRepo *repo = NULL;
    if(lua_gettop(L) == 1)
    {   
        long id = luaL_checkinteger(L, 1); 
        repo = tnet_label_repo_newWithId(id);
    }   
    else
    {   
    	const char *mlf_file = luaL_checkstring(L, 1);
    	const char *fmt = luaL_checkstring(L, 2);
    	const char *arg = luaL_checkstring(L, 3);
    	const char *dir = luaL_checkstring(L, 4);
    	const char *ext = luaL_checkstring(L, 5);
    	repo = tnet_label_repo_new(
                                mlf_file, fmt, arg,
                                dir, ext);
    }
    luaT_pushudata(L, repo, nerv_tnet_label_repo_tname);
    return 1;
}

static int label_repo_id(lua_State *L) {
	TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
	lua_pushinteger(L, tnet_label_repo_id(repo));
        return 1;
}


static int label_repo_read_utterance(lua_State *L) {
    TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
    TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname);
    size_t frames = luaL_checkinteger(L, 3);
    int debug;
    if (!lua_isboolean(L, 4))
        nerv_error(L, "debug flag should be a boolean");
    debug = lua_toboolean(L, 4);
    Matrix *utter = tnet_label_repo_read_utterance(repo,
                        frames,
                        tnet_feature_repo_current_samplerate(feat_repo),
                        tnet_feature_repo_current_tag(feat_repo), L, debug);
    luaT_pushudata(L, utter, nerv_matrix_host_float_tname);
    return 1;
}

static int label_repo_destroy(lua_State *L) {
    TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
    tnet_label_repo_destroy(repo);
    return 0;
}

static const luaL_Reg label_repo_methods[] = {
    {"get_utter", label_repo_read_utterance},
    {"id", label_repo_id},
    {NULL, NULL}
};

static void feat_repo_init(lua_State *L) {
    luaT_newmetatable(L, nerv_tnet_feat_repo_tname, NULL,
                        feat_repo_new, feat_repo_destroy, NULL);
    luaL_register(L, NULL, feat_repo_methods);
    lua_pop(L, 1);
}

static void label_repo_init(lua_State *L) {
    luaT_newmetatable(L, nerv_tnet_label_repo_tname, NULL,
                        label_repo_new, label_repo_destroy, NULL);
    luaL_register(L, NULL, label_repo_methods);
    lua_pop(L, 1);
}

void tnet_io_init(lua_State *L) {
    feat_repo_init(L);
    label_repo_init(L);
}