diff options
-rw-r--r-- | nerv/Makefile | 2 | ||||
-rw-r--r-- | nerv/lib/matrix/cukernel.h | 2 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/cukernel.cu | 31 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/cumatrix.c | 15 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/matrix.c | 8 | ||||
-rw-r--r-- | nerv/lib/matrix/matrix.h | 2 | ||||
-rw-r--r-- | nerv/matrix/generic/cumatrix.c | 21 |
7 files changed, 80 insertions, 1 deletions
diff --git a/nerv/Makefile b/nerv/Makefile index f154cc3..df6ce98 100644 --- a/nerv/Makefile +++ b/nerv/Makefile @@ -41,7 +41,7 @@ CUDA_BASE := /usr/local/cuda CUDA_INCLUDE := -I $(CUDA_BASE)/include/ INCLUDE += $(CUDA_INCLUDE) -LDFLAGS := -L$(CUDA_BASE)/lib64/ -Wl,-rpath=$(CUDA_BASE)/lib64/ -lcudart -lcublas +LDFLAGS := -L$(CUDA_BASE)/lib64/ -Wl,-rpath=$(CUDA_BASE)/lib64/ -lcudart -lcublas -lcurand CFLAGS := -Wall -Wextra -O2 NVCC := $(CUDA_BASE)/bin/nvcc NVCC_FLAGS := -Xcompiler -fPIC,-Wall,-Wextra diff --git a/nerv/lib/matrix/cukernel.h b/nerv/lib/matrix/cukernel.h index 40f8e9f..47dc0a8 100644 --- a/nerv/lib/matrix/cukernel.h +++ b/nerv/lib/matrix/cukernel.h @@ -3,6 +3,8 @@ void cudak_(cuda_mul_elem)(const Matrix *a, const Matrix *b, Matrix *c); void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b); void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b); void cudak_(cuda_sigmoid_grad)(const Matrix *output, const Matrix *err, Matrix *nerr); +void cudak_(cuda_rand_uniform)(const Matrix *a); //a's curand_gen may be modified +void cudak_(cuda_thres_mask)(const Matrix *a, double thres, double low, double high); void cudak_(cuda_tanh)(const Matrix *a, Matrix *b); void cudak_(cuda_tanh_grad)(const Matrix *output, const Matrix *err, Matrix *nerr); void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b); diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu index 1a20b4f..b092e4a 100644 --- a/nerv/lib/matrix/generic/cukernel.cu +++ b/nerv/lib/matrix/generic/cukernel.cu @@ -20,6 +20,19 @@ __global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b, b[idx] = log(tmp); } +__global__ void cudak_(thres_mask)(MATRIX_ELEM *a, double thres, double low, double high, + int nrow, int ncol, int stride) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + long idx; + if (i >= nrow || j >= ncol) return; + idx = j + i * stride; + if (a[idx] < thres) + a[idx] = low; + else + a[idx] = high; +} + __global__ void cudak_(mul_elem)(const MATRIX_ELEM *a, const MATRIX_ELEM *b, MATRIX_ELEM *c, int nrow, int ncol, int stride) { @@ -376,6 +389,24 @@ extern "C" { cudaStreamSynchronize(0); } + void cudak_(cuda_rand_uniform)(Matrix *a) { + #ifdef MATRIX_USE_FLOAT + curandGenerateUniform(*(a->curand_gen), MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM)); + #endif + #ifdef MATRIX_USE_DOUBLE + curandGenerateUniformDouble(*(a->curand_gen), MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM)); + #endif + } + + void cudak_(cuda_thres_mask)(const Matrix *a, double thres, double low, double high) { + dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), + CEIL_DIV(a->nrow, threadsPerBlock.y)); + cudak_(thres_mask)<<<numBlocks, threadsPerBlock>>> \ + (MATRIX_ELEM_PTR(a), thres, low, high, a->nrow, a->ncol, a->stride / sizeof(MATRIX_ELEM)); + cudaStreamSynchronize(0); + } + void cudak_(cuda_tanh)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c index 77cb304..cbb0481 100644 --- a/nerv/lib/matrix/generic/cumatrix.c +++ b/nerv/lib/matrix/generic/cumatrix.c @@ -10,6 +10,7 @@ #include "../../common.h" #include "../cukernel.h" #include "../cuda_helper.h" +#include <curand.h> void nerv_matrix_(add)(Matrix *c, const Matrix *a, const Matrix *b, MATRIX_ELEM alpha, MATRIX_ELEM beta, @@ -75,6 +76,20 @@ void nerv_matrix_(sigmoid_grad)(Matrix *nerr, const Matrix *err, NERV_SET_STATUS(status, NERV_NORMAL, 0); } +void nerv_matrix_(rand_uniform)(Matrix *a, Status *status) { + PROFILE_START + cudak_(cuda_rand_uniform)(a); + PROFILE_STOP + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + +void nerv_matrix_(thres_mask)(Matrix *a, double thres, double low, double high, Status *status) { + PROFILE_START + cudak_(cuda_thres_mask)(a, thres, low, high); + PROFILE_STOP + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + void nerv_matrix_(tanh)(Matrix *a, const Matrix *b, Status *status) { CHECK_SAME_DIMENSION(a, b, status); PROFILE_START diff --git a/nerv/lib/matrix/generic/matrix.c b/nerv/lib/matrix/generic/matrix.c index 4246751..fd5d28f 100644 --- a/nerv/lib/matrix/generic/matrix.c +++ b/nerv/lib/matrix/generic/matrix.c @@ -10,6 +10,8 @@ void nerv_matrix_(data_free)(Matrix *self, Status *status) { { /* free matrix data */ MATRIX_DATA_FREE(MATRIX_ELEM_PTR(self), status); + curandDestroyGenerator(*(self->curand_gen)); + free(self->curand_gen); free(self->data_ref); free(self); } @@ -39,6 +41,11 @@ Matrix *nerv_matrix_(create)(long nrow, long ncol, Status *status) { } self->data_ref = (long *)malloc(sizeof(long)); *self->data_ref = 0; + + self->curand_gen = (curandGenerator_t*)malloc(sizeof(curandGenerator_t)); + curandCreateGenerator(self->curand_gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(*(self->curand_gen), time(NULL)); + nerv_matrix_(data_retain)(self); NERV_SET_STATUS(status, NERV_NORMAL, 0); return self; @@ -57,6 +64,7 @@ Matrix *nerv_matrix_(getrow)(Matrix *self, int row) { prow->nmax = prow->ncol; MATRIX_ELEM_PTR(prow) = MATRIX_ROW_PTR(self, row); prow->data_ref = self->data_ref; + prow->curand_gen = self->curand_gen; nerv_matrix_(data_retain)(prow); return prow; } diff --git a/nerv/lib/matrix/matrix.h b/nerv/lib/matrix/matrix.h index 67a6e30..5a85c08 100644 --- a/nerv/lib/matrix/matrix.h +++ b/nerv/lib/matrix/matrix.h @@ -2,6 +2,7 @@ #define NERV_GENERIC_MATRIX_H #include <stddef.h> +#include <curand.h> typedef struct Matrix { size_t stride; /* size of a row */ @@ -13,6 +14,7 @@ typedef struct Matrix { long *i; } data; /* pointer to actual storage */ long *data_ref; + curandGenerator_t *curand_gen; } Matrix; #define MATRIX_ROW_PTR(self, row) \ diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c index 3d9e694..d1f763b 100644 --- a/nerv/matrix/generic/cumatrix.c +++ b/nerv/matrix/generic/cumatrix.c @@ -62,6 +62,25 @@ static int nerv_matrix_(lua_sigmoid_grad)(lua_State *L) { return 0; } +static int nerv_matrix_(lua_thres_mask)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + MATRIX_ELEM thres = luaL_checknumber(L, 2); + MATRIX_ELEM low = luaL_checknumber(L, 3); + MATRIX_ELEM high = luaL_checknumber(L, 4); + nerv_matrix_(thres_mask)(a, thres, low, high, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_rand_uniform)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + nerv_matrix_(rand_uniform)(a, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + static int nerv_matrix_(lua_tanh)(lua_State *L) { Status status; Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); @@ -349,9 +368,11 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)}, {"tanh", nerv_matrix_(lua_tanh)}, {"tanh_grad", nerv_matrix_(lua_tanh_grad)}, + {"rand_uniform", nerv_matrix_(lua_rand_uniform)}, {"softmax", nerv_matrix_(lua_softmax)}, {"mul_elem", nerv_matrix_(lua_mul_elem)}, {"log_elem", nerv_matrix_(lua_log_elem)}, + {"thres_mask", nerv_matrix_(lua_thres_mask)}, {"copy_rows_fromh_by_idx", nerv_matrix_(lua_copy_rows_fromh_by_idx)}, {"copy_rows_fromd_by_idx", nerv_matrix_(lua_copy_rows_fromd_by_idx)}, {"expand_frm", nerv_matrix_(lua_expand_frm)}, |