aboutsummaryrefslogtreecommitdiff
path: root/nerv/lib
diff options
context:
space:
mode:
authorTed Yin <Determinant@users.noreply.github.com>2015-11-18 15:11:43 +0800
committerTed Yin <Determinant@users.noreply.github.com>2015-11-18 15:11:43 +0800
commit369853d0b3f2bd70f5ddce43fa2811adb956333a (patch)
treeba4c3d66d69361e4f6343e5be8f69a09ae94a07a /nerv/lib
parente516887c5338411c22102cdab051e0abe447b754 (diff)
parent22dac37c663605aa6c6fa0426696d2d01da4370f (diff)
Merge pull request #12 from cloudygoose/txh18/rnnlm
add atomicAdd for cukernel
Diffstat (limited to 'nerv/lib')
-rw-r--r--nerv/lib/matrix/cukernel.cu32
-rw-r--r--nerv/lib/matrix/generic/cukernel.cu4
2 files changed, 35 insertions, 1 deletions
diff --git a/nerv/lib/matrix/cukernel.cu b/nerv/lib/matrix/cukernel.cu
index a19030a..1e856b9 100644
--- a/nerv/lib/matrix/cukernel.cu
+++ b/nerv/lib/matrix/cukernel.cu
@@ -1,5 +1,37 @@
#define NERV_GENERIC_CUKERNEL
+#include "cumatrix.h"
+
+__device__ double atomicAdd_nvidia(double* address, double val) {
+ //nvidia provided this implementation on the net
+ //atmoicAdd is not included in CUDA for double
+ unsigned long long int* address_as_ull =
+ (unsigned long long int*)address;
+ unsigned long long int old = *address_as_ull, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __double_as_longlong(val +
+ __longlong_as_double(assumed)));
+ } while (assumed != old);
+ return __longlong_as_double(old);
+}
+
+__device__ float atomicAdd_nvidia(float* address, float val) {
+ //nvidia provided this implementation on the net
+ //I tried the included atomocAdd, but the select_liner layer result seems unreproduceable, but sadly, even if I used this implementation, the select_linear layer result is still unreproduceable
+ int* address_as_ull = (int*)address;
+ int old = *address_as_ull, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __float_as_int(val +
+ __int_as_float(assumed)));
+ } while (assumed != old);
+ return __int_as_float(old);
+}
+
+
#define cudak_(NAME) cudak_float_ ## NAME
#define MATRIX_USE_FLOAT
#include "generic/elem_type.h"
diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu
index d042d48..e1063af 100644
--- a/nerv/lib/matrix/generic/cukernel.cu
+++ b/nerv/lib/matrix/generic/cukernel.cu
@@ -231,7 +231,9 @@ __global__ void cudak_(update_select_rows)(MATRIX_ELEM *c, const MATRIX_ELEM *a,
int i = blockIdx.y * blockDim.y + threadIdx.y;
if (i >= nrow_a || j >= ncol_a) return;
int i_c = lrintf(idx[i]);
- c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha;
+ //critical: i_c could conflict among threads(same index in the idx array), so atomicAdd is used
+ //c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha;
+ atomicAdd_nvidia(c + j + i_c * stride_c, c[j + i_c * stride_c] * (- beta * alpha) + a[j + i * stride_a] * alpha);
}
__global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,