diff options
Diffstat (limited to 'nerv/lib/matrix/generic/cukernel.cu')
-rw-r--r-- | nerv/lib/matrix/generic/cukernel.cu | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu index 995059c..51e3b6a 100644 --- a/nerv/lib/matrix/generic/cukernel.cu +++ b/nerv/lib/matrix/generic/cukernel.cu @@ -4,6 +4,7 @@ #include "../matrix.h" #include "cuda.h" #include "float.h" +#include "curand.h" #define CUDA_THREADS_N 16 #define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N)) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) @@ -356,11 +357,17 @@ __global__ void cudak_(copy_rows_by_idx)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; + /* int k = lrintf(idx[i]); if (k < 0 || k >= a_nrow) { printf("error in kernel copy_rows_by_idx k(%d) out of range\n", k); } b[j + i * stride] = a[j + k * stride]; + */ + /* NOTE: in most cases it is guaranteed + * the idx is within the range, checking + * would bring some overhead. */ + b[j + i * stride] = a[j + lrintf(idx[i]) * stride]; } __global__ void cudak_(copy_rows_by_colidx)(const MATRIX_ELEM *a, MATRIX_ELEM *b, @@ -391,8 +398,6 @@ __global__ void cudak_(prefixsum_row_reduce)(const MATRIX_ELEM *a, MATRIX_ELEM * b[idx_b] = a[idx_a]; } - - extern "C" { #include "../cukernel.h" void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) { @@ -440,12 +445,13 @@ extern "C" { cudaStreamSynchronize(0); } + extern curandGenerator_t curand_gen; void cudak_(cuda_rand_uniform)(const Matrix *a) { #ifdef MATRIX_USE_FLOAT - curandGenerateUniform(*(a->curand_gen), MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM)); + curandGenerateUniform(curand_gen, MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM)); #endif #ifdef MATRIX_USE_DOUBLE - curandGenerateUniformDouble(*(a->curand_gen), MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM)); + curandGenerateUniformDouble(curand_gen, MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM)); #endif } |