diff options
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r-- | matrix/generic/cumatrix.c | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c index aa303d4..8de6c1b 100644 --- a/matrix/generic/cumatrix.c +++ b/matrix/generic/cumatrix.c @@ -43,8 +43,7 @@ static int nerv_matrix_(add)(lua_State *L) { if (!(a->nrow == b->nrow && a->ncol == b->ncol)) nerv_error(L, "Matrices should be of the same dimension"); nerv_matrix_(add_)(a, b, c, alpha, beta); - luaT_pushudata(L, c, nerv_matrix_(tname)); - return 1; + return 0; } static int nerv_matrix_(get_cublas_op)(char ch) { @@ -52,6 +51,9 @@ static int nerv_matrix_(get_cublas_op)(char ch) { } static int nerv_matrix_(mul)(lua_State *L) { +#define SWAP(a, b) \ + do { int t = (a); (a) = (b); (b) = t; } while (0) + Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); @@ -62,23 +64,26 @@ static int nerv_matrix_(mul)(lua_State *L) { : CUBLAS_OP_N; int tb = nargs > 6 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 7)) \ : CUBLAS_OP_N; - printf("%d %d\n", ta, tb); - if (a->ncol != b->nrow) + int am = a->nrow, an = a->ncol; + int bm = b->nrow, bn = b->ncol; + if (ta == CUBLAS_OP_T) SWAP(am, an); + if (tb == CUBLAS_OP_T) SWAP(bm, bn); + if (an != bm) nerv_error(L, "Wrong dimension of multipliers"); /* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */ NERV_CUBLAS_(gemm)(cublas_handle, tb, ta, - b->ncol, a->nrow, b->nrow, + bn, am, bm, &alpha, MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM), MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM), &beta, MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM)); - luaT_pushudata(L, c, nerv_matrix_(tname)); - return 1; + return 0; } static int nerv_matrix_(create)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + fprintf(stderr, "create\n"); Matrix *b = nerv_matrix_(new_)(a->nrow, a->ncol); luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; @@ -174,6 +179,21 @@ static int nerv_matrix_(copy_to)(lua_State *L) { return 0; } +static int nerv_matrix_(trans)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(new_)(a->ncol, a->nrow); + MATRIX_ELEM alpha = 1, beta = 0; + NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, + a->nrow, a->ncol, + &alpha, + MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM), + &beta, + MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM), + MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM)); + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"create", nerv_matrix_(create)}, @@ -184,6 +204,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"rowmax", nerv_matrix_(rowmax)}, {"copy_from", nerv_matrix_(copy_from)}, {"copy_to", nerv_matrix_(copy_to)}, + {"trans", nerv_matrix_(trans)}, /* in-place calc */ {"add", nerv_matrix_(add)}, {"mul", nerv_matrix_(mul)}, |