diff options
author | cloudygoose <[email protected]> | 2015-05-31 11:32:45 +0800 |
---|---|---|
committer | cloudygoose <[email protected]> | 2015-05-31 11:32:45 +0800 |
commit | 36162328956177d554891f937a13616b5476b231 (patch) | |
tree | a28c7a6f29b37ce091b03534e85d5cb28b2e0f81 /matrix/generic/cumatrix.c | |
parent | cfd06bb974c7088837a107d721b1311a4f160572 (diff) | |
parent | ab12a9583bdd39884fde9bc2444e6fd1bc5f518e (diff) |
...
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r-- | matrix/generic/cumatrix.c | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c index 7b0aa2a..3bc58d7 100644 --- a/matrix/generic/cumatrix.c +++ b/matrix/generic/cumatrix.c @@ -251,6 +251,37 @@ static int nerv_matrix_(log_elem)(lua_State *L) { return 0; } +extern const char *nerv_matrix_host_int_tname; +static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME); + Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_int_tname); + long *idx_ptr = idx->data.i; + int i; + long nrow = a->nrow; + if (idx->nrow != 1) + nerv_error(L, "index should be a vector"); + if (idx->ncol != nrow) + nerv_error(L, "index dimension mismatch"); + if (a->ncol != b->ncol) + nerv_error(L, "source/destination dimension mismatch"); + cudaStream_t *streams = (cudaStream_t*)malloc(sizeof(cudaStream_t) * nrow); + for (i = 0; i < nrow; i++) + { + CUDA_SAFE_CALL(cudaStreamCreate(streams + i)); + CUDA_SAFE_CALL(cudaMemcpyAsync(MATRIX_ROW_PTR(a, i), + MATRIX_ROW_PTR(b, idx_ptr[i]), + b->stride, + cudaMemcpyHostToDevice, streams[i])); + } + for (i = 0; i < nrow; i++) + { + CUDA_SAFE_CALL(cudaStreamSynchronize(streams[i])); + CUDA_SAFE_CALL(cudaStreamDestroy(streams[i])); + } + return 0; +} + static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"create", nerv_matrix_(create)}, {"colsum", nerv_matrix_(colsum)}, @@ -271,6 +302,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"softmax", nerv_matrix_(softmax)}, {"mul_elem", nerv_matrix_(mul_elem)}, {"log_elem", nerv_matrix_(log_elem)}, + {"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)}, {NULL, NULL} }; |