summaryrefslogtreecommitdiff
path: root/matrix/cumatrix.c
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-07 21:59:10 +0800
committerDeterminant <[email protected]>2015-06-07 21:59:10 +0800
commit0f30b1a4b5e583cb1df7dbb349c1af4378e41369 (patch)
tree967c6326b83cda2b92eee5f597dde0e74b071dbb /matrix/cumatrix.c
parent6e720b961f7edac9c3a41affe0ca40dd0ec9fc85 (diff)
fix minor bugs in cumatrix; clean up part of code
Diffstat (limited to 'matrix/cumatrix.c')
-rw-r--r--matrix/cumatrix.c28
1 files changed, 26 insertions, 2 deletions
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index ee5ecaa..af34fb4 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -1,11 +1,14 @@
#define NERV_GENERIC_CUMATRIX
#include "../common.h"
#include "cuda_helper.h"
+#include <string.h>
+#define PROFILE_HASHMAP_SIZE 123457
static cublasHandle_t cublas_handle;
static cudaEvent_t profile_start, profile_stop;
static HashMap *profile;
-int print_profile(lua_State *L) {
+static int print_profile(lua_State *L) {
+ (void)L;
size_t i;
fprintf(stderr, "*** [nerv cumatrix profile] **\n");
for (i = 0; i < profile->size; i++)
@@ -19,7 +22,8 @@ int print_profile(lua_State *L) {
return 0;
}
-int clear_profile(lua_State *L) {
+static int clear_profile(lua_State *L) {
+ (void)L;
hashmap_clear(profile);
return 0;
}
@@ -35,6 +39,25 @@ void accu_profile(const char *name, float delta) {
*val += delta;
}
+static const luaL_Reg cumatrix_methods[] = {
+ {"print_profile", print_profile},
+ {"clear_profile", clear_profile},
+ {NULL, NULL}
+};
+
+extern void nerv_matrix_cuda_float_init(lua_State *L);
+extern void nerv_matrix_cuda_double_init(lua_State *L);
+
+void nerv_cumatrix_init(lua_State *L) {
+ luaL_register(L, NULL, cumatrix_methods);
+ cublasCreate(&cublas_handle);
+ cudaEventCreate(&profile_start);
+ cudaEventCreate(&profile_stop);
+ profile = hashmap_create(PROFILE_HASHMAP_SIZE, bkdr_hash, strcmp);
+ nerv_matrix_cuda_float_init(L);
+ nerv_matrix_cuda_double_init(L);
+}
+
#define MATRIX_USE_FLOAT
#define cuda_matrix_(NAME) cuda_matrix_float_##NAME
#define nerv_matrix_(NAME) nerv_matrix_cuda_float_##NAME
@@ -51,6 +74,7 @@ const char *nerv_matrix_(tname) = "nerv.CuMatrixFloat";
#undef MATRIX_ELEM
#undef MATRIX_ELEM_PTR
#undef MATRIX_ELEM_FMT
+#undef MATRIX_ELEM_WRITE_FMT
#undef MATRIX_CUMATRIX_HOST_TNAME
#define MATRIX_USE_DOUBLE