aboutsummaryrefslogtreecommitdiff
path: root/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'matrix')
-rw-r--r--matrix/generic/cumatrix.c19
-rw-r--r--matrix/init.lua6
2 files changed, 20 insertions, 5 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index 90c6d6a..d98c559 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -47,15 +47,24 @@ static int nerv_matrix_(add)(lua_State *L) {
return 1;
}
+static int nerv_matrix_(get_cublas_op)(char ch) {
+ return (ch == 'T' || ch == 't') ? CUBLAS_OP_T : CUBLAS_OP_N;
+}
+
static int nerv_matrix_(mul)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *c;
+ Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
+ int nargs = lua_gettop(L);
+ int ta = nargs > 3 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 4)) \
+ : CUBLAS_OP_N;
+ int tb = nargs > 4 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 5)) \
+ : CUBLAS_OP_N;
+ printf("%d %d\n", ta, tb);
if (a->ncol != b->nrow)
nerv_error(L, "Wrong dimension of multipliers");
- c = nerv_matrix_(new_)(a->nrow, b->ncol);
MATRIX_ELEM alpha = 1.0f, beta = 0.0f;
- NERV_CUBLAS_(gemm)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
+ NERV_CUBLAS_(gemm)(cublas_handle, tb, ta,
b->ncol, a->nrow, b->nrow,
&alpha,
MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
diff --git a/matrix/init.lua b/matrix/init.lua
index c33cf7e..8f626dc 100644
--- a/matrix/init.lua
+++ b/matrix/init.lua
@@ -32,3 +32,9 @@ function nerv.CuMatrix:__sub__(b)
c:add(self, b, 1.0, -1.0)
return c
end
+
+function nerv.CuMatrix:__mul__(b)
+ c = self:create()
+ c:mul(self, b, 'N', 'N')
+ return c
+end