summaryrefslogtreecommitdiff
path: root/tnet_io/init.c
blob: 3fa7cb8a440c56ccf2ed58afa6d85589b0b069ea (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include "../../common.h"
#include "cwrapper.h"
#include <stdio.h>

const char *nerv_tnet_feat_repo_tname = "nerv.TNetFeatureRepo";
const char *nerv_tnet_label_repo_tname = "nerv.TNetLabelRepo";
extern const char *nerv_matrix_host_float_tname;

static int feat_repo_new(lua_State *L) {
    const char *scp_file = luaL_checkstring(L, 1);
    const char *conf = luaL_checkstring(L, 2);
    int frm_ext = luaL_checkinteger(L, 3);
    TNetFeatureRepo *repo = tnet_feature_repo_new(scp_file, conf, frm_ext);
    luaT_pushudata(L, repo, nerv_tnet_feat_repo_tname);
    return 1;
}

static int feat_repo_destroy(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    tnet_feature_repo_destroy(repo);
    return 0;
}

static int feat_repo_current_tag(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    lua_pushstring(L, tnet_feature_repo_current_tag(repo));
    return 1;
}

static int feat_repo_current_utterance(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    Matrix *utter = tnet_feature_repo_read_utterance(repo, L);
    luaT_pushudata(L, utter, nerv_matrix_host_float_tname);
    return 1;
}

static int feat_repo_next(lua_State *L) {
    TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname);
    tnet_feature_repo_next(repo);
    return 0;
}

static const luaL_Reg feat_repo_methods[] = {
    {"cur_utter", feat_repo_current_utterance},
    {"cur_tag", feat_repo_current_tag},
    {"next", feat_repo_next},
    {NULL, NULL}
};

static int label_repo_new(lua_State *L) {
    const char *mlf_file = luaL_checkstring(L, 1);
    const char *fmt = luaL_checkstring(L, 2);
    const char *arg = luaL_checkstring(L, 3);
    const char *dir = luaL_checkstring(L, 4);
    const char *ext = luaL_checkstring(L, 5);
    TNetLabelRepo *repo = tnet_label_repo_new(
                                mlf_file, fmt, arg,
                                dir, ext);
    luaT_pushudata(L, repo, nerv_tnet_label_repo_tname);
    return 1;
}

static int label_repo_read_utterance(lua_State *L) {
    TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
    TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname);
    size_t frames = luaL_checkinteger(L, 3);
    Matrix *utter = tnet_label_repo_read_utterance(repo,
                        frames,
                        tnet_feature_repo_current_samplerate(feat_repo),
                        tnet_feature_repo_current_tag(feat_repo), L);
    luaT_pushudata(L, utter, nerv_matrix_host_float_tname);
    return 1;
}

static int label_repo_destroy(lua_State *L) {
    TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname);
    tnet_label_repo_destroy(repo);
    return 0;
}

static const luaL_Reg label_repo_methods[] = {
    {"get_utter", label_repo_read_utterance},
    {NULL, NULL}
};

static void feat_repo_init(lua_State *L) {
    luaT_newmetatable(L, nerv_tnet_feat_repo_tname, NULL,
                        feat_repo_new, feat_repo_destroy, NULL);
    luaL_register(L, NULL, feat_repo_methods);
    lua_pop(L, 1);
}

static void label_repo_init(lua_State *L) {
    luaT_newmetatable(L, nerv_tnet_label_repo_tname, NULL,
                        label_repo_new, label_repo_destroy, NULL);
    luaL_register(L, NULL, label_repo_methods);
    lua_pop(L, 1);
}

void tnet_io_init(lua_State *L) {
    feat_repo_init(L);
    label_repo_init(L);
}