aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r--matrix/generic/cumatrix.c24
1 files changed, 19 insertions, 5 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index a8e18e0..b5d1a35 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -37,8 +37,8 @@ static int nerv_matrix_(add)(lua_State *L) {
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
- MATRIX_ELEM alpha = luaL_checknumber(L, 4); /* alpha */
- MATRIX_ELEM beta = luaL_checknumber(L, 5); /* alpha */
+ MATRIX_ELEM alpha = luaL_checknumber(L, 4);
+ MATRIX_ELEM beta = luaL_checknumber(L, 5);
CHECK_SAME_DIMENSION(a, b);
CHECK_SAME_DIMENSION(a, c);
nerv_matrix_(add_)(L, a, b, c, alpha, beta);
@@ -396,7 +396,20 @@ static int nerv_matrix_(rearrange_frm)(lua_State *L) {
return 0;
}
-static int nerv_matrix_(scale_row)(lua_State *L) {
+static int nerv_matrix_(scale_rows_by_col)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ if (a->nrow != b->nrow)
+ nerv_error(L, "the number of rows is not the same");
+ if (b->ncol != 1)
+ nerv_error(L, "a column vector is expected");
+ PROFILE_START
+ cudak_(cuda_scale_rows_by_col)(b, a);
+ PROFILE_STOP
+ return 0;
+}
+
+static int nerv_matrix_(scale_rows_by_row)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
if (a->ncol != b->ncol)
@@ -404,7 +417,7 @@ static int nerv_matrix_(scale_row)(lua_State *L) {
if (b->nrow != 1)
nerv_error(L, "a row vector is expected");
PROFILE_START
- cudak_(cuda_scale_row)(b, a);
+ cudak_(cuda_scale_rows_by_row)(b, a);
PROFILE_STOP
return 0;
}
@@ -434,7 +447,8 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)},
{"expand_frm", nerv_matrix_(expand_frm)},
{"rearrange_frm", nerv_matrix_(rearrange_frm)},
- {"scale_row", nerv_matrix_(scale_row)},
+ {"scale_rows_by_row", nerv_matrix_(scale_rows_by_row)},
+ {"scale_rows_by_col", nerv_matrix_(scale_rows_by_col)},
{NULL, NULL}
};