diff options
Diffstat (limited to 'nerv/matrix/cumatrix.c')
-rw-r--r-- | nerv/matrix/cumatrix.c | 59 |
1 files changed, 44 insertions, 15 deletions
diff --git a/nerv/matrix/cumatrix.c b/nerv/matrix/cumatrix.c index 7f22d68..26b055b 100644 --- a/nerv/matrix/cumatrix.c +++ b/nerv/matrix/cumatrix.c @@ -4,45 +4,74 @@ #include "../lib/matrix/cuda_helper.h" #include <string.h> #define PROFILE_HASHMAP_SIZE 123457 -static cublasHandle_t cublas_handle; -static cudaEvent_t profile_start, profile_stop; -static HashMap *profile; -static int select_gpu(lua_State *L) { +const char *nerv_cuda_context_tname = "nerv.CuContext"; + +int nerv_cuda_context_lua_select_gpu(lua_State *L) { Status status; - int dev = luaL_checkinteger(L, 1); - nerv_cumatrix_select_gpu(dev, &status); + nerv_cuda_context_select_gpu(luaT_checkudata(L, 1, nerv_cuda_context_tname), + luaL_checkinteger(L, 1), &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } -static int print_profile(lua_State *L) { - nerv_cumatrix_print_profile(); +int nerv_cuda_context_lua_print_profile(lua_State *L) { + nerv_cuda_context_print_profile(luaT_checkudata(L, 1, nerv_cuda_context_tname)); return 0; } -static int clear_profile(lua_State *L) { - nerv_cumatrix_clear_profile(); +int nerv_cuda_context_lua_clear_profile(lua_State *L) { + nerv_cuda_context_clear_profile(luaT_checkudata(L, 1, nerv_cuda_context_tname)); return 0; } -static const luaL_Reg cumatrix_methods[] = { - {"print_profile", print_profile}, - {"clear_profile", clear_profile}, - {"select_gpu", select_gpu}, +int nerv_cuda_context_lua_new(lua_State *L) { + Status status; + CuContext *self = nerv_cuda_context_create(&status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, self, nerv_cuda_context_tname); + return 1; +} + +int nerv_cuda_context_lua_destroy(lua_State *L) { + Status status; + CuContext *self = luaT_checkudata(L, 1, nerv_cuda_context_tname); + nerv_cuda_context_destroy(self, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 1; +} + +static const luaL_Reg nerv_cuda_context_methods[] = { + {"print_profile", nerv_cuda_context_lua_print_profile}, + {"clear_profile", nerv_cuda_context_lua_clear_profile}, + {"select_gpu", nerv_cuda_context_lua_select_gpu}, {NULL, NULL} }; +void nerv_cuda_context_lua_init(lua_State *L) { + luaT_newmetatable(L, nerv_cuda_context_tname, NULL, + nerv_cuda_context_lua_new, + nerv_cuda_context_lua_destroy, NULL); + luaL_register(L, NULL, nerv_cuda_context_methods); +} + extern void nerv_matrix_cuda_float_lua_init(lua_State *L); extern void nerv_matrix_cuda_double_lua_init(lua_State *L); +static const luaL_Reg cumatrix_methods[] = { + {NULL, NULL} +}; + void nerv_lua_cumatrix_init(lua_State *L) { luaL_register(L, NULL, cumatrix_methods); - nerv_cumatrix_init(); + nerv_cuda_context_lua_init(L); nerv_matrix_cuda_float_lua_init(L); nerv_matrix_cuda_double_lua_init(L); } +#define MATRIX_CONTEXT CuContext +#define MATRIX_CONTEXT_TNAME nerv_cuda_context_tname + #define MATRIX_USE_FLOAT #define cuda_matrix_(NAME) cuda_matrix_float_##NAME #define nerv_matrix_(NAME) nerv_matrix_cuda_float_##NAME |