diff options
Diffstat (limited to 'nerv/lib/matrix/cumatrix.h')
-rw-r--r-- | nerv/lib/matrix/cumatrix.h | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/nerv/lib/matrix/cumatrix.h b/nerv/lib/matrix/cumatrix.h index b47e14b..fd2a5ce 100644 --- a/nerv/lib/matrix/cumatrix.h +++ b/nerv/lib/matrix/cumatrix.h @@ -2,8 +2,20 @@ #define NERV_CUMATRIX_H #include "matrix.h" #include "../common.h" -void nerv_cumatrix_print_profile(); -void nerv_cumatrix_clear_profile(); -void nerv_cumatrix_init(); -void nerv_cumatrix_select_gpu(int dev, Status *status); +#include "cuda_helper.h" + +typedef struct CuContext { + int has_handle; + cublasHandle_t cublas_handle; + cudaEvent_t profile_start, profile_stop; + curandGenerator_t curand_gen; + HashMap *profile; +} CuContext; + +void nerv_cuda_context_print_profile(CuContext *context); +void nerv_cuda_context_clear_profile(CuContext *context); +void nerv_cuda_context_accu_profile(CuContext *context, const char *name, float delta); +void nerv_cuda_context_select_gpu(CuContext *context, int dev, Status *status); +CuContext *nerv_cuda_context_create(int dev, Status *status); +void nerv_cuda_context_destroy(CuContext *contex, Status *status); #endif |