#ifdef NERV_GENERIC_MATRIX
#include "../matrix.h"
#include "../../lib/matrix/generic/matrix.h"
#include "../../lib/common.h"
extern const char *nerv_matrix_(tname);
extern const char *MATRIX_BASE_TNAME;
int nerv_matrix_(lua_new)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
Matrix *self = nerv_matrix_(create)(luaL_checkinteger(L, 1),
luaL_checkinteger(L, 2),
context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, self, nerv_matrix_(tname));
return 1;
}
int nerv_matrix_(lua_destroy)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 2);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
nerv_matrix_(destroy)(self, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 1;
}
int nerv_matrix_(lua_get_elem)(lua_State *L);
int nerv_matrix_(lua_set_elem)(lua_State *L);
static int nerv_matrix_(lua_newindex)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
if (lua_isnumber(L, 2))
{
int idx = luaL_checkinteger(L, 2);
if (self->dim == 1)
{
if (idx < 0 || idx >= self->ncol)
nerv_error(L, "index must be within range [0, %d)", self->ncol);
MATRIX_DATA_WRITE(L, MATRIX_ELEM_PTR(self), idx,
luaL_checknumber(L, 3));
}
else
nerv_error(L, "cannot assign to row vector");
lua_pushboolean(L, 1);
return 1;
}
else
{
lua_pushboolean(L, 0);
return 1;
}
}
static int nerv_matrix_(lua_index)(lua_State *L) {
Status status;
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
if (lua_isnumber(L, 2))
{
int idx = luaL_checkinteger(L, 2);
if (self->dim == 1)
{
if (idx < 0 || idx >= self->ncol)
nerv_error(L, "index must be within range [0, %d)", self->ncol);
lua_pushnumber(L, MATRIX_DATA_READ(L, MATRIX_ELEM_PTR(self), idx));
}
else
{
if (idx < 0 || idx >= self->nrow)
nerv_error(L, "index must be within range [0, %d)", self->nrow);
luaT_pushudata(L, nerv_matrix_(getrow)(self, idx),
nerv_matrix_(tname));
}
lua_pushboolean(L, 1);
return 2;
}
else
{
lua_pushboolean(L, 0);
return 1;
}
}
static int nerv_matrix_(lua_ncol)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
lua_pushinteger(L, self->ncol);
return 1;
}
static int nerv_matrix_(lua_dim)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
lua_pushinteger(L, self->dim);
return 1;
}
static int nerv_matrix_(lua_nrow)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
lua_pushinteger(L, self->nrow);
return 1;
}
static int nerv_matrix_(lua_get_dataref_value)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
lua_pushinteger(L, *(self->data_ref));
return 1;
}
static const luaL_Reg nerv_matrix_(methods)[] = {
{"get_elem", nerv_matrix_(lua_get_elem)},
{"set_elem", nerv_matrix_(lua_set_elem)},
{"ncol", nerv_matrix_(lua_ncol)},
{"nrow", nerv_matrix_(lua_nrow)},
{"dim", nerv_matrix_(lua_dim)},
{"get_dataref_value", nerv_matrix_(lua_get_dataref_value)},
{"__index__", nerv_matrix_(lua_index)},
{"__newindex__", nerv_matrix_(lua_newindex)},
{NULL, NULL}
};
void nerv_matrix_(lua_init)(lua_State *L) {
luaT_newmetatable(L, nerv_matrix_(tname), MATRIX_BASE_TNAME,
nerv_matrix_(lua_new), nerv_matrix_(lua_destroy), NULL);
luaL_register(L, NULL, nerv_matrix_(methods));
#ifdef MATRIX_INIT
MATRIX_INIT(L);
#endif
lua_pop(L, 1);
}
static int nerv_matrix_(lua_add)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 6);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
nerv_matrix_(add)(c, a, b, alpha, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_mul)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 8);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
int nargs = lua_gettop(L);
int ta = nargs > 5 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 6)) \
: BLAS_OP_N;
int tb = nargs > 6 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 7)) \
: BLAS_OP_N;
nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_sigmoid)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
nerv_matrix_(sigmoid)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_sigmoid_grad)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 4);
Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
nerv_matrix_(sigmoid_grad)(nerr, err, output, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_softmax)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *max_idx = nerv_matrix_(softmax)(b, a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, max_idx, nerv_matrix_(tname));
return 1;
}
static int nerv_matrix_(lua_rowsum)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(rowsum)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
static int nerv_matrix_(lua_colsum)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(colsum)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
static int nerv_matrix_(lua_colsame)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(colsame)(a, ref, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
static int nerv_matrix_(lua_rowmax)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(rowmax)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b;
Matrix *idx;
nerv_matrix_(rowmax_idx)(a, &b, &idx, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
luaT_pushudata(L, idx, nerv_matrix_(tname));
return 2;
}
static int nerv_matrix_(lua_add_row)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 4);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
double beta = luaL_checknumber(L, 3);
nerv_matrix_(add_row)(b, a, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_fill)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
double val = luaL_checknumber(L, 2);
nerv_matrix_(fill)(self, val, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_clip)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 4);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
double val1 = luaL_checknumber(L, 2);
double val2 = luaL_checknumber(L, 3);
nerv_matrix_(clip)(self, val1, val2, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_trans)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(trans)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
static int nerv_matrix_(lua_mul_elem)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 4);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
nerv_matrix_(mul_elem)(c, a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_log_elem)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
nerv_matrix_(log_elem)(b, a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_decompress)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
const Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
int orig_col = luaL_checkinteger(L, 2);
Matrix *b = nerv_matrix_(decompress)(a, orig_col, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
static int nerv_matrix_(lua_expand_frm)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 4);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
int cont = luaL_checkinteger(L, 3);
nerv_matrix_(expand_frm)(a, b, cont, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_rearrange_frm)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 4);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
int step = luaL_checkinteger(L, 3);
nerv_matrix_(rearrange_frm)(a, b, step, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_scale_rows_by_col)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
nerv_matrix_(scale_rows_by_col)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
nerv_matrix_(scale_rows_by_row)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_diagonalize)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
nerv_matrix_(diagonalize)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_set_values_by_mask)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
MATRIX_GET_CONTEXT(L, 4);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *mask = luaT_checkudata(L, 2, nerv_matrix_(tname));
double val = luaL_checknumber(L, 3);
nerv_matrix_(set_values_by_mask)(a, mask, val, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
#endif