diff options
Diffstat (limited to 'nerv/matrix/generic')
-rw-r--r-- | nerv/matrix/generic/cumatrix.c | 53 | ||||
-rw-r--r-- | nerv/matrix/generic/matrix.c | 101 | ||||
-rw-r--r-- | nerv/matrix/generic/mmatrix.c | 27 |
3 files changed, 132 insertions, 49 deletions
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c index b706c21..16c0e3a 100644 --- a/nerv/matrix/generic/cumatrix.c +++ b/nerv/matrix/generic/cumatrix.c @@ -6,6 +6,7 @@ #define MATRIX_BASE_TNAME nerv_matrix_cuda_tname #define NERV_GENERIC_MATRIX #define NERV_GENERIC_CUKERNEL +#include "../matrix.h" #include "../../lib/common.h" #include "../../lib/matrix/generic/matrix.h" #include "../../lib/matrix/generic/cumatrix.h" @@ -17,48 +18,58 @@ static int nerv_matrix_(lua_get_blas_op)(char ch) { static int nerv_matrix_(lua_prefixsum_row)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); - nerv_matrix_(prefixsum_row)(a, b, &status); + nerv_matrix_(prefixsum_row)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_thres_mask)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 6); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); MATRIX_ELEM thres = luaL_checknumber(L, 3); MATRIX_ELEM low = luaL_checknumber(L, 4); MATRIX_ELEM high = luaL_checknumber(L, 5); - nerv_matrix_(thres_mask)(a, b, thres, low, high, &status); + nerv_matrix_(thres_mask)(a, b, thres, low, high, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_rand_uniform)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); - nerv_matrix_(rand_uniform)(a, &status); + nerv_matrix_(rand_uniform)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_tanh)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); - nerv_matrix_(tanh)(a, b, &status); + nerv_matrix_(tanh)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_tanh_grad)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 4); Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname)); - nerv_matrix_(tanh_grad)(nerr, err, output, &status); + nerv_matrix_(tanh_grad)(nerr, err, output, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } @@ -66,39 +77,45 @@ static int nerv_matrix_(lua_tanh_grad)(lua_State *L) { extern const char *MATRIX_CUMATRIX_HOST_TNAME; static int nerv_matrix_(lua_copy_fromh)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 6); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME); int nargs = lua_gettop(L); int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0; int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow; int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0; - nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, &status); + nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_copy_toh)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 6); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME); int nargs = lua_gettop(L); int a_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0; int a_end = nargs > 3 ? luaL_checkinteger(L, 4) : a->nrow; int b_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0; - nerv_matrix_(copy_toh)(a, b, a_begin, a_end, b_begin, &status); + nerv_matrix_(copy_toh)(a, b, a_begin, a_end, b_begin, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_copy_fromd)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 6); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); int nargs = lua_gettop(L); int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0; int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow; int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0; - nerv_matrix_(copy_fromd)(a, b, a_begin, b_begin, b_end, &status); + nerv_matrix_(copy_fromd)(a, b, a_begin, b_begin, b_end, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } @@ -106,36 +123,42 @@ static int nerv_matrix_(lua_copy_fromd)(lua_State *L) { extern const char *nerv_matrix_host_float_tname; static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 5); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME); const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_float_tname); long nrow = a->nrow; int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; - nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, &status); + nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_copy_rows_fromd_by_idx)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 5); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname)); long nrow = a->nrow; int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; - nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, idx_begin, &status); + nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, idx_begin, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_copy_rows_fromd_by_colidx)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 5); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname)); long nrow = a->nrow; int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; - nerv_matrix_(copy_rows_fromd_by_colidx)(a, b, idx, idx_begin, &status); + nerv_matrix_(copy_rows_fromd_by_colidx)(a, b, idx, idx_begin, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } @@ -145,12 +168,14 @@ static int nerv_matrix_(lua_update_select_rows_by_rowidx)(lua_State *L) { /* update c's select rows, * i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */ Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 6); Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname)); MATRIX_ELEM alpha = luaL_checknumber(L, 4); MATRIX_ELEM beta = luaL_checknumber(L, 5); - nerv_matrix_(update_select_rows_by_rowidx)(c, a, idx, alpha, beta, &status); + nerv_matrix_(update_select_rows_by_rowidx)(c, a, idx, alpha, beta, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } @@ -159,12 +184,14 @@ static int nerv_matrix_(lua_update_select_rows_by_colidx)(lua_State *L) { /* update c's select rows, * i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */ Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 6); Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname)); MATRIX_ELEM alpha = luaL_checknumber(L, 4); MATRIX_ELEM beta = luaL_checknumber(L, 5); - nerv_matrix_(update_select_rows_by_colidx)(c, a, idx, alpha, beta, &status); + nerv_matrix_(update_select_rows_by_colidx)(c, a, idx, alpha, beta, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c index c1da774..c2e57b8 100644 --- a/nerv/matrix/generic/matrix.c +++ b/nerv/matrix/generic/matrix.c @@ -1,15 +1,18 @@ #ifdef NERV_GENERIC_MATRIX +#include "../matrix.h" #include "../../lib/common.h" #include "../../lib/matrix/generic/matrix.h" extern const char *nerv_matrix_(tname); extern const char *MATRIX_BASE_TNAME; - int nerv_matrix_(lua_new)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *self = nerv_matrix_(create)(luaL_checkinteger(L, 1), - luaL_checkinteger(L, 2), &status); + luaL_checkinteger(L, 2), + context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, self, nerv_matrix_(tname)); return 1; @@ -17,8 +20,10 @@ int nerv_matrix_(lua_new)(lua_State *L) { int nerv_matrix_(lua_destroy)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); - nerv_matrix_(destroy)(self, &status); + nerv_matrix_(destroy)(self, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 1; } @@ -128,18 +133,22 @@ void nerv_matrix_(lua_init)(lua_State *L) { static int nerv_matrix_(lua_add)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 6); Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); MATRIX_ELEM alpha = luaL_checknumber(L, 4); MATRIX_ELEM beta = luaL_checknumber(L, 5); - nerv_matrix_(add)(c, a, b, alpha, beta, &status); + nerv_matrix_(add)(c, a, b, alpha, beta, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_mul)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 8); Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); @@ -150,35 +159,41 @@ static int nerv_matrix_(lua_mul)(lua_State *L) { : BLAS_OP_N; int tb = nargs > 6 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 7)) \ : BLAS_OP_N; - nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, &status); + nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_sigmoid)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); - nerv_matrix_(sigmoid)(a, b, &status); + nerv_matrix_(sigmoid)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_sigmoid_grad)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 4); Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname)); - nerv_matrix_(sigmoid_grad)(nerr, err, output, &status); + nerv_matrix_(sigmoid_grad)(nerr, err, output, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_softmax)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); - Matrix *max_idx = nerv_matrix_(softmax)(b, a, &status); + Matrix *max_idx = nerv_matrix_(softmax)(b, a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, max_idx, nerv_matrix_(tname)); return 1; @@ -186,8 +201,10 @@ static int nerv_matrix_(lua_softmax)(lua_State *L) { static int nerv_matrix_(lua_rowsum)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); - Matrix *b = nerv_matrix_(rowsum)(a, &status); + Matrix *b = nerv_matrix_(rowsum)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; @@ -195,8 +212,10 @@ static int nerv_matrix_(lua_rowsum)(lua_State *L) { static int nerv_matrix_(lua_colsum)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); - Matrix *b = nerv_matrix_(colsum)(a, &status); + Matrix *b = nerv_matrix_(colsum)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; @@ -204,9 +223,11 @@ static int nerv_matrix_(lua_colsum)(lua_State *L) { static int nerv_matrix_(lua_colsame)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname)); - Matrix *b = nerv_matrix_(colsame)(a, ref, &status); + Matrix *b = nerv_matrix_(colsame)(a, ref, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; @@ -214,8 +235,10 @@ static int nerv_matrix_(lua_colsame)(lua_State *L) { static int nerv_matrix_(lua_rowmax)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); - Matrix *b = nerv_matrix_(rowmax)(a, &status); + Matrix *b = nerv_matrix_(rowmax)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; @@ -223,10 +246,12 @@ static int nerv_matrix_(lua_rowmax)(lua_State *L) { static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b; Matrix *idx; - nerv_matrix_(rowmax_idx)(a, &b, &idx, &status); + nerv_matrix_(rowmax_idx)(a, &b, &idx, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); luaT_pushudata(L, idx, nerv_matrix_(tname)); @@ -235,37 +260,45 @@ static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) { static int nerv_matrix_(lua_add_row)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 4); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); double beta = luaL_checknumber(L, 3); - nerv_matrix_(add_row)(b, a, beta, &status); + nerv_matrix_(add_row)(b, a, beta, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_fill)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); double val = luaL_checknumber(L, 2); - nerv_matrix_(fill)(self, val, &status); + nerv_matrix_(fill)(self, val, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_clip)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 4); Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); - double val_1 = luaL_checknumber(L, 2); - double val_2 = luaL_checknumber(L, 3); - nerv_matrix_(clip)(self, val_1, val_2, &status); + double val1 = luaL_checknumber(L, 2); + double val2 = luaL_checknumber(L, 3); + nerv_matrix_(clip)(self, val1, val2, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_trans)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); - Matrix *b = nerv_matrix_(trans)(a, &status); + Matrix *b = nerv_matrix_(trans)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; @@ -273,28 +306,34 @@ static int nerv_matrix_(lua_trans)(lua_State *L) { static int nerv_matrix_(lua_mul_elem)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 4); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); - nerv_matrix_(mul_elem)(c, a, b, &status); + nerv_matrix_(mul_elem)(c, a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_log_elem)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); - nerv_matrix_(log_elem)(b, a, &status); + nerv_matrix_(log_elem)(b, a, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_decompress)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); const Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); int orig_col = luaL_checkinteger(L, 2); - Matrix *b = nerv_matrix_(decompress)(a, orig_col, &status); + Matrix *b = nerv_matrix_(decompress)(a, orig_col, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; @@ -302,38 +341,46 @@ static int nerv_matrix_(lua_decompress)(lua_State *L) { static int nerv_matrix_(lua_expand_frm)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 4); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); - int context = luaL_checkinteger(L, 3); - nerv_matrix_(expand_frm)(a, b, context, &status); + int cont = luaL_checkinteger(L, 3); + nerv_matrix_(expand_frm)(a, b, cont, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_rearrange_frm)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 4); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); int step = luaL_checkinteger(L, 3); - nerv_matrix_(rearrange_frm)(a, b, step, &status); + nerv_matrix_(rearrange_frm)(a, b, step, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_scale_rows_by_col)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); - nerv_matrix_(scale_rows_by_col)(a, b, &status); + nerv_matrix_(scale_rows_by_col)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); - nerv_matrix_(scale_rows_by_row)(a, b, &status); + nerv_matrix_(scale_rows_by_row)(a, b, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } diff --git a/nerv/matrix/generic/mmatrix.c b/nerv/matrix/generic/mmatrix.c index 93562d0..69000b7 100644 --- a/nerv/matrix/generic/mmatrix.c +++ b/nerv/matrix/generic/mmatrix.c @@ -1,4 +1,5 @@ #ifdef NERV_GENERIC_MMATRIX +#include "../matrix.h" #include "../../lib/matrix/generic/matrix.h" #include "../../lib/matrix/generic/elem_type.h" #define MATRIX_DATA_WRITE(L, data, idx, val) (data[idx] = val) @@ -48,8 +49,10 @@ static void host_matrix_(init)(lua_State *L) { static int nerv_matrix_(lua_load)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); ChunkData *cdp = luaT_checkudata(L, 1, nerv_chunk_data_tname); - Matrix *self = nerv_matrix_(load)(cdp, &status); + Matrix *self = nerv_matrix_(load)(cdp, context, &status); NERV_LUA_CHECK_STATUS(L, status); luaT_pushudata(L, self, nerv_matrix_(tname)); return 1; @@ -57,23 +60,27 @@ static int nerv_matrix_(lua_load)(lua_State *L) { static int nerv_matrix_(lua_save)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 3); ChunkFile *cfp = luaT_checkudata(L, 2, nerv_chunk_file_handle_tname); Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); - nerv_matrix_(save)(self, cfp, &status); + nerv_matrix_(save)(self, cfp, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } static int nerv_matrix_(lua_copy_fromh)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 6); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); int nargs = lua_gettop(L); int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0; int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow; int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0; - nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, &status); + nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } @@ -81,12 +88,14 @@ static int nerv_matrix_(lua_copy_fromh)(lua_State *L) { static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) { Status status; - Matrix *a=luaT_checkudata(L,1,nerv_matrix_(tname)); - const Matrix *b=luaT_checkudata(L,2,nerv_matrix_(tname)); - const Matrix *idx=luaT_checkudata(L,3,nerv_matrix_(tname)); - int b_begin=lua_gettop(L)>3?luaL_checkinteger(L,4):0; - nerv_matrix_(copy_rows_fromh_by_idx)(a,b,idx,b_begin,&status); - NERV_LUA_CHECK_STATUS(L,status); + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 5); + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname)); + int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; + nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, context, &status); + NERV_LUA_CHECK_STATUS(L, status); return 0; } |