diff options
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r-- | matrix/generic/cumatrix.c | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c index ae57b21..aa303d4 100644 --- a/matrix/generic/cumatrix.c +++ b/matrix/generic/cumatrix.c @@ -55,15 +55,17 @@ static int nerv_matrix_(mul)(lua_State *L) { Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); + MATRIX_ELEM alpha = luaL_checknumber(L, 4); + MATRIX_ELEM beta = luaL_checknumber(L, 5); int nargs = lua_gettop(L); - int ta = nargs > 3 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 4)) \ + int ta = nargs > 5 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 6)) \ : CUBLAS_OP_N; - int tb = nargs > 4 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 5)) \ + int tb = nargs > 6 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 7)) \ : CUBLAS_OP_N; printf("%d %d\n", ta, tb); if (a->ncol != b->nrow) nerv_error(L, "Wrong dimension of multipliers"); - MATRIX_ELEM alpha = 1.0f, beta = 0.0f; +/* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */ NERV_CUBLAS_(gemm)(cublas_handle, tb, ta, b->ncol, a->nrow, b->nrow, &alpha, @@ -131,10 +133,22 @@ static int nerv_matrix_(add_row)(lua_State *L) { Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); double beta = luaL_checknumber(L, 3); + if (a->ncol != b->ncol) + nerv_error(L, "the number of columns is not the same"); + if (a->nrow != 1) + nerv_error(L, "a row vector is expected"); cudak_(cuda_add_row)(a, b, beta); return 0; } +static int nerv_matrix_(fill)(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); + double val = luaL_checknumber(L, 2); + cudak_(cuda_fill)(self, val); + return 0; +} + + extern const char *MATRIX_CUMATRIX_HOST_TNAME; static int nerv_matrix_(copy_from)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); @@ -162,8 +176,6 @@ static int nerv_matrix_(copy_to)(lua_State *L) { static const luaL_Reg nerv_matrix_(extra_methods)[] = { - {"add", nerv_matrix_(add)}, - {"mul", nerv_matrix_(mul)}, {"create", nerv_matrix_(create)}, {"sigmoid", nerv_matrix_(sigmoid)}, {"softmax", nerv_matrix_(softmax)}, @@ -173,7 +185,10 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"copy_from", nerv_matrix_(copy_from)}, {"copy_to", nerv_matrix_(copy_to)}, /* in-place calc */ + {"add", nerv_matrix_(add)}, + {"mul", nerv_matrix_(mul)}, {"add_row", nerv_matrix_(add_row)}, + {"fill", nerv_matrix_(fill)}, {NULL, NULL} }; |