aboutsummaryrefslogtreecommitdiff
path: root/nerv/lib/matrix/generic/cukernel.cu
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/lib/matrix/generic/cukernel.cu')
-rw-r--r--nerv/lib/matrix/generic/cukernel.cu14
1 files changed, 10 insertions, 4 deletions
diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu
index 995059c..51e3b6a 100644
--- a/nerv/lib/matrix/generic/cukernel.cu
+++ b/nerv/lib/matrix/generic/cukernel.cu
@@ -4,6 +4,7 @@
#include "../matrix.h"
#include "cuda.h"
#include "float.h"
+#include "curand.h"
#define CUDA_THREADS_N 16
#define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N))
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
@@ -356,11 +357,17 @@ __global__ void cudak_(copy_rows_by_idx)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
if (i >= nrow || j >= ncol) return;
+ /*
int k = lrintf(idx[i]);
if (k < 0 || k >= a_nrow) {
printf("error in kernel copy_rows_by_idx k(%d) out of range\n", k);
}
b[j + i * stride] = a[j + k * stride];
+ */
+ /* NOTE: in most cases it is guaranteed
+ * the idx is within the range, checking
+ * would bring some overhead. */
+ b[j + i * stride] = a[j + lrintf(idx[i]) * stride];
}
__global__ void cudak_(copy_rows_by_colidx)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
@@ -391,8 +398,6 @@ __global__ void cudak_(prefixsum_row_reduce)(const MATRIX_ELEM *a, MATRIX_ELEM *
b[idx_b] = a[idx_a];
}
-
-
extern "C" {
#include "../cukernel.h"
void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) {
@@ -440,12 +445,13 @@ extern "C" {
cudaStreamSynchronize(0);
}
+ extern curandGenerator_t curand_gen;
void cudak_(cuda_rand_uniform)(const Matrix *a) {
#ifdef MATRIX_USE_FLOAT
- curandGenerateUniform(*(a->curand_gen), MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM));
+ curandGenerateUniform(curand_gen, MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM));
#endif
#ifdef MATRIX_USE_DOUBLE
- curandGenerateUniformDouble(*(a->curand_gen), MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM));
+ curandGenerateUniformDouble(curand_gen, MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM));
#endif
}