#include "nerv/common.h" #include "kaldi_mpe.h" #include "kaldi_mmi.h" #include const char *nerv_kaldi_mpe_tname = "nerv.KaldiMPE"; const char *nerv_kaldi_mmi_tname = "nerv.KaldiMMI"; const char *nerv_matrix_cuda_float_tname = "nerv.CuMatrixFloat"; const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat"; static int mpe_new(lua_State *L) { const char *arg = luaL_checkstring(L, 1); const char *mdl = luaL_checkstring(L, 2); const char *lat = luaL_checkstring(L, 3); const char *ali = luaL_checkstring(L, 4); KaldiMPE *mpe = new_KaldiMPE(arg, mdl, lat, ali); luaT_pushudata(L, mpe, nerv_kaldi_mpe_tname); return 1; } static int mpe_destroy(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); destroy_KaldiMPE(mpe); return 0; } static int mpe_check(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); const Matrix *cumat = luaT_checkudata(L, 2, nerv_matrix_cuda_float_tname); const char *utt = luaL_checkstring(L, 3); lua_pushboolean(L, check_mpe(mpe, cumat, utt)); return 1; } static int mpe_calc_diff(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); Matrix *mat = luaT_checkudata(L, 2, nerv_matrix_host_float_tname); const char *utt = luaL_checkstring(L, 3); Matrix *diff = calc_diff_mpe(mpe, mat, utt); luaT_pushudata(L, diff, nerv_matrix_host_float_tname); return 1; } static int mpe_get_num_frames(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); lua_pushnumber(L, get_num_frames_mpe(mpe)); return 1; } static int mpe_get_utt_frame_acc(lua_State *L) { KaldiMPE *mpe = luaT_checkudata(L, 1, nerv_kaldi_mpe_tname); lua_pushnumber(L, get_utt_frame_acc_mpe(mpe)); return 1; } static const luaL_Reg mpe_methods[] = { {"check", mpe_check}, {"calc_diff", mpe_calc_diff}, {"get_num_frames", mpe_get_num_frames}, {"get_utt_frame_acc", mpe_get_utt_frame_acc}, {NULL, NULL} }; static void mpe_init(lua_State *L) { luaT_newmetatable(L, nerv_kaldi_mpe_tname, NULL, mpe_new, mpe_destroy, NULL); luaL_register(L, NULL, mpe_methods); lua_pop(L, 1); } static int mmi_new(lua_State *L) { const char *arg = luaL_checkstring(L, 1); const char *mdl = luaL_checkstring(L, 2); const char *lat = luaL_checkstring(L, 3); const char *ali = luaL_checkstring(L, 4); KaldiMMI *mmi = new_KaldiMMI(arg, mdl, lat, ali); luaT_pushudata(L, mmi, nerv_kaldi_mmi_tname); return 1; } static int mmi_destroy(lua_State *L) { KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname); destroy_KaldiMMI(mmi); return 0; } static int mmi_check(lua_State *L) { KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname); const Matrix *cumat = luaT_checkudata(L, 2, nerv_matrix_cuda_float_tname); const char *utt = luaL_checkstring(L, 3); lua_pushboolean(L, check_mmi(mmi, cumat, utt)); return 1; } static int mmi_calc_diff(lua_State *L) { KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname); Matrix *mat = luaT_checkudata(L, 2, nerv_matrix_host_float_tname); const char *utt = luaL_checkstring(L, 3); Matrix *diff = calc_diff_mmi(mmi, mat, utt); luaT_pushudata(L, diff, nerv_matrix_host_float_tname); return 1; } static int mmi_get_num_frames(lua_State *L) { KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname); lua_pushnumber(L, get_num_frames_mmi(mmi)); return 1; } static const luaL_Reg mmi_methods[] = { {"check", mmi_check}, {"calc_diff", mmi_calc_diff}, {"get_num_frames", mmi_get_num_frames}, {NULL, NULL} }; static void mmi_init(lua_State *L) { luaT_newmetatable(L, nerv_kaldi_mmi_tname, NULL, mmi_new, mmi_destroy, NULL); luaL_register(L, NULL, mmi_methods); lua_pop(L, 1); } void kaldi_seq_init(lua_State *L) { mpe_init(L); mmi_init(L); }