aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix/generic/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/matrix/generic/cumatrix.c')
-rw-r--r--nerv/matrix/generic/cumatrix.c53
1 files changed, 40 insertions, 13 deletions
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c
index b706c21..16c0e3a 100644
--- a/nerv/matrix/generic/cumatrix.c
+++ b/nerv/matrix/generic/cumatrix.c
@@ -6,6 +6,7 @@
#define MATRIX_BASE_TNAME nerv_matrix_cuda_tname
#define NERV_GENERIC_MATRIX
#define NERV_GENERIC_CUKERNEL
+#include "../matrix.h"
#include "../../lib/common.h"
#include "../../lib/matrix/generic/matrix.h"
#include "../../lib/matrix/generic/cumatrix.h"
@@ -17,48 +18,58 @@ static int nerv_matrix_(lua_get_blas_op)(char ch) {
static int nerv_matrix_(lua_prefixsum_row)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(prefixsum_row)(a, b, &status);
+ nerv_matrix_(prefixsum_row)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_thres_mask)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
MATRIX_ELEM thres = luaL_checknumber(L, 3);
MATRIX_ELEM low = luaL_checknumber(L, 4);
MATRIX_ELEM high = luaL_checknumber(L, 5);
- nerv_matrix_(thres_mask)(a, b, thres, low, high, &status);
+ nerv_matrix_(thres_mask)(a, b, thres, low, high, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_rand_uniform)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(rand_uniform)(a, &status);
+ nerv_matrix_(rand_uniform)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_tanh)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(tanh)(a, b, &status);
+ nerv_matrix_(tanh)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_tanh_grad)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
- nerv_matrix_(tanh_grad)(nerr, err, output, &status);
+ nerv_matrix_(tanh_grad)(nerr, err, output, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -66,39 +77,45 @@ static int nerv_matrix_(lua_tanh_grad)(lua_State *L) {
extern const char *MATRIX_CUMATRIX_HOST_TNAME;
static int nerv_matrix_(lua_copy_fromh)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
int nargs = lua_gettop(L);
int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, &status);
+ nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_toh)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
int nargs = lua_gettop(L);
int a_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
int a_end = nargs > 3 ? luaL_checkinteger(L, 4) : a->nrow;
int b_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- nerv_matrix_(copy_toh)(a, b, a_begin, a_end, b_begin, &status);
+ nerv_matrix_(copy_toh)(a, b, a_begin, a_end, b_begin, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_fromd)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
int nargs = lua_gettop(L);
int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- nerv_matrix_(copy_fromd)(a, b, a_begin, b_begin, b_end, &status);
+ nerv_matrix_(copy_fromd)(a, b, a_begin, b_begin, b_end, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -106,36 +123,42 @@ static int nerv_matrix_(lua_copy_fromd)(lua_State *L) {
extern const char *nerv_matrix_host_float_tname;
static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 5);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_float_tname);
long nrow = a->nrow;
int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
- nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, &status);
+ nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_rows_fromd_by_idx)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 5);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
long nrow = a->nrow;
int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
- nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, idx_begin, &status);
+ nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, idx_begin, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_rows_fromd_by_colidx)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 5);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
long nrow = a->nrow;
int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
- nerv_matrix_(copy_rows_fromd_by_colidx)(a, b, idx, idx_begin, &status);
+ nerv_matrix_(copy_rows_fromd_by_colidx)(a, b, idx, idx_begin, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -145,12 +168,14 @@ static int nerv_matrix_(lua_update_select_rows_by_rowidx)(lua_State *L) {
/* update c's select rows,
* i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
- nerv_matrix_(update_select_rows_by_rowidx)(c, a, idx, alpha, beta, &status);
+ nerv_matrix_(update_select_rows_by_rowidx)(c, a, idx, alpha, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -159,12 +184,14 @@ static int nerv_matrix_(lua_update_select_rows_by_colidx)(lua_State *L) {
/* update c's select rows,
* i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
- nerv_matrix_(update_select_rows_by_colidx)(c, a, idx, alpha, beta, &status);
+ nerv_matrix_(update_select_rows_by_colidx)(c, a, idx, alpha, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}