summaryrefslogtreecommitdiff
path: root/matrix/generic/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r--matrix/generic/cumatrix.c27
1 files changed, 24 insertions, 3 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index 2deb7a3..ed64bbf 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -48,6 +48,7 @@ static int nerv_matrix_(add)(lua_State *L) {
MATRIX_ELEM alpha = luaL_checknumber(L, 4); /* alpha */
MATRIX_ELEM beta = luaL_checknumber(L, 5); /* alpha */
CHECK_SAME_DIMENSION(a, b);
+ CHECK_SAME_DIMENSION(a, c);
nerv_matrix_(add_)(a, b, c, alpha, beta);
return 0;
}
@@ -118,6 +119,7 @@ static int nerv_matrix_(softmax)(lua_State *L) {
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *max = nerv_matrix_(new_)(a->nrow, 1);
Matrix *dno = nerv_matrix_(new_)(a->nrow, 1);
+ CHECK_SAME_DIMENSION(a, b);
cudak_(cuda_rowmax)(a, max);
cudak_(cuda_softmax_denominator)(a, max, dno);
cudak_(cuda_softmax_final)(a, max, dno, b);
@@ -230,25 +232,44 @@ static int nerv_matrix_(trans)(lua_State *L) {
return 1;
}
+static int nerv_matrix_(mul_elem)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
+ Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ CHECK_SAME_DIMENSION(a, b);
+ CHECK_SAME_DIMENSION(a, c);
+ cudak_(cuda_mul_elem)(a, b, c);
+ return 0;
+}
+
+static int nerv_matrix_(log_elem)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ CHECK_SAME_DIMENSION(a, b);
+ cudak_(cuda_log_elem)(a, b);
+ return 0;
+}
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"create", nerv_matrix_(create)},
- {"softmax", nerv_matrix_(softmax)},
{"colsum", nerv_matrix_(colsum)},
{"rowsum", nerv_matrix_(rowsum)},
{"rowmax", nerv_matrix_(rowmax)},
+ {"trans", nerv_matrix_(trans)},
+ /* in-place calc */
{"copy_fromh", nerv_matrix_(copy_fromh)},
{"copy_fromd", nerv_matrix_(copy_fromd)},
{"copy_toh", nerv_matrix_(copy_toh)},
{"copy_tod", nerv_matrix_(copy_tod)},
- {"trans", nerv_matrix_(trans)},
- /* in-place calc */
{"add", nerv_matrix_(add)},
{"mul", nerv_matrix_(mul)},
{"add_row", nerv_matrix_(add_row)},
{"fill", nerv_matrix_(fill)},
{"sigmoid", nerv_matrix_(sigmoid)},
{"sigmoid_grad", nerv_matrix_(sigmoid_grad)},
+ {"softmax", nerv_matrix_(softmax)},
+ {"mul_elem", nerv_matrix_(mul_elem)},
+ {"log_elem", nerv_matrix_(log_elem)},
{NULL, NULL}
};