aboutsummaryrefslogtreecommitdiff
path: root/nerv/lib/matrix/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/lib/matrix/cumatrix.c')
-rw-r--r--nerv/lib/matrix/cumatrix.c8
1 files changed, 8 insertions, 0 deletions
diff --git a/nerv/lib/matrix/cumatrix.c b/nerv/lib/matrix/cumatrix.c
index 04205e4..58bdfe7 100644
--- a/nerv/lib/matrix/cumatrix.c
+++ b/nerv/lib/matrix/cumatrix.c
@@ -9,6 +9,14 @@ static cudaEvent_t profile_start, profile_stop;
curandGenerator_t curand_gen;
static HashMap *profile;
+void nerv_cumatrix_select_gpu(int dev, Status *status) {
+ fprintf(stderr, "** selecting GPU %d\n", dev);
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+ CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status);
+ CUDA_SAFE_SYNC_CALL(cublasDestroy(cublas_handle), status);
+ CUDA_SAFE_SYNC_CALL(cublasCreate(&cublas_handle), status);
+}
+
void nerv_cumatrix_print_profile() {
size_t i;
fprintf(stderr, "*** [nerv cumatrix profile] **\n");