aboutsummaryrefslogtreecommitdiff
path: root/matrix/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/cumatrix.c')
-rw-r--r--matrix/cumatrix.c28
1 files changed, 25 insertions, 3 deletions
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index 49b7fbf..aa10571 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -66,10 +66,30 @@ static int nerv_float_matrix_(sigmoid)(lua_State *L) {
return 1;
}
-static int nerv_float_matrix_(rowsum)(lua_State *L) {
+static int nerv_float_matrix_(softmax)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+ Matrix *max = nerv_float_matrix_(new_)(a->nrow, 1);
+ Matrix *dno = nerv_float_matrix_(new_)(a->nrow, 1);
+ Matrix *b = nerv_float_matrix_(new_)(a->nrow, a->ncol);
+ cuda_colmax(a, max);
+ cuda_softmax_denominator(a, max, dno);
+ cuda_softmax_final(a, max, dno, b);
+ luaT_pushudata(L, b, nerv_float_matrix_(tname));
+ return 1;
+}
+
+static int nerv_float_matrix_(colsum)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+ Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
+ cuda_colsum(a, b);
+ luaT_pushudata(L, b, nerv_float_matrix_(tname));
+ return 1;
+}
+
+static int nerv_float_matrix_(colmax)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
- cuda_rowsum(a, b);
+ cuda_colmax(a, b);
luaT_pushudata(L, b, nerv_float_matrix_(tname));
return 1;
}
@@ -78,7 +98,9 @@ static const luaL_Reg nerv_float_matrix_(extra_methods)[] = {
{"__add__", nerv_float_matrix_(add)},
{"__mul__", nerv_float_matrix_(mul)},
{"sigmoid", nerv_float_matrix_(sigmoid)},
- {"rowsum", nerv_float_matrix_(rowsum)},
+ {"softmax", nerv_float_matrix_(softmax)},
+ {"colsum", nerv_float_matrix_(colsum)},
+ {"colmax", nerv_float_matrix_(colmax)},
{NULL, NULL}
};