From 0f30b1a4b5e583cb1df7dbb349c1af4378e41369 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sun, 7 Jun 2015 21:59:10 +0800 Subject: fix minor bugs in cumatrix; clean up part of code --- matrix/cumatrix.c | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) (limited to 'matrix/cumatrix.c') diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c index ee5ecaa..af34fb4 100644 --- a/matrix/cumatrix.c +++ b/matrix/cumatrix.c @@ -1,11 +1,14 @@ #define NERV_GENERIC_CUMATRIX #include "../common.h" #include "cuda_helper.h" +#include +#define PROFILE_HASHMAP_SIZE 123457 static cublasHandle_t cublas_handle; static cudaEvent_t profile_start, profile_stop; static HashMap *profile; -int print_profile(lua_State *L) { +static int print_profile(lua_State *L) { + (void)L; size_t i; fprintf(stderr, "*** [nerv cumatrix profile] **\n"); for (i = 0; i < profile->size; i++) @@ -19,7 +22,8 @@ int print_profile(lua_State *L) { return 0; } -int clear_profile(lua_State *L) { +static int clear_profile(lua_State *L) { + (void)L; hashmap_clear(profile); return 0; } @@ -35,6 +39,25 @@ void accu_profile(const char *name, float delta) { *val += delta; } +static const luaL_Reg cumatrix_methods[] = { + {"print_profile", print_profile}, + {"clear_profile", clear_profile}, + {NULL, NULL} +}; + +extern void nerv_matrix_cuda_float_init(lua_State *L); +extern void nerv_matrix_cuda_double_init(lua_State *L); + +void nerv_cumatrix_init(lua_State *L) { + luaL_register(L, NULL, cumatrix_methods); + cublasCreate(&cublas_handle); + cudaEventCreate(&profile_start); + cudaEventCreate(&profile_stop); + profile = hashmap_create(PROFILE_HASHMAP_SIZE, bkdr_hash, strcmp); + nerv_matrix_cuda_float_init(L); + nerv_matrix_cuda_double_init(L); +} + #define MATRIX_USE_FLOAT #define cuda_matrix_(NAME) cuda_matrix_float_##NAME #define nerv_matrix_(NAME) nerv_matrix_cuda_float_##NAME @@ -51,6 +74,7 @@ const char *nerv_matrix_(tname) = "nerv.CuMatrixFloat"; #undef MATRIX_ELEM #undef MATRIX_ELEM_PTR #undef MATRIX_ELEM_FMT +#undef MATRIX_ELEM_WRITE_FMT #undef MATRIX_CUMATRIX_HOST_TNAME #define MATRIX_USE_DOUBLE -- cgit v1.2.3