aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/matrix')
-rw-r--r--nerv/matrix/cumatrix.c59
-rw-r--r--nerv/matrix/generic/cumatrix.c53
-rw-r--r--nerv/matrix/generic/matrix.c101
-rw-r--r--nerv/matrix/generic/mmatrix.c27
-rw-r--r--nerv/matrix/init.lua9
-rw-r--r--nerv/matrix/matrix.h24
-rw-r--r--nerv/matrix/mmatrix.c51
7 files changed, 250 insertions, 74 deletions
diff --git a/nerv/matrix/cumatrix.c b/nerv/matrix/cumatrix.c
index 7f22d68..26b055b 100644
--- a/nerv/matrix/cumatrix.c
+++ b/nerv/matrix/cumatrix.c
@@ -4,45 +4,74 @@
#include "../lib/matrix/cuda_helper.h"
#include <string.h>
#define PROFILE_HASHMAP_SIZE 123457
-static cublasHandle_t cublas_handle;
-static cudaEvent_t profile_start, profile_stop;
-static HashMap *profile;
-static int select_gpu(lua_State *L) {
+const char *nerv_cuda_context_tname = "nerv.CuContext";
+
+int nerv_cuda_context_lua_select_gpu(lua_State *L) {
Status status;
- int dev = luaL_checkinteger(L, 1);
- nerv_cumatrix_select_gpu(dev, &status);
+ nerv_cuda_context_select_gpu(luaT_checkudata(L, 1, nerv_cuda_context_tname),
+ luaL_checkinteger(L, 1), &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
-static int print_profile(lua_State *L) {
- nerv_cumatrix_print_profile();
+int nerv_cuda_context_lua_print_profile(lua_State *L) {
+ nerv_cuda_context_print_profile(luaT_checkudata(L, 1, nerv_cuda_context_tname));
return 0;
}
-static int clear_profile(lua_State *L) {
- nerv_cumatrix_clear_profile();
+int nerv_cuda_context_lua_clear_profile(lua_State *L) {
+ nerv_cuda_context_clear_profile(luaT_checkudata(L, 1, nerv_cuda_context_tname));
return 0;
}
-static const luaL_Reg cumatrix_methods[] = {
- {"print_profile", print_profile},
- {"clear_profile", clear_profile},
- {"select_gpu", select_gpu},
+int nerv_cuda_context_lua_new(lua_State *L) {
+ Status status;
+ CuContext *self = nerv_cuda_context_create(&status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ luaT_pushudata(L, self, nerv_cuda_context_tname);
+ return 1;
+}
+
+int nerv_cuda_context_lua_destroy(lua_State *L) {
+ Status status;
+ CuContext *self = luaT_checkudata(L, 1, nerv_cuda_context_tname);
+ nerv_cuda_context_destroy(self, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ return 1;
+}
+
+static const luaL_Reg nerv_cuda_context_methods[] = {
+ {"print_profile", nerv_cuda_context_lua_print_profile},
+ {"clear_profile", nerv_cuda_context_lua_clear_profile},
+ {"select_gpu", nerv_cuda_context_lua_select_gpu},
{NULL, NULL}
};
+void nerv_cuda_context_lua_init(lua_State *L) {
+ luaT_newmetatable(L, nerv_cuda_context_tname, NULL,
+ nerv_cuda_context_lua_new,
+ nerv_cuda_context_lua_destroy, NULL);
+ luaL_register(L, NULL, nerv_cuda_context_methods);
+}
+
extern void nerv_matrix_cuda_float_lua_init(lua_State *L);
extern void nerv_matrix_cuda_double_lua_init(lua_State *L);
+static const luaL_Reg cumatrix_methods[] = {
+ {NULL, NULL}
+};
+
void nerv_lua_cumatrix_init(lua_State *L) {
luaL_register(L, NULL, cumatrix_methods);
- nerv_cumatrix_init();
+ nerv_cuda_context_lua_init(L);
nerv_matrix_cuda_float_lua_init(L);
nerv_matrix_cuda_double_lua_init(L);
}
+#define MATRIX_CONTEXT CuContext
+#define MATRIX_CONTEXT_TNAME nerv_cuda_context_tname
+
#define MATRIX_USE_FLOAT
#define cuda_matrix_(NAME) cuda_matrix_float_##NAME
#define nerv_matrix_(NAME) nerv_matrix_cuda_float_##NAME
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c
index b706c21..16c0e3a 100644
--- a/nerv/matrix/generic/cumatrix.c
+++ b/nerv/matrix/generic/cumatrix.c
@@ -6,6 +6,7 @@
#define MATRIX_BASE_TNAME nerv_matrix_cuda_tname
#define NERV_GENERIC_MATRIX
#define NERV_GENERIC_CUKERNEL
+#include "../matrix.h"
#include "../../lib/common.h"
#include "../../lib/matrix/generic/matrix.h"
#include "../../lib/matrix/generic/cumatrix.h"
@@ -17,48 +18,58 @@ static int nerv_matrix_(lua_get_blas_op)(char ch) {
static int nerv_matrix_(lua_prefixsum_row)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(prefixsum_row)(a, b, &status);
+ nerv_matrix_(prefixsum_row)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_thres_mask)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
MATRIX_ELEM thres = luaL_checknumber(L, 3);
MATRIX_ELEM low = luaL_checknumber(L, 4);
MATRIX_ELEM high = luaL_checknumber(L, 5);
- nerv_matrix_(thres_mask)(a, b, thres, low, high, &status);
+ nerv_matrix_(thres_mask)(a, b, thres, low, high, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_rand_uniform)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(rand_uniform)(a, &status);
+ nerv_matrix_(rand_uniform)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_tanh)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(tanh)(a, b, &status);
+ nerv_matrix_(tanh)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_tanh_grad)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
- nerv_matrix_(tanh_grad)(nerr, err, output, &status);
+ nerv_matrix_(tanh_grad)(nerr, err, output, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -66,39 +77,45 @@ static int nerv_matrix_(lua_tanh_grad)(lua_State *L) {
extern const char *MATRIX_CUMATRIX_HOST_TNAME;
static int nerv_matrix_(lua_copy_fromh)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
int nargs = lua_gettop(L);
int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, &status);
+ nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_toh)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
int nargs = lua_gettop(L);
int a_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
int a_end = nargs > 3 ? luaL_checkinteger(L, 4) : a->nrow;
int b_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- nerv_matrix_(copy_toh)(a, b, a_begin, a_end, b_begin, &status);
+ nerv_matrix_(copy_toh)(a, b, a_begin, a_end, b_begin, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_fromd)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
int nargs = lua_gettop(L);
int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- nerv_matrix_(copy_fromd)(a, b, a_begin, b_begin, b_end, &status);
+ nerv_matrix_(copy_fromd)(a, b, a_begin, b_begin, b_end, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -106,36 +123,42 @@ static int nerv_matrix_(lua_copy_fromd)(lua_State *L) {
extern const char *nerv_matrix_host_float_tname;
static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 5);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_float_tname);
long nrow = a->nrow;
int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
- nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, &status);
+ nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_rows_fromd_by_idx)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 5);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
long nrow = a->nrow;
int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
- nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, idx_begin, &status);
+ nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, idx_begin, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_rows_fromd_by_colidx)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 5);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
long nrow = a->nrow;
int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
- nerv_matrix_(copy_rows_fromd_by_colidx)(a, b, idx, idx_begin, &status);
+ nerv_matrix_(copy_rows_fromd_by_colidx)(a, b, idx, idx_begin, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -145,12 +168,14 @@ static int nerv_matrix_(lua_update_select_rows_by_rowidx)(lua_State *L) {
/* update c's select rows,
* i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
- nerv_matrix_(update_select_rows_by_rowidx)(c, a, idx, alpha, beta, &status);
+ nerv_matrix_(update_select_rows_by_rowidx)(c, a, idx, alpha, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -159,12 +184,14 @@ static int nerv_matrix_(lua_update_select_rows_by_colidx)(lua_State *L) {
/* update c's select rows,
* i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
- nerv_matrix_(update_select_rows_by_colidx)(c, a, idx, alpha, beta, &status);
+ nerv_matrix_(update_select_rows_by_colidx)(c, a, idx, alpha, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c
index c1da774..c2e57b8 100644
--- a/nerv/matrix/generic/matrix.c
+++ b/nerv/matrix/generic/matrix.c
@@ -1,15 +1,18 @@
#ifdef NERV_GENERIC_MATRIX
+#include "../matrix.h"
#include "../../lib/common.h"
#include "../../lib/matrix/generic/matrix.h"
extern const char *nerv_matrix_(tname);
extern const char *MATRIX_BASE_TNAME;
-
int nerv_matrix_(lua_new)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *self = nerv_matrix_(create)(luaL_checkinteger(L, 1),
- luaL_checkinteger(L, 2), &status);
+ luaL_checkinteger(L, 2),
+ context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, self, nerv_matrix_(tname));
return 1;
@@ -17,8 +20,10 @@ int nerv_matrix_(lua_new)(lua_State *L) {
int nerv_matrix_(lua_destroy)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(destroy)(self, &status);
+ nerv_matrix_(destroy)(self, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 1;
}
@@ -128,18 +133,22 @@ void nerv_matrix_(lua_init)(lua_State *L) {
static int nerv_matrix_(lua_add)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
- nerv_matrix_(add)(c, a, b, alpha, beta, &status);
+ nerv_matrix_(add)(c, a, b, alpha, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_mul)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 8);
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
@@ -150,35 +159,41 @@ static int nerv_matrix_(lua_mul)(lua_State *L) {
: BLAS_OP_N;
int tb = nargs > 6 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 7)) \
: BLAS_OP_N;
- nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, &status);
+ nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_sigmoid)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(sigmoid)(a, b, &status);
+ nerv_matrix_(sigmoid)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_sigmoid_grad)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
- nerv_matrix_(sigmoid_grad)(nerr, err, output, &status);
+ nerv_matrix_(sigmoid_grad)(nerr, err, output, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_softmax)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *max_idx = nerv_matrix_(softmax)(b, a, &status);
+ Matrix *max_idx = nerv_matrix_(softmax)(b, a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, max_idx, nerv_matrix_(tname));
return 1;
@@ -186,8 +201,10 @@ static int nerv_matrix_(lua_softmax)(lua_State *L) {
static int nerv_matrix_(lua_rowsum)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(rowsum)(a, &status);
+ Matrix *b = nerv_matrix_(rowsum)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -195,8 +212,10 @@ static int nerv_matrix_(lua_rowsum)(lua_State *L) {
static int nerv_matrix_(lua_colsum)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(colsum)(a, &status);
+ Matrix *b = nerv_matrix_(colsum)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -204,9 +223,11 @@ static int nerv_matrix_(lua_colsum)(lua_State *L) {
static int nerv_matrix_(lua_colsame)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(colsame)(a, ref, &status);
+ Matrix *b = nerv_matrix_(colsame)(a, ref, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -214,8 +235,10 @@ static int nerv_matrix_(lua_colsame)(lua_State *L) {
static int nerv_matrix_(lua_rowmax)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(rowmax)(a, &status);
+ Matrix *b = nerv_matrix_(rowmax)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -223,10 +246,12 @@ static int nerv_matrix_(lua_rowmax)(lua_State *L) {
static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b;
Matrix *idx;
- nerv_matrix_(rowmax_idx)(a, &b, &idx, &status);
+ nerv_matrix_(rowmax_idx)(a, &b, &idx, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
luaT_pushudata(L, idx, nerv_matrix_(tname));
@@ -235,37 +260,45 @@ static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) {
static int nerv_matrix_(lua_add_row)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
double beta = luaL_checknumber(L, 3);
- nerv_matrix_(add_row)(b, a, beta, &status);
+ nerv_matrix_(add_row)(b, a, beta, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_fill)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
double val = luaL_checknumber(L, 2);
- nerv_matrix_(fill)(self, val, &status);
+ nerv_matrix_(fill)(self, val, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_clip)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- double val_1 = luaL_checknumber(L, 2);
- double val_2 = luaL_checknumber(L, 3);
- nerv_matrix_(clip)(self, val_1, val_2, &status);
+ double val1 = luaL_checknumber(L, 2);
+ double val2 = luaL_checknumber(L, 3);
+ nerv_matrix_(clip)(self, val1, val2, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_trans)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(trans)(a, &status);
+ Matrix *b = nerv_matrix_(trans)(a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -273,28 +306,34 @@ static int nerv_matrix_(lua_trans)(lua_State *L) {
static int nerv_matrix_(lua_mul_elem)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(mul_elem)(c, a, b, &status);
+ nerv_matrix_(mul_elem)(c, a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_log_elem)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(log_elem)(b, a, &status);
+ nerv_matrix_(log_elem)(b, a, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_decompress)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
const Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
int orig_col = luaL_checkinteger(L, 2);
- Matrix *b = nerv_matrix_(decompress)(a, orig_col, &status);
+ Matrix *b = nerv_matrix_(decompress)(a, orig_col, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -302,38 +341,46 @@ static int nerv_matrix_(lua_decompress)(lua_State *L) {
static int nerv_matrix_(lua_expand_frm)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- int context = luaL_checkinteger(L, 3);
- nerv_matrix_(expand_frm)(a, b, context, &status);
+ int cont = luaL_checkinteger(L, 3);
+ nerv_matrix_(expand_frm)(a, b, cont, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_rearrange_frm)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
int step = luaL_checkinteger(L, 3);
- nerv_matrix_(rearrange_frm)(a, b, step, &status);
+ nerv_matrix_(rearrange_frm)(a, b, step, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_scale_rows_by_col)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(scale_rows_by_col)(a, b, &status);
+ nerv_matrix_(scale_rows_by_col)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- nerv_matrix_(scale_rows_by_row)(a, b, &status);
+ nerv_matrix_(scale_rows_by_row)(a, b, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
diff --git a/nerv/matrix/generic/mmatrix.c b/nerv/matrix/generic/mmatrix.c
index 93562d0..69000b7 100644
--- a/nerv/matrix/generic/mmatrix.c
+++ b/nerv/matrix/generic/mmatrix.c
@@ -1,4 +1,5 @@
#ifdef NERV_GENERIC_MMATRIX
+#include "../matrix.h"
#include "../../lib/matrix/generic/matrix.h"
#include "../../lib/matrix/generic/elem_type.h"
#define MATRIX_DATA_WRITE(L, data, idx, val) (data[idx] = val)
@@ -48,8 +49,10 @@ static void host_matrix_(init)(lua_State *L) {
static int nerv_matrix_(lua_load)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
ChunkData *cdp = luaT_checkudata(L, 1, nerv_chunk_data_tname);
- Matrix *self = nerv_matrix_(load)(cdp, &status);
+ Matrix *self = nerv_matrix_(load)(cdp, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, self, nerv_matrix_(tname));
return 1;
@@ -57,23 +60,27 @@ static int nerv_matrix_(lua_load)(lua_State *L) {
static int nerv_matrix_(lua_save)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
ChunkFile *cfp = luaT_checkudata(L, 2,
nerv_chunk_file_handle_tname);
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(save)(self, cfp, &status);
+ nerv_matrix_(save)(self, cfp, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
static int nerv_matrix_(lua_copy_fromh)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 6);
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
int nargs = lua_gettop(L);
int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, &status);
+ nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -81,12 +88,14 @@ static int nerv_matrix_(lua_copy_fromh)(lua_State *L) {
static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L)
{
Status status;
- Matrix *a=luaT_checkudata(L,1,nerv_matrix_(tname));
- const Matrix *b=luaT_checkudata(L,2,nerv_matrix_(tname));
- const Matrix *idx=luaT_checkudata(L,3,nerv_matrix_(tname));
- int b_begin=lua_gettop(L)>3?luaL_checkinteger(L,4):0;
- nerv_matrix_(copy_rows_fromh_by_idx)(a,b,idx,b_begin,&status);
- NERV_LUA_CHECK_STATUS(L,status);
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 5);
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
+ int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
+ nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, context, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
diff --git a/nerv/matrix/init.lua b/nerv/matrix/init.lua
index da76e1b..ef2fb6b 100644
--- a/nerv/matrix/init.lua
+++ b/nerv/matrix/init.lua
@@ -130,12 +130,3 @@ end
function nerv.MMatrix:copy_toh(b, ...)
b:copy_fromh(self, ...)
end
-
---- Print profiling info of host matrices
-function nerv.MMatrix.print_profile()
- nerv.info("mmatrix profile not available")
-end
-
---- Clear profiling info of host matrices
-function nerv.MMatrix.clear_profile()
-end
diff --git a/nerv/matrix/matrix.h b/nerv/matrix/matrix.h
new file mode 100644
index 0000000..788f596
--- /dev/null
+++ b/nerv/matrix/matrix.h
@@ -0,0 +1,24 @@
+#ifndef NERV_LUA_MATRIX_H
+#define NERV_LUA_MATRIX_H
+#include "../lib/luaT/luaT.h"
+#define _MATRIX_GET_CONTEXT(L, p, tname, ctname) \
+ do { \
+ if (lua_gettop(L) < p) \
+ { \
+ luaT_pushmetatable(L, tname); \
+ lua_getfield(L, -1, "_default_context"); \
+ context = luaT_checkudata(L, -1, ctname); \
+ lua_pop(L, 2); \
+ } \
+ else \
+ { \
+ context = luaT_checkudata(L, p, ctname); \
+ } \
+ } while (0)
+
+extern const char *nerv_cuda_context_tname;
+extern const char *nerv_host_context_tname;
+extern const char *nerv_matrix_host_tname;
+#define MATRIX_GET_CONTEXT(L, p) _MATRIX_GET_CONTEXT(L, p, nerv_matrix_(tname), MATRIX_CONTEXT_TNAME)
+#define MMATRIX_GET_CONTEXT(L, p) _MATRIX_GET_CONTEXT(L, p, nerv_matrix_host_tname, nerv_host_context_tname)
+#endif
diff --git a/nerv/matrix/mmatrix.c b/nerv/matrix/mmatrix.c
index a68506d..45cb238 100644
--- a/nerv/matrix/mmatrix.c
+++ b/nerv/matrix/mmatrix.c
@@ -1,17 +1,64 @@
#define NERV_GENERIC_MMATRIX
#include <stdlib.h>
+#include "../lib/matrix/mmatrix.h"
#include "../lib/common.h"
+
+const char *nerv_host_context_tname = "nerv.MContext";
+
+int nerv_host_context_lua_print_profile(lua_State *L) {
+ nerv_host_context_print_profile(luaT_checkudata(L, 1, nerv_host_context_tname));
+ return 0;
+}
+
+int nerv_host_context_lua_clear_profile(lua_State *L) {
+ nerv_host_context_clear_profile(luaT_checkudata(L, 1, nerv_host_context_tname));
+ return 0;
+}
+
+int nerv_host_context_lua_new(lua_State *L) {
+ Status status;
+ MContext *self = nerv_host_context_create(&status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ luaT_pushudata(L, self, nerv_host_context_tname);
+ return 1;
+}
+
+int nerv_host_context_lua_destroy(lua_State *L) {
+ Status status;
+ MContext *self = luaT_checkudata(L, 1, nerv_host_context_tname);
+ nerv_host_context_destroy(self, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ return 1;
+}
+
+static const luaL_Reg nerv_host_context_methods[] = {
+ {"print_profile", nerv_host_context_lua_print_profile},
+ {"clear_profile", nerv_host_context_lua_clear_profile},
+ {NULL, NULL}
+};
+
+void nerv_host_context_lua_init(lua_State *L) {
+ luaT_newmetatable(L, nerv_host_context_tname, NULL,
+ nerv_host_context_lua_new,
+ nerv_host_context_lua_destroy, NULL);
+ luaL_register(L, NULL, nerv_host_context_methods);
+}
+
void nerv_matrix_host_float_lua_init(lua_State *L);
void nerv_matrix_host_double_lua_init(lua_State *L);
void nerv_matrix_host_int_lua_init(lua_State *L);
void nerv_lua_mmatrix_init(lua_State *L) {
srand(1);
+ nerv_host_context_lua_init(L);
nerv_matrix_host_float_lua_init(L);
nerv_matrix_host_double_lua_init(L);
nerv_matrix_host_int_lua_init(L);
}
+#define MATRIX_CONTEXT MContext
+#define MATRIX_CONTEXT_TNAME nerv_host_context_tname
+
#define MATRIX_USE_FLOAT
#define host_matrix_(NAME) host_matrix_float_##NAME
#define nerv_matrix_(NAME) nerv_matrix_host_float_##NAME
@@ -29,8 +76,10 @@ static void host_matrix_(init_extra)(lua_State *L) {
static int nerv_matrix_(lua_perm_gen)(lua_State *L) {
Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 2);
int i, ncol = luaL_checkinteger(L, 1);
- Matrix *self = nerv_matrix_(perm_gen)(ncol, &status);
+ Matrix *self = nerv_matrix_(perm_gen)(ncol, context, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, self, nerv_matrix_(tname));
return 1;