aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix/generic/matrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/matrix/generic/matrix.c')
-rw-r--r--nerv/matrix/generic/matrix.c101
1 files changed, 74 insertions, 27 deletions
diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c
index c1da774..c2e57b8 100644
--- a/nerv/matrix/generic/matrix.c
+++ b/nerv/matrix/generic/matrix.c
@@ -1,15 +1,18 @@
#ifdef NERV_GENERIC_MATRIX
+#include "../matrix.h"
#include "../../lib/common.h"
#include "../../lib/matrix/generic/matrix.h"
extern const char *nerv_matrix_(tname);
extern const char *MATRIX_BASE_TNAME;
-
int nerv_matrix_(lua_new)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *self = nerv_matrix_(create)(luaL_checkinteger(L, 1),
- luaL_checkinteger(L, 2), &status);
+ luaL_checkinteger(L, 2),
+ context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, self, nerv_matrix_(tname));
return 1;
@@ -17,8 +20,10 @@ int nerv_matrix_(lua_new)(lua_State *L) {
int nerv_matrix_(lua_destroy)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(destroy)(self, &status);
+ nerv_matrix_(destroy)(self, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 1;
}
@@ -128,18 +133,22 @@ void nerv_matrix_(lua_init)(lua_State *L) {
static int nerv_matrix_(lua_add)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
- nerv_matrix_(add)(c, a, b, alpha, beta, &status);
+ nerv_matrix_(add)(c, a, b, alpha, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_mul)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 8);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
@@ -150,35 +159,41 @@ static int nerv_matrix_(lua_mul)(lua_State *L) {
: BLAS_OP_N;
int tb = nargs > 6 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 7)) \
: BLAS_OP_N;
- nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, &status);
+ nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_sigmoid)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(sigmoid)(a, b, &status);
+ nerv_matrix_(sigmoid)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_sigmoid_grad)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
- nerv_matrix_(sigmoid_grad)(nerr, err, output, &status);
+ nerv_matrix_(sigmoid_grad)(nerr, err, output, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_softmax)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *max_idx = nerv_matrix_(softmax)(b, a, &status);
+ Matrix *max_idx = nerv_matrix_(softmax)(b, a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, max_idx, nerv_matrix_(tname));
return 1;
@@ -186,8 +201,10 @@ static int nerv_matrix_(lua_softmax)(lua_State *L) {
static int nerv_matrix_(lua_rowsum)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(rowsum)(a, &status);
+ Matrix *b = nerv_matrix_(rowsum)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -195,8 +212,10 @@ static int nerv_matrix_(lua_rowsum)(lua_State *L) {
static int nerv_matrix_(lua_colsum)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(colsum)(a, &status);
+ Matrix *b = nerv_matrix_(colsum)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -204,9 +223,11 @@ static int nerv_matrix_(lua_colsum)(lua_State *L) {
static int nerv_matrix_(lua_colsame)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(colsame)(a, ref, &status);
+ Matrix *b = nerv_matrix_(colsame)(a, ref, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -214,8 +235,10 @@ static int nerv_matrix_(lua_colsame)(lua_State *L) {
static int nerv_matrix_(lua_rowmax)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(rowmax)(a, &status);
+ Matrix *b = nerv_matrix_(rowmax)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -223,10 +246,12 @@ static int nerv_matrix_(lua_rowmax)(lua_State *L) {
static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b;
Matrix *idx;
- nerv_matrix_(rowmax_idx)(a, &b, &idx, &status);
+ nerv_matrix_(rowmax_idx)(a, &b, &idx, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
luaT_pushudata(L, idx, nerv_matrix_(tname));
@@ -235,37 +260,45 @@ static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) {
static int nerv_matrix_(lua_add_row)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
double beta = luaL_checknumber(L, 3);
- nerv_matrix_(add_row)(b, a, beta, &status);
+ nerv_matrix_(add_row)(b, a, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_fill)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
double val = luaL_checknumber(L, 2);
- nerv_matrix_(fill)(self, val, &status);
+ nerv_matrix_(fill)(self, val, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_clip)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- double val_1 = luaL_checknumber(L, 2);
- double val_2 = luaL_checknumber(L, 3);
- nerv_matrix_(clip)(self, val_1, val_2, &status);
+ double val1 = luaL_checknumber(L, 2);
+ double val2 = luaL_checknumber(L, 3);
+ nerv_matrix_(clip)(self, val1, val2, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_trans)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(trans)(a, &status);
+ Matrix *b = nerv_matrix_(trans)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -273,28 +306,34 @@ static int nerv_matrix_(lua_trans)(lua_State *L) {
static int nerv_matrix_(lua_mul_elem)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(mul_elem)(c, a, b, &status);
+ nerv_matrix_(mul_elem)(c, a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_log_elem)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(log_elem)(b, a, &status);
+ nerv_matrix_(log_elem)(b, a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_decompress)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
const Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
int orig_col = luaL_checkinteger(L, 2);
- Matrix *b = nerv_matrix_(decompress)(a, orig_col, &status);
+ Matrix *b = nerv_matrix_(decompress)(a, orig_col, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -302,38 +341,46 @@ static int nerv_matrix_(lua_decompress)(lua_State *L) {
static int nerv_matrix_(lua_expand_frm)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- int context = luaL_checkinteger(L, 3);
- nerv_matrix_(expand_frm)(a, b, context, &status);
+ int cont = luaL_checkinteger(L, 3);
+ nerv_matrix_(expand_frm)(a, b, cont, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_rearrange_frm)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
int step = luaL_checkinteger(L, 3);
- nerv_matrix_(rearrange_frm)(a, b, step, &status);
+ nerv_matrix_(rearrange_frm)(a, b, step, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_scale_rows_by_col)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(scale_rows_by_col)(a, b, &status);
+ nerv_matrix_(scale_rows_by_col)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(scale_rows_by_row)(a, b, &status);
+ nerv_matrix_(scale_rows_by_row)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}