#ifdef NERV_GENERIC_MATRIX #include "../matrix.h" #include "../../lib/matrix/generic/matrix.h" #include "../../lib/common.h" extern const char *nerv_matrix_(tname); extern const char *MATRIX_BASE_TNAME; int nerv_matrix_(lua_new)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *self = nerv_matrix_(create)(luaL_checkinteger(L, 1), luaL_checkinteger(L, 2), context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, self, nerv_matrix_(tname)); return 1; } int nerv_matrix_(lua_destroy)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 2); Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); nerv_matrix_(destroy)(self, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 1; } int nerv_matrix_(lua_get_elem)(lua_State *L); int nerv_matrix_(lua_set_elem)(lua_State *L); static int nerv_matrix_(lua_newindex)(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); if (lua_isnumber(L, 2)) { int idx = luaL_checkinteger(L, 2); if (self->dim == 1) { if (idx < 0 || idx >= self->ncol) nerv_error(L, "index must be within range [0, %d)", self->ncol); MATRIX_DATA_WRITE(L, MATRIX_ELEM_PTR(self), idx, luaL_checknumber(L, 3)); } else nerv_error(L, "cannot assign to row vector"); lua_pushboolean(L, 1); return 1; } else { lua_pushboolean(L, 0); return 1; } } static int nerv_matrix_(lua_index)(lua_State *L) { Status status; Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); if (lua_isnumber(L, 2)) { int idx = luaL_checkinteger(L, 2); if (self->dim == 1) { if (idx < 0 || idx >= self->ncol) nerv_error(L, "index must be within range [0, %d)", self->ncol); lua_pushnumber(L, MATRIX_DATA_READ(L, MATRIX_ELEM_PTR(self), idx)); } else { if (idx < 0 || idx >= self->nrow) nerv_error(L, "index must be within range [0, %d)", self->nrow); luaT_pushudata(L, nerv_matrix_(getrow)(self, idx), nerv_matrix_(tname)); } lua_pushboolean(L, 1); return 2; } else { lua_pushboolean(L, 0); return 1; } } static int nerv_matrix_(lua_ncol)(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); lua_pushinteger(L, self->ncol); return 1; } static int nerv_matrix_(lua_dim)(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); lua_pushinteger(L, self->dim); return 1; } static int nerv_matrix_(lua_nrow)(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); lua_pushinteger(L, self->nrow); return 1; } static int nerv_matrix_(lua_get_dataref_value)(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); lua_pushinteger(L, *(self->data_ref)); return 1; } static const luaL_Reg nerv_matrix_(methods)[] = { {"get_elem", nerv_matrix_(lua_get_elem)}, {"set_elem", nerv_matrix_(lua_set_elem)}, {"ncol", nerv_matrix_(lua_ncol)}, {"nrow", nerv_matrix_(lua_nrow)}, {"dim", nerv_matrix_(lua_dim)}, {"get_dataref_value", nerv_matrix_(lua_get_dataref_value)}, {"__index__", nerv_matrix_(lua_index)}, {"__newindex__", nerv_matrix_(lua_newindex)}, {NULL, NULL} }; void nerv_matrix_(lua_init)(lua_State *L) { luaT_newmetatable(L, nerv_matrix_(tname), MATRIX_BASE_TNAME, nerv_matrix_(lua_new), nerv_matrix_(lua_destroy), NULL); luaL_register(L, NULL, nerv_matrix_(methods)); #ifdef MATRIX_INIT MATRIX_INIT(L); #endif lua_pop(L, 1); } static int nerv_matrix_(lua_add)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 6); Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); MATRIX_ELEM alpha = luaL_checknumber(L, 4); MATRIX_ELEM beta = luaL_checknumber(L, 5); nerv_matrix_(add)(c, a, b, alpha, beta, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_mul)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 8); Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); MATRIX_ELEM alpha = luaL_checknumber(L, 4); MATRIX_ELEM beta = luaL_checknumber(L, 5); int nargs = lua_gettop(L); int ta = nargs > 5 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 6)) \ : BLAS_OP_N; int tb = nargs > 6 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 7)) \ : BLAS_OP_N; nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_sigmoid)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); nerv_matrix_(sigmoid)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_sigmoid_grad)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname)); nerv_matrix_(sigmoid_grad)(nerr, err, output, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_softmax)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *max_idx = nerv_matrix_(softmax)(b, a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, max_idx, nerv_matrix_(tname)); return 1; } static int nerv_matrix_(lua_rowsum)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = nerv_matrix_(rowsum)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } static int nerv_matrix_(lua_colsum)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = nerv_matrix_(colsum)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } static int nerv_matrix_(lua_colsame)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = nerv_matrix_(colsame)(a, ref, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } static int nerv_matrix_(lua_rowmax)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = nerv_matrix_(rowmax)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b; Matrix *idx; nerv_matrix_(rowmax_idx)(a, &b, &idx, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); luaT_pushudata(L, idx, nerv_matrix_(tname)); return 2; } static int nerv_matrix_(lua_add_row)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); double beta = luaL_checknumber(L, 3); nerv_matrix_(add_row)(b, a, beta, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_fill)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); double val = luaL_checknumber(L, 2); nerv_matrix_(fill)(self, val, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_clip)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); double val1 = luaL_checknumber(L, 2); double val2 = luaL_checknumber(L, 3); nerv_matrix_(clip)(self, val1, val2, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_trans)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = nerv_matrix_(trans)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } static int nerv_matrix_(lua_mul_elem)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); nerv_matrix_(mul_elem)(c, a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_log_elem)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); nerv_matrix_(log_elem)(b, a, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_decompress)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); const Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); int orig_col = luaL_checkinteger(L, 2); Matrix *b = nerv_matrix_(decompress)(a, orig_col, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } static int nerv_matrix_(lua_expand_frm)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); int cont = luaL_checkinteger(L, 3); nerv_matrix_(expand_frm)(a, b, cont, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_rearrange_frm)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); int step = luaL_checkinteger(L, 3); nerv_matrix_(rearrange_frm)(a, b, step, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_scale_rows_by_col)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); nerv_matrix_(scale_rows_by_col)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); nerv_matrix_(scale_rows_by_row)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_diagonalize)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); nerv_matrix_(diagonalize)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_set_values_by_mask)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *mask = luaT_checkudata(L, 2, nerv_matrix_(tname)); double val = luaL_checknumber(L, 3); nerv_matrix_(set_values_by_mask)(a, mask, val, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_tanh)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); nerv_matrix_(tanh)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_tanh_grad)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname)); nerv_matrix_(tanh_grad)(nerr, err, output, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_relu)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); nerv_matrix_(relu)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_relu_grad)(lua_State *L) { Status status; MATRIX_CONTEXT *context; MATRIX_GET_CONTEXT(L, 4); Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname)); nerv_matrix_(relu_grad)(nerr, err, output, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } #endif