aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/lib/matrix/generic/cukernel.cu30
-rw-r--r--nerv/lib/matrix/generic/cumatrix.c18
-rw-r--r--nerv/lib/matrix/generic/cumatrix.h6
-rw-r--r--nerv/matrix/generic/cumatrix.c21
4 files changed, 65 insertions, 10 deletions
diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu
index 552f7a4..9244783 100644
--- a/nerv/lib/matrix/generic/cukernel.cu
+++ b/nerv/lib/matrix/generic/cukernel.cu
@@ -262,7 +262,7 @@ __global__ void cudak_(clip)(MATRIX_ELEM *a,
}
#ifdef __NERV_FUTURE_CUDA_7
-__global__ void cudak_(update_select_rows)(MATRIX_ELEM *c, const MATRIX_ELEM *a, const MATRIX_ELEM *idx,
+__global__ void cudak_(update_select_rows_by_rowidx)(MATRIX_ELEM *c, const MATRIX_ELEM *a, const MATRIX_ELEM *idx,
int nrow_a, int ncol_a, int nrow_c, int stride_c, int stride_a, double alpha, double beta) {
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
@@ -275,6 +275,20 @@ __global__ void cudak_(update_select_rows)(MATRIX_ELEM *c, const MATRIX_ELEM *a,
//c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha;
atomicAdd_nvidia(c + j + i_c * stride_c, c[j + i_c * stride_c] * (- beta * alpha) + a[j + i * stride_a] * alpha);
}
+
+__global__ void cudak_(update_select_rows_by_colidx)(MATRIX_ELEM *c, const MATRIX_ELEM *a, const MATRIX_ELEM *idx,
+ int nrow_a, int ncol_a, int nrow_c, int stride_c, int stride_a, int stride_idx, double alpha, double beta) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow_a || j >= ncol_a) return;
+ int i_c = lrintf(idx[stride_idx * i]);
+ if (i_c < 0 || i_c >= nrow_c) {
+ printf("ERROR inside kernel update_select_rows, i_c(%d) out of range!", i_c);
+ }
+ //critical: i_c could conflict among threads(same index in the idx array), so atomicAdd is used
+ //c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha;
+ atomicAdd_nvidia(c + j + i_c * stride_c, c[j + i_c * stride_c] * (- beta * alpha) + a[j + i * stride_a] * alpha);
+}
#endif
__global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
@@ -640,16 +654,26 @@ extern "C" {
}
#ifdef __NERV_FUTURE_CUDA_7
- void cudak_(cuda_update_select_rows)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta) {
+ void cudak_(cuda_update_select_rows_by_rowidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta) {
dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
CEIL_DIV(a->nrow, threadsPerBlock.y));
- cudak_(update_select_rows)<<<numBlocks, threadsPerBlock>>> \
+ cudak_(update_select_rows_by_rowidx)<<<numBlocks, threadsPerBlock>>> \
(MATRIX_ELEM_PTR(c), MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(idx),
a->nrow, a->ncol, c->nrow, c->stride / sizeof(MATRIX_ELEM),
a->stride / sizeof(MATRIX_ELEM), alpha, beta);
cudaStreamSynchronize(0);
}
+ void cudak_(cuda_update_select_rows_by_colidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
+ CEIL_DIV(a->nrow, threadsPerBlock.y));
+ cudak_(update_select_rows_by_colidx)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(c), MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(idx),
+ a->nrow, a->ncol, c->nrow, c->stride / sizeof(MATRIX_ELEM),
+ a->stride / sizeof(MATRIX_ELEM), idx->stride / sizeof(MATRIX_ELEM), alpha, beta);
+ cudaStreamSynchronize(0);
+ }
#endif
void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) {
diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c
index 68889ad..31d6b06 100644
--- a/nerv/lib/matrix/generic/cumatrix.c
+++ b/nerv/lib/matrix/generic/cumatrix.c
@@ -394,14 +394,26 @@ void nerv_matrix_(copy_rows_fromd_by_idx)(Matrix *a, const Matrix *b,
}
#ifdef __NERV_FUTURE_CUDA_7
-void nerv_matrix_(update_select_rows)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta, Status *status) {
+void nerv_matrix_(update_select_rows_by_rowidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta, Status *status) {
long nrow = a->nrow;
- if (idx->nrow != 1)
+ if (idx->nrow != 1 || idx->ncol != a->nrow)
+ NERV_EXIT_STATUS(status, MAT_IDX_VECTOR_EXP, 0);
+ if (a->ncol != c->ncol)
+ NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0);
+ PROFILE_START
+ cudak_(cuda_update_select_rows_by_rowidx)(c, a, idx, alpha, beta);
+ PROFILE_STOP
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+}
+
+void nerv_matrix_(update_select_rows_by_colidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta, Status *status) {
+ long nrow = a->nrow;
+ if (idx->ncol != 1 || idx->nrow != a->nrow)
NERV_EXIT_STATUS(status, MAT_IDX_VECTOR_EXP, 0);
if (a->ncol != c->ncol)
NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0);
PROFILE_START
- cudak_(cuda_update_select_rows)(c, a, idx, alpha, beta);
+ cudak_(cuda_update_select_rows_by_colidx)(c, a, idx, alpha, beta);
PROFILE_STOP
NERV_SET_STATUS(status, NERV_NORMAL, 0);
}
diff --git a/nerv/lib/matrix/generic/cumatrix.h b/nerv/lib/matrix/generic/cumatrix.h
index aa8805a..560311e 100644
--- a/nerv/lib/matrix/generic/cumatrix.h
+++ b/nerv/lib/matrix/generic/cumatrix.h
@@ -45,7 +45,11 @@ void nerv_matrix_(copy_rows_fromh_by_idx)(Matrix *a, const Matrix *b,
const Matrix *idx, int b_begin, Status *status);
void nerv_matrix_(copy_rows_fromd_by_idx)(Matrix *a, const Matrix *b,
const Matrix *idx, int b_begin, Status *status);
-void nerv_matrix_(update_select_rows)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta, Status *status);
+
+#ifdef __NERV_FUTURE_CUDA_7
+void nerv_matrix_(update_select_rows_by_rowidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta, Status *status);
+void nerv_matrix_(update_select_rows_by_colidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta, Status *status);
+#endif
void nerv_matrix_(expand_frm)(Matrix *a, const Matrix *b,
int context, Status *status);
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c
index df858e6..95a0132 100644
--- a/nerv/matrix/generic/cumatrix.c
+++ b/nerv/matrix/generic/cumatrix.c
@@ -331,7 +331,7 @@ static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) {
}
#ifdef __NERV_FUTURE_CUDA_7
-static int nerv_matrix_(lua_update_select_rows)(lua_State *L) {
+static int nerv_matrix_(lua_update_select_rows_by_rowidx)(lua_State *L) {
/* update c's select rows,
* i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */
Status status;
@@ -340,7 +340,21 @@ static int nerv_matrix_(lua_update_select_rows)(lua_State *L) {
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
MATRIX_ELEM alpha = luaL_checknumber(L, 4);
MATRIX_ELEM beta = luaL_checknumber(L, 5);
- nerv_matrix_(update_select_rows)(c, a, idx, alpha, beta, &status);
+ nerv_matrix_(update_select_rows_by_rowidx)(c, a, idx, alpha, beta, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ return 0;
+}
+
+static int nerv_matrix_(lua_update_select_rows_by_colidx)(lua_State *L) {
+ /* update c's select rows,
+ * i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */
+ Status status;
+ Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
+ MATRIX_ELEM alpha = luaL_checknumber(L, 4);
+ MATRIX_ELEM beta = luaL_checknumber(L, 5);
+ nerv_matrix_(update_select_rows_by_colidx)(c, a, idx, alpha, beta, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
@@ -381,7 +395,8 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"scale_rows_by_row", nerv_matrix_(lua_scale_rows_by_row)},
{"scale_rows_by_col", nerv_matrix_(lua_scale_rows_by_col)},
#ifdef __NERV_FUTURE_CUDA_7
- {"update_select_rows", nerv_matrix_(lua_update_select_rows)},
+ {"update_select_rows_by_rowidx", nerv_matrix_(lua_update_select_rows_by_rowidx)},
+ {"update_select_rows_by_colidx", nerv_matrix_(lua_update_select_rows_by_colidx)},
#endif
{NULL, NULL}
};