aboutsummaryrefslogtreecommitdiff
path: root/matrix/cumatrix.c
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-05-18 23:34:08 +0800
committerDeterminant <ted.sybil@gmail.com>2015-05-18 23:34:08 +0800
commit186cf4f39e1c753a6056101f654d2939f812d285 (patch)
treeda6bbc97c8f553ad473c820484b879a4b6a00968 /matrix/cumatrix.c
parent23fd2694723ab3f2203e6cd040c5e6633cb989c7 (diff)
add softmax for cumatrix
Diffstat (limited to 'matrix/cumatrix.c')
-rw-r--r--matrix/cumatrix.c28
1 files changed, 25 insertions, 3 deletions
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index 49b7fbf..aa10571 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -66,10 +66,30 @@ static int nerv_float_matrix_(sigmoid)(lua_State *L) {
return 1;
}
-static int nerv_float_matrix_(rowsum)(lua_State *L) {
+static int nerv_float_matrix_(softmax)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+ Matrix *max = nerv_float_matrix_(new_)(a->nrow, 1);
+ Matrix *dno = nerv_float_matrix_(new_)(a->nrow, 1);
+ Matrix *b = nerv_float_matrix_(new_)(a->nrow, a->ncol);
+ cuda_colmax(a, max);
+ cuda_softmax_denominator(a, max, dno);
+ cuda_softmax_final(a, max, dno, b);
+ luaT_pushudata(L, b, nerv_float_matrix_(tname));
+ return 1;
+}
+
+static int nerv_float_matrix_(colsum)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+ Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
+ cuda_colsum(a, b);
+ luaT_pushudata(L, b, nerv_float_matrix_(tname));
+ return 1;
+}
+
+static int nerv_float_matrix_(colmax)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
- cuda_rowsum(a, b);
+ cuda_colmax(a, b);
luaT_pushudata(L, b, nerv_float_matrix_(tname));
return 1;
}
@@ -78,7 +98,9 @@ static const luaL_Reg nerv_float_matrix_(extra_methods)[] = {
{"__add__", nerv_float_matrix_(add)},
{"__mul__", nerv_float_matrix_(mul)},
{"sigmoid", nerv_float_matrix_(sigmoid)},
- {"rowsum", nerv_float_matrix_(rowsum)},
+ {"softmax", nerv_float_matrix_(softmax)},
+ {"colsum", nerv_float_matrix_(colsum)},
+ {"colmax", nerv_float_matrix_(colmax)},
{NULL, NULL}
};