diff options
Diffstat (limited to 'matrix/cumatrix.c')
-rw-r--r-- | matrix/cumatrix.c | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c index 49b7fbf..aa10571 100644 --- a/matrix/cumatrix.c +++ b/matrix/cumatrix.c @@ -66,10 +66,30 @@ static int nerv_float_matrix_(sigmoid)(lua_State *L) { return 1; } -static int nerv_float_matrix_(rowsum)(lua_State *L) { +static int nerv_float_matrix_(softmax)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); + Matrix *max = nerv_float_matrix_(new_)(a->nrow, 1); + Matrix *dno = nerv_float_matrix_(new_)(a->nrow, 1); + Matrix *b = nerv_float_matrix_(new_)(a->nrow, a->ncol); + cuda_colmax(a, max); + cuda_softmax_denominator(a, max, dno); + cuda_softmax_final(a, max, dno, b); + luaT_pushudata(L, b, nerv_float_matrix_(tname)); + return 1; +} + +static int nerv_float_matrix_(colsum)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); + Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1); + cuda_colsum(a, b); + luaT_pushudata(L, b, nerv_float_matrix_(tname)); + return 1; +} + +static int nerv_float_matrix_(colmax)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1); - cuda_rowsum(a, b); + cuda_colmax(a, b); luaT_pushudata(L, b, nerv_float_matrix_(tname)); return 1; } @@ -78,7 +98,9 @@ static const luaL_Reg nerv_float_matrix_(extra_methods)[] = { {"__add__", nerv_float_matrix_(add)}, {"__mul__", nerv_float_matrix_(mul)}, {"sigmoid", nerv_float_matrix_(sigmoid)}, - {"rowsum", nerv_float_matrix_(rowsum)}, + {"softmax", nerv_float_matrix_(softmax)}, + {"colsum", nerv_float_matrix_(colsum)}, + {"colmax", nerv_float_matrix_(colmax)}, {NULL, NULL} }; |