summaryrefslogtreecommitdiff
path: root/kaldi_seq/src/init.c
blob: c2002cf11a677b9f32a7f43a4bebc09b3e65b05f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include "nerv/lib/common.h"
#include "kaldi_mpe.h"
#include "kaldi_mmi.h"
#include <stdio.h>

const char *nerv_kaldi_mpe_tname = "nerv.KaldiMPE";
const char *nerv_kaldi_mmi_tname = "nerv.KaldiMMI";
const char *nerv_matrix_cuda_float_tname = "nerv.CuMatrixFloat";
const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat";

static int mpe_new(lua_State *L) {
    const char *arg = luaL_checkstring(L, 1);
    const char *mdl = luaL_checkstring(L, 2);
    const char *lat = luaL_checkstring(L, 3);
    const char *ali = luaL_checkstring(L, 4);
    KaldiMPE *mpe = new_KaldiMPE(arg, mdl, lat, ali);
    luaT_pushudata(L, mpe, nerv_kaldi_mpe_tname);
    return 1;
}

static int mpe_destroy(lua_State *L) {
    KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname);
    destroy_KaldiMPE(mpe);
    return 0;
}

static int mpe_check(lua_State *L) {
    KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname);
    const Matrix *cumat = luaT_checkudata(L, 2, nerv_matrix_cuda_float_tname);
    const char *utt = luaL_checkstring(L, 3);

    lua_pushboolean(L, check_mpe(mpe, cumat, utt));
    return 1;
}

static int mpe_calc_diff(lua_State *L) {
    KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname);
    Matrix *mat = luaT_checkudata(L, 2, nerv_matrix_host_float_tname);
    const char *utt = luaL_checkstring(L, 3);

    Matrix *diff = calc_diff_mpe(mpe, mat, utt);
    luaT_pushudata(L, diff, nerv_matrix_host_float_tname);
    return 1;
}

static int mpe_get_num_frames(lua_State *L) {
    KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname);
    lua_pushnumber(L, get_num_frames_mpe(mpe));
    return 1;
}

static int mpe_get_utt_frame_acc(lua_State *L) {
    KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname);
    lua_pushnumber(L, get_utt_frame_acc_mpe(mpe));
    return 1;
}

static const luaL_Reg mpe_methods[] = {
    {"check", mpe_check},
    {"calc_diff", mpe_calc_diff},
    {"get_num_frames", mpe_get_num_frames},
    {"get_utt_frame_acc", mpe_get_utt_frame_acc},
    {NULL, NULL}
};

static void mpe_init(lua_State *L) {
    luaT_newmetatable(L, nerv_kaldi_mpe_tname, NULL,
            mpe_new, mpe_destroy, NULL);
    luaL_register(L, NULL, mpe_methods);
    lua_pop(L, 1);
}

static int mmi_new(lua_State *L) {
    const char *arg = luaL_checkstring(L, 1);
    const char *mdl = luaL_checkstring(L, 2);
    const char *lat = luaL_checkstring(L, 3);
    const char *ali = luaL_checkstring(L, 4);
    KaldiMMI *mmi = new_KaldiMMI(arg, mdl, lat, ali);
    luaT_pushudata(L, mmi, nerv_kaldi_mmi_tname);
    return 1;
}

static int mmi_destroy(lua_State *L) {
    KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname);
    destroy_KaldiMMI(mmi);
    return 0;
}

static int mmi_check(lua_State *L) {
    KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname);
    const Matrix *cumat = luaT_checkudata(L, 2, nerv_matrix_cuda_float_tname);
    const char *utt = luaL_checkstring(L, 3);

    lua_pushboolean(L, check_mmi(mmi, cumat, utt));
    return 1;
}

static int mmi_calc_diff(lua_State *L) {
    KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname);
    Matrix *mat = luaT_checkudata(L, 2, nerv_matrix_host_float_tname);
    const char *utt = luaL_checkstring(L, 3);

    Matrix *diff = calc_diff_mmi(mmi, mat, utt);
    luaT_pushudata(L, diff, nerv_matrix_host_float_tname);
    return 1;
}

static int mmi_get_num_frames(lua_State *L) {
    KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname);
    lua_pushnumber(L, get_num_frames_mmi(mmi));
    return 1;
}

static const luaL_Reg mmi_methods[] = {
    {"check", mmi_check},
    {"calc_diff", mmi_calc_diff},
    {"get_num_frames", mmi_get_num_frames},
    {NULL, NULL}
};

static void mmi_init(lua_State *L) {
    luaT_newmetatable(L, nerv_kaldi_mmi_tname, NULL,
            mmi_new, mmi_destroy, NULL);
    luaL_register(L, NULL, mmi_methods);
    lua_pop(L, 1);
}

void kaldi_seq_init(lua_State *L) {
    mpe_init(L);
    mmi_init(L);
}