diff options
author | Determinant <ted.sybil@gmail.com> | 2015-08-28 13:21:52 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2015-08-28 13:21:52 +0800 |
commit | 1a9f63e351582f54fec7817927168cb1dbb0c1d6 (patch) | |
tree | c340b648c60d93b956be5956fa03233383e37e5d /nerv/matrix | |
parent | 8bf9c7575ffeeabb3924e9e02a35afe187071fe2 (diff) |
support gpu buffering
Diffstat (limited to 'nerv/matrix')
-rw-r--r-- | nerv/matrix/generic/cumatrix.c | 22 | ||||
-rw-r--r-- | nerv/matrix/mmatrix.c | 46 |
2 files changed, 44 insertions, 24 deletions
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c index ab7f7c4..08cb4c2 100644 --- a/nerv/matrix/generic/cumatrix.c +++ b/nerv/matrix/generic/cumatrix.c @@ -228,12 +228,12 @@ static int nerv_matrix_(lua_decompress)(lua_State *L) { return 1; } -extern const char *nerv_matrix_host_int_tname; +extern const char *nerv_matrix_host_float_tname; static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) { Status status; Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME); - const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_int_tname); + const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_float_tname); long nrow = a->nrow; int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, &status); @@ -241,6 +241,18 @@ static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) { return 0; } +static int nerv_matrix_(lua_copy_rows_fromd_by_idx)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname)); + long nrow = a->nrow; + int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; + nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, b_begin, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + static int nerv_matrix_(lua_expand_frm)(lua_State *L) { Status status; Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); @@ -290,6 +302,8 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { /* in-place calc */ {"copy_fromh", nerv_matrix_(lua_copy_fromh)}, {"copy_fromd", nerv_matrix_(lua_copy_fromd)}, + /* alias for copy_fromd */ + {"copy_from", nerv_matrix_(lua_copy_fromd)}, {"copy_toh", nerv_matrix_(lua_copy_toh)}, {"add", nerv_matrix_(lua_add)}, {"mul", nerv_matrix_(lua_mul)}, @@ -302,6 +316,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"mul_elem", nerv_matrix_(lua_mul_elem)}, {"log_elem", nerv_matrix_(lua_log_elem)}, {"copy_rows_fromh_by_idx", nerv_matrix_(lua_copy_rows_fromh_by_idx)}, + {"copy_rows_fromd_by_idx", nerv_matrix_(lua_copy_rows_fromd_by_idx)}, {"expand_frm", nerv_matrix_(lua_expand_frm)}, {"rearrange_frm", nerv_matrix_(lua_rearrange_frm)}, {"scale_rows_by_row", nerv_matrix_(lua_scale_rows_by_row)}, @@ -311,6 +326,9 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { static void cuda_matrix_(init)(lua_State *L) { luaN_append_methods(L, nerv_matrix_(extra_methods)); +#ifdef CUMATRIX_INIT + CUMATRIX_INIT(L); +#endif } int nerv_matrix_(lua_get_elem)(lua_State *L) { diff --git a/nerv/matrix/mmatrix.c b/nerv/matrix/mmatrix.c index 5561572..961059c 100644 --- a/nerv/matrix/mmatrix.c +++ b/nerv/matrix/mmatrix.c @@ -16,7 +16,30 @@ void nerv_lua_mmatrix_init(lua_State *L) { #define host_matrix_(NAME) host_matrix_float_##NAME #define nerv_matrix_(NAME) nerv_matrix_host_float_##NAME const char *nerv_matrix_(tname) = "nerv.MMatrixFloat"; +#define MMATRIX_INIT(L) host_matrix_(init_extra)(L) + +static const luaL_Reg nerv_matrix_(extra_methods_int)[]; +static void host_matrix_(init_extra)(lua_State *L) { + luaN_append_methods(L, nerv_matrix_(extra_methods_int)); +} + #include "generic/mmatrix.c" +#include "../lib/matrix/mmatrix.h" + +static int nerv_matrix_(lua_perm_gen)(lua_State *L) { + Status status; + int i, ncol = luaL_checkinteger(L, 1); + Matrix *self = nerv_matrix_(perm_gen)(ncol, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, self, nerv_matrix_(tname)); + return 1; +} + +static const luaL_Reg nerv_matrix_(extra_methods_int)[] = { + {"perm_gen", nerv_matrix_(lua_perm_gen)}, + {NULL, NULL} +}; + #undef nerv_matrix_ #undef host_matrix_ #undef MATRIX_USE_FLOAT @@ -24,6 +47,7 @@ const char *nerv_matrix_(tname) = "nerv.MMatrixFloat"; #undef MATRIX_ELEM_PTR #undef MATRIX_ELEM_FMT #undef MATRIX_ELEM_WRITE_FMT +#undef MMATRIX_INIT #define NERV_GENERIC_MMATRIX #define MATRIX_USE_DOUBLE @@ -44,26 +68,4 @@ const char *nerv_matrix_(tname) = "nerv.MMatrixDouble"; #define host_matrix_(NAME) host_matrix_int_##NAME #define nerv_matrix_(NAME) nerv_matrix_host_int_##NAME const char *nerv_matrix_(tname) = "nerv.MMatrixInt"; -#define MMATRIX_INIT(L) host_matrix_(init_extra)(L) - -static const luaL_Reg nerv_matrix_(extra_methods_int)[]; -static void host_matrix_(init_extra)(lua_State *L) { - luaN_append_methods(L, nerv_matrix_(extra_methods_int)); -} - #include "generic/mmatrix.c" -#include "../lib/matrix/mmatrix.h" - -static int nerv_matrix_(lua_perm_gen)(lua_State *L) { - Status status; - int i, ncol = luaL_checkinteger(L, 1); - Matrix *self = nerv_matrix_(perm_gen)(ncol, &status); - NERV_LUA_CHECK_STATUS(L, status); - luaT_pushudata(L, self, nerv_matrix_(tname)); - return 1; -} - -static const luaL_Reg nerv_matrix_(extra_methods_int)[] = { - {"perm_gen", nerv_matrix_(lua_perm_gen)}, - {NULL, NULL} -}; |