aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/matrix')
-rw-r--r--nerv/matrix/cumatrix.c5
-rw-r--r--nerv/matrix/generic/mmatrix.c2
-rw-r--r--nerv/matrix/init.lua21
3 files changed, 24 insertions, 4 deletions
diff --git a/nerv/matrix/cumatrix.c b/nerv/matrix/cumatrix.c
index f6a4ed5..b8eef9c 100644
--- a/nerv/matrix/cumatrix.c
+++ b/nerv/matrix/cumatrix.c
@@ -9,7 +9,7 @@ const char *nerv_cuda_context_tname = "nerv.CuContext";
int nerv_cuda_context_lua_select_gpu(lua_State *L) {
Status status;
nerv_cuda_context_select_gpu(luaT_checkudata(L, 1, nerv_cuda_context_tname),
- luaL_checkinteger(L, 1), &status);
+ luaL_checkinteger(L, 2), &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -26,7 +26,8 @@ int nerv_cuda_context_lua_clear_profile(lua_State *L) {
int nerv_cuda_context_lua_new(lua_State *L) {
Status status;
- CuContext *self = nerv_cuda_context_create(&status);
+ int dev = lua_gettop(L) > 0 ? luaL_checkinteger(L, 1) : -1;
+ CuContext *self = nerv_cuda_context_create(dev, &status);
NERV_LUA_CHECK_STATUS(L, status);
luaT_pushudata(L, self, nerv_cuda_context_tname);
return 1;
diff --git a/nerv/matrix/generic/mmatrix.c b/nerv/matrix/generic/mmatrix.c
index 69000b7..1f37173 100644
--- a/nerv/matrix/generic/mmatrix.c
+++ b/nerv/matrix/generic/mmatrix.c
@@ -8,10 +8,10 @@
#define MATRIX_BASE_TNAME nerv_matrix_host_tname
#define NERV_GENERIC_MATRIX
#include "../../lib/common.h"
+#include "../../lib/cblas.h"
#include "../../lib/matrix/generic/mmatrix.h"
#include "../../io/chunk_file.h"
#include <string.h>
-#include <cblas.h>
#define BLAS_OP_N CblasNoTrans
static int nerv_matrix_(lua_get_blas_op)(char ch) {
diff --git a/nerv/matrix/init.lua b/nerv/matrix/init.lua
index ef2fb6b..cf85004 100644
--- a/nerv/matrix/init.lua
+++ b/nerv/matrix/init.lua
@@ -87,6 +87,17 @@ function nerv.Matrix:__mul__(b)
return c
end
+--- A wrapper function for `copy_from`
+function nerv.Matrix:copy_to(b, ...)
+ b:copy_from(self, ...)
+end
+
+--- The base class for all device (in-GPU) matrices
+-- @type nerv.CuMatrix
+
+--- A wrapper function for `copy_fromd`
+nerv.CuMatrix.copy_tod = nerv.Matrix.copy_to
+
--- CUDA float matrices
-- @type nerv.CuMatrixFloat
@@ -127,6 +138,14 @@ end
-- @type nerv.MMatrix
--- A wrapper function for `copy_fromh`
-function nerv.MMatrix:copy_toh(b, ...)
+nerv.MMatrix.copy_toh = nerv.Matrix.copy_to
+
+--- A wrapper function for `nerv.CuMatrix` copy
+function nerv.MMatrix:copy_fromd(b, ...)
+ b:copy_toh(self, ...)
+end
+
+--- A wrapper function for `nerv.CuMatrix` copy
+function nerv.MMatrix:copy_tod(b, ...)
b:copy_fromh(self, ...)
end