summaryrefslogtreecommitdiff
path: root/matrix/generic/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r--matrix/generic/cumatrix.c35
1 files changed, 28 insertions, 7 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index aa303d4..8de6c1b 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -43,8 +43,7 @@ static int nerv_matrix_(add)(lua_State *L) {
if (!(a->nrow == b->nrow && a->ncol == b->ncol))
nerv_error(L, "Matrices should be of the same dimension");
nerv_matrix_(add_)(a, b, c, alpha, beta);
- luaT_pushudata(L, c, nerv_matrix_(tname));
- return 1;
+ return 0;
}
static int nerv_matrix_(get_cublas_op)(char ch) {
@@ -52,6 +51,9 @@ static int nerv_matrix_(get_cublas_op)(char ch) {
}
static int nerv_matrix_(mul)(lua_State *L) {
+#define SWAP(a, b) \
+ do { int t = (a); (a) = (b); (b) = t; } while (0)
+
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
@@ -62,23 +64,26 @@ static int nerv_matrix_(mul)(lua_State *L) {
: CUBLAS_OP_N;
int tb = nargs > 6 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 7)) \
: CUBLAS_OP_N;
- printf("%d %d\n", ta, tb);
- if (a->ncol != b->nrow)
+ int am = a->nrow, an = a->ncol;
+ int bm = b->nrow, bn = b->ncol;
+ if (ta == CUBLAS_OP_T) SWAP(am, an);
+ if (tb == CUBLAS_OP_T) SWAP(bm, bn);
+ if (an != bm)
nerv_error(L, "Wrong dimension of multipliers");
/* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */
NERV_CUBLAS_(gemm)(cublas_handle, tb, ta,
- b->ncol, a->nrow, b->nrow,
+ bn, am, bm,
&alpha,
MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
&beta,
MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM));
- luaT_pushudata(L, c, nerv_matrix_(tname));
- return 1;
+ return 0;
}
static int nerv_matrix_(create)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ fprintf(stderr, "create\n");
Matrix *b = nerv_matrix_(new_)(a->nrow, a->ncol);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -174,6 +179,21 @@ static int nerv_matrix_(copy_to)(lua_State *L) {
return 0;
}
+static int nerv_matrix_(trans)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = nerv_matrix_(new_)(a->ncol, a->nrow);
+ MATRIX_ELEM alpha = 1, beta = 0;
+ NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T,
+ a->nrow, a->ncol,
+ &alpha,
+ MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
+ &beta,
+ MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
+ MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM));
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ return 1;
+}
+
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"create", nerv_matrix_(create)},
@@ -184,6 +204,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"rowmax", nerv_matrix_(rowmax)},
{"copy_from", nerv_matrix_(copy_from)},
{"copy_to", nerv_matrix_(copy_to)},
+ {"trans", nerv_matrix_(trans)},
/* in-place calc */
{"add", nerv_matrix_(add)},
{"mul", nerv_matrix_(mul)},