From 0b128e097d425418499ab2257c5448f14fec3215 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 22 May 2015 12:13:37 +0800 Subject: ... --- cumatrix_example.lua | 7 +++++-- matrix/generic/cumatrix.c | 19 ++++++++++++++----- matrix/init.lua | 6 ++++++ 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/cumatrix_example.lua b/cumatrix_example.lua index 26e1dc4..f8235eb 100644 --- a/cumatrix_example.lua +++ b/cumatrix_example.lua @@ -10,11 +10,14 @@ for i = 0, m - 1 do dm[i][j] = t end end --- print(fm) +print(fm) fs = fm:softmax() -- print(fs) --- print(dm) +print(dm) ds = dm:softmax() -- print(ds) +print(fs) print(fs + fs) +print(ds + ds) print(fs - fs) +print(ds - ds) diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c index 90c6d6a..d98c559 100644 --- a/matrix/generic/cumatrix.c +++ b/matrix/generic/cumatrix.c @@ -47,15 +47,24 @@ static int nerv_matrix_(add)(lua_State *L) { return 1; } +static int nerv_matrix_(get_cublas_op)(char ch) { + return (ch == 'T' || ch == 't') ? CUBLAS_OP_T : CUBLAS_OP_N; +} + static int nerv_matrix_(mul)(lua_State *L) { - Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); - Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); - Matrix *c; + Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); + int nargs = lua_gettop(L); + int ta = nargs > 3 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 4)) \ + : CUBLAS_OP_N; + int tb = nargs > 4 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 5)) \ + : CUBLAS_OP_N; + printf("%d %d\n", ta, tb); if (a->ncol != b->nrow) nerv_error(L, "Wrong dimension of multipliers"); - c = nerv_matrix_(new_)(a->nrow, b->ncol); MATRIX_ELEM alpha = 1.0f, beta = 0.0f; - NERV_CUBLAS_(gemm)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + NERV_CUBLAS_(gemm)(cublas_handle, tb, ta, b->ncol, a->nrow, b->nrow, &alpha, MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM), diff --git a/matrix/init.lua b/matrix/init.lua index c33cf7e..8f626dc 100644 --- a/matrix/init.lua +++ b/matrix/init.lua @@ -32,3 +32,9 @@ function nerv.CuMatrix:__sub__(b) c:add(self, b, 1.0, -1.0) return c end + +function nerv.CuMatrix:__mul__(b) + c = self:create() + c:mul(self, b, 'N', 'N') + return c +end -- cgit v1.2.3-70-g09d2