#include "nerv/common.h" #include "kaldi_mpe.h" #include const char *nerv_kaldi_mpe_tname = "nerv.KaldiMPE"; const char *nerv_matrix_cuda_float_tname = "nerv.CuMatrixFloat"; const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat"; static int mpe_new(lua_State *L) { const char *arg = luaL_checkstring(L, 1); const char *mdl = luaL_checkstring(L, 2); const char *lat = luaL_checkstring(L, 3); const char *ali = luaL_checkstring(L, 4); KaldiMPE *mpe = new_KaldiMPE(arg, mdl, lat, ali); luaT_pushudata(L, mpe, nerv_kaldi_mpe_tname); return 1; } static int mpe_destroy(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); destroy_KaldiMPE(mpe); return 0; } static int mpe_check(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); const Matrix *cumat = luaT_checkudata(L, 2, nerv_matrix_cuda_float_tname); const char *utt = luaL_checkstring(L, 3); lua_pushboolean(L, check_mpe(mpe, cumat, utt)); return 1; } static int mpe_calc_diff(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); Matrix *mat = luaT_checkudata(L, 2, nerv_matrix_host_float_tname); const char *utt = luaL_checkstring(L, 3); Matrix *diff = calc_diff_mpe(mpe, mat, utt); luaT_pushudata(L, diff, nerv_matrix_host_float_tname); return 1; } static int mpe_get_num_frames(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); lua_pushnumber(L, get_num_frames_mpe(mpe)); return 1; } static int mpe_get_utt_frame_acc(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); lua_pushnumber(L, get_utt_frame_acc_mpe(mpe)); return 1; } static const luaL_Reg mpe_methods[] = { {"check", mpe_check}, {"calc_diff", mpe_calc_diff}, {"get_num_frames", mpe_get_num_frames}, {"get_utt_frame_acc", mpe_get_utt_frame_acc}, {NULL, NULL} }; static void mpe_init(lua_State *L) { luaT_newmetatable(L, nerv_kaldi_mpe_tname, NULL, mpe_new, mpe_destroy, NULL); luaL_register(L, NULL, mpe_methods); lua_pop(L, 1); } void kaldi_seq_init(lua_State *L) { mpe_init(L); }