aboutsummaryrefslogtreecommitdiff
path: root/matrix/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/cumatrix.c')
-rw-r--r--matrix/cumatrix.c16
1 files changed, 13 insertions, 3 deletions
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index 7759ca1..49b7fbf 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -9,6 +9,7 @@
#include "generic/matrix.h"
#include "cukernel.h"
#include "cuda.h"
+#include "cuda_runtime.h"
#include "driver_types.h"
#include "cublas_v2.h"
@@ -65,10 +66,19 @@ static int nerv_float_matrix_(sigmoid)(lua_State *L) {
return 1;
}
+static int nerv_float_matrix_(rowsum)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+ Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
+ cuda_rowsum(a, b);
+ luaT_pushudata(L, b, nerv_float_matrix_(tname));
+ return 1;
+}
+
static const luaL_Reg nerv_float_matrix_(extra_methods)[] = {
{"__add__", nerv_float_matrix_(add)},
{"__mul__", nerv_float_matrix_(mul)},
{"sigmoid", nerv_float_matrix_(sigmoid)},
+ {"rowsum", nerv_float_matrix_(rowsum)},
{NULL, NULL}
};
@@ -77,13 +87,13 @@ static void cuda_float_init(lua_State *L) {
cublasCreate(&cublas_handle);
}
-static cuda_float_array_free(float *ptr) {
+static void cuda_float_array_free(float *ptr) {
cudaFree(ptr);
}
-static cuda_float_array_alloc(float **dptr, long *stride,
+static void cuda_float_array_alloc(float **dptr, size_t *stride,
long width, long height) {
- cudaMallocPitch(dptr, stride, width, height);
+ cudaMallocPitch((void **)dptr, stride, width, height);
}
static float cuda_float_array_read(float *data, int idx) {