diff options
Diffstat (limited to 'matrix/generic')
-rw-r--r-- | matrix/generic/cumatrix.c | 35 | ||||
-rw-r--r-- | matrix/generic/matrix.c | 10 |
2 files changed, 36 insertions, 9 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c index aa303d4..8de6c1b 100644 --- a/matrix/generic/cumatrix.c +++ b/matrix/generic/cumatrix.c @@ -43,8 +43,7 @@ static int nerv_matrix_(add)(lua_State *L) { if (!(a->nrow == b->nrow && a->ncol == b->ncol)) nerv_error(L, "Matrices should be of the same dimension"); nerv_matrix_(add_)(a, b, c, alpha, beta); - luaT_pushudata(L, c, nerv_matrix_(tname)); - return 1; + return 0; } static int nerv_matrix_(get_cublas_op)(char ch) { @@ -52,6 +51,9 @@ static int nerv_matrix_(get_cublas_op)(char ch) { } static int nerv_matrix_(mul)(lua_State *L) { +#define SWAP(a, b) \ + do { int t = (a); (a) = (b); (b) = t; } while (0) + Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); @@ -62,23 +64,26 @@ static int nerv_matrix_(mul)(lua_State *L) { : CUBLAS_OP_N; int tb = nargs > 6 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 7)) \ : CUBLAS_OP_N; - printf("%d %d\n", ta, tb); - if (a->ncol != b->nrow) + int am = a->nrow, an = a->ncol; + int bm = b->nrow, bn = b->ncol; + if (ta == CUBLAS_OP_T) SWAP(am, an); + if (tb == CUBLAS_OP_T) SWAP(bm, bn); + if (an != bm) nerv_error(L, "Wrong dimension of multipliers"); /* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */ NERV_CUBLAS_(gemm)(cublas_handle, tb, ta, - b->ncol, a->nrow, b->nrow, + bn, am, bm, &alpha, MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM), MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM), &beta, MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM)); - luaT_pushudata(L, c, nerv_matrix_(tname)); - return 1; + return 0; } static int nerv_matrix_(create)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + fprintf(stderr, "create\n"); Matrix *b = nerv_matrix_(new_)(a->nrow, a->ncol); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; @@ -174,6 +179,21 @@ static int nerv_matrix_(copy_to)(lua_State *L) { return 0; } +static int nerv_matrix_(trans)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(new_)(a->ncol, a->nrow); + MATRIX_ELEM alpha = 1, beta = 0; + NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, + a->nrow, a->ncol, + &alpha, + MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM), + &beta, + MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM), + MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM)); + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"create", nerv_matrix_(create)}, @@ -184,6 +204,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"rowmax", nerv_matrix_(rowmax)}, {"copy_from", nerv_matrix_(copy_from)}, {"copy_to", nerv_matrix_(copy_to)}, + {"trans", nerv_matrix_(trans)}, /* in-place calc */ {"add", nerv_matrix_(add)}, {"mul", nerv_matrix_(mul)}, diff --git a/matrix/generic/matrix.c b/matrix/generic/matrix.c index c3838d2..74c9f19 100644 --- a/matrix/generic/matrix.c +++ b/matrix/generic/matrix.c @@ -9,8 +9,14 @@ extern const char *nerv_matrix_(tname); extern const char *MATRIX_BASE_TNAME; void nerv_matrix_(data_free)(Matrix *self) { + assert(*self->data_ref > 0); if (--(*self->data_ref) == 0) + { + /* free matrix data */ MATRIX_DATA_FREE(MATRIX_ELEM_PTR(self)); + free(self->data_ref); + free(self); + } } void nerv_matrix_(data_retain)(Matrix *self) { @@ -40,7 +46,7 @@ int nerv_matrix_(new)(lua_State *L) { int nerv_matrix_(destroy)(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); nerv_matrix_(data_free)(self); - return 0; + return 1; } int nerv_matrix_(get_elem)(lua_State *L); @@ -54,7 +60,7 @@ static Matrix *nerv_matrix_(getrow)(Matrix *self, int row) { prow->nmax = prow->ncol; MATRIX_ELEM_PTR(prow) = MATRIX_ROW_PTR(self, row); prow->data_ref = self->data_ref; - nerv_matrix_(data_retain)(self); + nerv_matrix_(data_retain)(prow); return prow; } |