aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r--matrix/generic/cumatrix.c25
1 files changed, 20 insertions, 5 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index ae57b21..aa303d4 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -55,15 +55,17 @@ static int nerv_matrix_(mul)(lua_State *L) {
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
+ MATRIX_ELEM alpha = luaL_checknumber(L, 4);
+ MATRIX_ELEM beta = luaL_checknumber(L, 5);
int nargs = lua_gettop(L);
- int ta = nargs > 3 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 4)) \
+ int ta = nargs > 5 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 6)) \
: CUBLAS_OP_N;
- int tb = nargs > 4 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 5)) \
+ int tb = nargs > 6 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 7)) \
: CUBLAS_OP_N;
printf("%d %d\n", ta, tb);
if (a->ncol != b->nrow)
nerv_error(L, "Wrong dimension of multipliers");
- MATRIX_ELEM alpha = 1.0f, beta = 0.0f;
+/* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */
NERV_CUBLAS_(gemm)(cublas_handle, tb, ta,
b->ncol, a->nrow, b->nrow,
&alpha,
@@ -131,10 +133,22 @@ static int nerv_matrix_(add_row)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
double beta = luaL_checknumber(L, 3);
+ if (a->ncol != b->ncol)
+ nerv_error(L, "the number of columns is not the same");
+ if (a->nrow != 1)
+ nerv_error(L, "a row vector is expected");
cudak_(cuda_add_row)(a, b, beta);
return 0;
}
+static int nerv_matrix_(fill)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ double val = luaL_checknumber(L, 2);
+ cudak_(cuda_fill)(self, val);
+ return 0;
+}
+
+
extern const char *MATRIX_CUMATRIX_HOST_TNAME;
static int nerv_matrix_(copy_from)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
@@ -162,8 +176,6 @@ static int nerv_matrix_(copy_to)(lua_State *L) {
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
- {"add", nerv_matrix_(add)},
- {"mul", nerv_matrix_(mul)},
{"create", nerv_matrix_(create)},
{"sigmoid", nerv_matrix_(sigmoid)},
{"softmax", nerv_matrix_(softmax)},
@@ -173,7 +185,10 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"copy_from", nerv_matrix_(copy_from)},
{"copy_to", nerv_matrix_(copy_to)},
/* in-place calc */
+ {"add", nerv_matrix_(add)},
+ {"mul", nerv_matrix_(mul)},
{"add_row", nerv_matrix_(add_row)},
+ {"fill", nerv_matrix_(fill)},
{NULL, NULL}
};