aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-10-28 16:29:29 +0800
committertxh18 <[email protected]>2015-10-28 16:29:29 +0800
commite0fa1a48cb9f91bfcfc60b732b6f137a7a2071ba (patch)
treed375f9684970bcfa2977d74074f949bf3f98107a
parentaf99db1c6bc4823cc6ff094f24e963acd4788ef1 (diff)
changed copy_rows_fromd_by_idx a bit to make it clearer
-rw-r--r--nerv/doc/nerv_matrix.md8
-rw-r--r--nerv/lib/matrix/generic/cukernel.cu8
-rw-r--r--nerv/matrix/generic/cumatrix.c4
3 files changed, 10 insertions, 10 deletions
diff --git a/nerv/doc/nerv_matrix.md b/nerv/doc/nerv_matrix.md
index b915dee..dfd843d 100644
--- a/nerv/doc/nerv_matrix.md
+++ b/nerv/doc/nerv_matrix.md
@@ -63,10 +63,10 @@ Copy the content of a __CuMatrix__ `a` to __Matrix__ `self`, they should be of t
Copy the content of the __Matrix__ `self` to a __MMatrix__ `a`.
* __void Matrix.copy_tod(Matrix self, CuMatrix a)__
Copy the content of the __Matrix__ `self` to a __CuMatrix__ `a`.
-* __void Matrix.copy_rows_fromh_by_idx(Matrix self, MMatrix ma, MMatrixInt idx)__
-`idx` should be a row vector. This function copy the rows of `ma` to `self` according to `idx`, in other words, it assigns `ma[idx[i]]` to `self[i]`.
-* __void Matrix.copy_rows_fromd_by_idx(Matrix self, CuMatrix b, CuMatrix idx)__
-`idx` needs to a row vector matrix, it stacks the rows of index `idx` of the __CuMatrix__ `b` and copies to `self`.
+* __void Matrix.copy_rows_fromh_by_idx(Matrix self, MMatrix ma, MMatrixInt idx, int idx_begin)__
+`idx` should be a row vector. This function copy the rows of `ma` to `self` according to `idx`, in other words, it assigns `ma[idx[i+idx_begin]]` to `self[i]`.
+* __void Matrix.copy_rows_fromd_by_idx(Matrix self, Matrix b, Matrix idx, int idx_begin)__
+`idx` needs to a row vector matrix, it stacks the rows of index `idx` of the __CuMatrix__ `b` and copies to `self`. `idx_begin` is used as an offset in the `idx` index array.
* __void Matrix.update_select_rows(Matrix self, Matrix err, Matrix idx, double alpha, double beta)__
Update selected rows of `self`, i.e. `self[idx[i]] = self[idx[i]] * (1 - beta * alpha) + alpha * err[i]`.
* __void Matrix.add(Matrix self, Matrix ma, Matrix mb, Element_type alpha, Element_type beta)__
diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu
index 6c8e64a..d042d48 100644
--- a/nerv/lib/matrix/generic/cukernel.cu
+++ b/nerv/lib/matrix/generic/cukernel.cu
@@ -294,7 +294,7 @@ __global__ void cudak_(gen_col_idx)(MATRIX_ELEM *b,
}
__global__ void cudak_(copy_rows_by_idx)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- const MATRIX_ELEM *idx, int b_begin,
+ const MATRIX_ELEM *idx,
int nrow, int ncol, int stride) {
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
@@ -620,13 +620,13 @@ extern "C" {
}
void cudak_(cuda_copy_rows_by_idx)(const Matrix *a, Matrix *b,
- const Matrix *idx, int b_begin) {
+ const Matrix *idx, int idx_begin) {
dim3 threadsPerBlock(CUDA_THREADS_NN, 1);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), b->nrow);
cudak_(copy_rows_by_idx)<<<numBlocks, threadsPerBlock>>> \
(MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- MATRIX_ELEM_PTR(idx) + b_begin,
- b_begin, b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
+ MATRIX_ELEM_PTR(idx) + idx_begin,
+ b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
cudaStreamSynchronize(0);
}
}
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c
index 623352e..f675149 100644
--- a/nerv/matrix/generic/cumatrix.c
+++ b/nerv/matrix/generic/cumatrix.c
@@ -247,8 +247,8 @@ static int nerv_matrix_(lua_copy_rows_fromd_by_idx)(lua_State *L) {
const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
long nrow = a->nrow;
- int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
- nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, b_begin, &status);
+ int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
+ nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, idx_begin, &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}