diff options
author | Determinant <ted.sybil@gmail.com> | 2015-08-28 13:21:52 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2015-08-28 13:21:52 +0800 |
commit | 1a9f63e351582f54fec7817927168cb1dbb0c1d6 (patch) | |
tree | c340b648c60d93b956be5956fa03233383e37e5d /nerv/matrix/generic/cumatrix.c | |
parent | 8bf9c7575ffeeabb3924e9e02a35afe187071fe2 (diff) |
support gpu buffering
Diffstat (limited to 'nerv/matrix/generic/cumatrix.c')
-rw-r--r-- | nerv/matrix/generic/cumatrix.c | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c index ab7f7c4..08cb4c2 100644 --- a/nerv/matrix/generic/cumatrix.c +++ b/nerv/matrix/generic/cumatrix.c @@ -228,12 +228,12 @@ static int nerv_matrix_(lua_decompress)(lua_State *L) { return 1; } -extern const char *nerv_matrix_host_int_tname; +extern const char *nerv_matrix_host_float_tname; static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) { Status status; Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME); - const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_int_tname); + const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_float_tname); long nrow = a->nrow; int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, &status); @@ -241,6 +241,18 @@ static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) { return 0; } +static int nerv_matrix_(lua_copy_rows_fromd_by_idx)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname)); + long nrow = a->nrow; + int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; + nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, b_begin, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + static int nerv_matrix_(lua_expand_frm)(lua_State *L) { Status status; Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); @@ -290,6 +302,8 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { /* in-place calc */ {"copy_fromh", nerv_matrix_(lua_copy_fromh)}, {"copy_fromd", nerv_matrix_(lua_copy_fromd)}, + /* alias for copy_fromd */ + {"copy_from", nerv_matrix_(lua_copy_fromd)}, {"copy_toh", nerv_matrix_(lua_copy_toh)}, {"add", nerv_matrix_(lua_add)}, {"mul", nerv_matrix_(lua_mul)}, @@ -302,6 +316,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"mul_elem", nerv_matrix_(lua_mul_elem)}, {"log_elem", nerv_matrix_(lua_log_elem)}, {"copy_rows_fromh_by_idx", nerv_matrix_(lua_copy_rows_fromh_by_idx)}, + {"copy_rows_fromd_by_idx", nerv_matrix_(lua_copy_rows_fromd_by_idx)}, {"expand_frm", nerv_matrix_(lua_expand_frm)}, {"rearrange_frm", nerv_matrix_(lua_rearrange_frm)}, {"scale_rows_by_row", nerv_matrix_(lua_scale_rows_by_row)}, @@ -311,6 +326,9 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { static void cuda_matrix_(init)(lua_State *L) { luaN_append_methods(L, nerv_matrix_(extra_methods)); +#ifdef CUMATRIX_INIT + CUMATRIX_INIT(L); +#endif } int nerv_matrix_(lua_get_elem)(lua_State *L) { |