aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic')
-rw-r--r--matrix/generic/cumatrix.c35
-rw-r--r--matrix/generic/matrix.c10
2 files changed, 36 insertions, 9 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index aa303d4..8de6c1b 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -43,8 +43,7 @@ static int nerv_matrix_(add)(lua_State *L) {
if (!(a->nrow == b->nrow && a->ncol == b->ncol))
nerv_error(L, "Matrices should be of the same dimension");
nerv_matrix_(add_)(a, b, c, alpha, beta);
- luaT_pushudata(L, c, nerv_matrix_(tname));
- return 1;
+ return 0;
}
static int nerv_matrix_(get_cublas_op)(char ch) {
@@ -52,6 +51,9 @@ static int nerv_matrix_(get_cublas_op)(char ch) {
}
static int nerv_matrix_(mul)(lua_State *L) {
+#define SWAP(a, b) \
+ do { int t = (a); (a) = (b); (b) = t; } while (0)
+
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
@@ -62,23 +64,26 @@ static int nerv_matrix_(mul)(lua_State *L) {
: CUBLAS_OP_N;
int tb = nargs > 6 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 7)) \
: CUBLAS_OP_N;
- printf("%d %d\n", ta, tb);
- if (a->ncol != b->nrow)
+ int am = a->nrow, an = a->ncol;
+ int bm = b->nrow, bn = b->ncol;
+ if (ta == CUBLAS_OP_T) SWAP(am, an);
+ if (tb == CUBLAS_OP_T) SWAP(bm, bn);
+ if (an != bm)
nerv_error(L, "Wrong dimension of multipliers");
/* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */
NERV_CUBLAS_(gemm)(cublas_handle, tb, ta,
- b->ncol, a->nrow, b->nrow,
+ bn, am, bm,
&alpha,
MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
&beta,
MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM));
- luaT_pushudata(L, c, nerv_matrix_(tname));
- return 1;
+ return 0;
}
static int nerv_matrix_(create)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ fprintf(stderr, "create\n");
Matrix *b = nerv_matrix_(new_)(a->nrow, a->ncol);
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
@@ -174,6 +179,21 @@ static int nerv_matrix_(copy_to)(lua_State *L) {
return 0;
}
+static int nerv_matrix_(trans)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = nerv_matrix_(new_)(a->ncol, a->nrow);
+ MATRIX_ELEM alpha = 1, beta = 0;
+ NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T,
+ a->nrow, a->ncol,
+ &alpha,
+ MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
+ &beta,
+ MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
+ MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM));
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ return 1;
+}
+
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"create", nerv_matrix_(create)},
@@ -184,6 +204,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"rowmax", nerv_matrix_(rowmax)},
{"copy_from", nerv_matrix_(copy_from)},
{"copy_to", nerv_matrix_(copy_to)},
+ {"trans", nerv_matrix_(trans)},
/* in-place calc */
{"add", nerv_matrix_(add)},
{"mul", nerv_matrix_(mul)},
diff --git a/matrix/generic/matrix.c b/matrix/generic/matrix.c
index c3838d2..74c9f19 100644
--- a/matrix/generic/matrix.c
+++ b/matrix/generic/matrix.c
@@ -9,8 +9,14 @@ extern const char *nerv_matrix_(tname);
extern const char *MATRIX_BASE_TNAME;
void nerv_matrix_(data_free)(Matrix *self) {
+ assert(*self->data_ref > 0);
if (--(*self->data_ref) == 0)
+ {
+ /* free matrix data */
MATRIX_DATA_FREE(MATRIX_ELEM_PTR(self));
+ free(self->data_ref);
+ free(self);
+ }
}
void nerv_matrix_(data_retain)(Matrix *self) {
@@ -40,7 +46,7 @@ int nerv_matrix_(new)(lua_State *L) {
int nerv_matrix_(destroy)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
nerv_matrix_(data_free)(self);
- return 0;
+ return 1;
}
int nerv_matrix_(get_elem)(lua_State *L);
@@ -54,7 +60,7 @@ static Matrix *nerv_matrix_(getrow)(Matrix *self, int row) {
prow->nmax = prow->ncol;
MATRIX_ELEM_PTR(prow) = MATRIX_ROW_PTR(self, row);
prow->data_ref = self->data_ref;
- nerv_matrix_(data_retain)(self);
+ nerv_matrix_(data_retain)(prow);
return prow;
}