aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic/cumatrix.c
diff options
context:
space:
mode:
authorcloudygoose <[email protected]>2015-05-31 11:32:45 +0800
committercloudygoose <[email protected]>2015-05-31 11:32:45 +0800
commit36162328956177d554891f937a13616b5476b231 (patch)
treea28c7a6f29b37ce091b03534e85d5cb28b2e0f81 /matrix/generic/cumatrix.c
parentcfd06bb974c7088837a107d721b1311a4f160572 (diff)
parentab12a9583bdd39884fde9bc2444e6fd1bc5f518e (diff)
...
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r--matrix/generic/cumatrix.c32
1 files changed, 32 insertions, 0 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index 7b0aa2a..3bc58d7 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -251,6 +251,37 @@ static int nerv_matrix_(log_elem)(lua_State *L) {
return 0;
}
+extern const char *nerv_matrix_host_int_tname;
+static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
+ Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_int_tname);
+ long *idx_ptr = idx->data.i;
+ int i;
+ long nrow = a->nrow;
+ if (idx->nrow != 1)
+ nerv_error(L, "index should be a vector");
+ if (idx->ncol != nrow)
+ nerv_error(L, "index dimension mismatch");
+ if (a->ncol != b->ncol)
+ nerv_error(L, "source/destination dimension mismatch");
+ cudaStream_t *streams = (cudaStream_t*)malloc(sizeof(cudaStream_t) * nrow);
+ for (i = 0; i < nrow; i++)
+ {
+ CUDA_SAFE_CALL(cudaStreamCreate(streams + i));
+ CUDA_SAFE_CALL(cudaMemcpyAsync(MATRIX_ROW_PTR(a, i),
+ MATRIX_ROW_PTR(b, idx_ptr[i]),
+ b->stride,
+ cudaMemcpyHostToDevice, streams[i]));
+ }
+ for (i = 0; i < nrow; i++)
+ {
+ CUDA_SAFE_CALL(cudaStreamSynchronize(streams[i]));
+ CUDA_SAFE_CALL(cudaStreamDestroy(streams[i]));
+ }
+ return 0;
+}
+
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"create", nerv_matrix_(create)},
{"colsum", nerv_matrix_(colsum)},
@@ -271,6 +302,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"softmax", nerv_matrix_(softmax)},
{"mul_elem", nerv_matrix_(mul_elem)},
{"log_elem", nerv_matrix_(log_elem)},
+ {"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)},
{NULL, NULL}
};