summaryrefslogtreecommitdiff
path: root/matrix/generic/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r--matrix/generic/cumatrix.c37
1 files changed, 37 insertions, 0 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index 3bc58d7..58f3679 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -282,6 +282,40 @@ static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) {
return 0;
}
+static int nerv_matrix_(expand_frm)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ int context = luaL_checkinteger(L, 3);
+ if (a->nrow != b->nrow)
+ nerv_error(L, "mismatching number of frames");
+ if (a->ncol != b->ncol * (context * 2 + 1))
+ nerv_error(L, "the width should be 2 * context + 1");
+ cudak_(cuda_expand_frm)(b, a, context);
+ return 0;
+}
+
+static int nerv_matrix_(rearrange_frm)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ int step = luaL_checkinteger(L, 3);
+ CHECK_SAME_DIMENSION(a, b);
+ if (b->ncol % step)
+ nerv_error(L, "the dimension of columns is not divisible by step");
+ cudak_(cuda_rearrange_frm)(b, a, step);
+ return 0;
+}
+
+static int nerv_matrix_(scale_row)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ if (a->ncol != b->ncol)
+ nerv_error(L, "the number of columns is not the same");
+ if (b->nrow != 1)
+ nerv_error(L, "a row vector is expected");
+ cudak_(cuda_scale_row)(b, a);
+ return 0;
+}
+
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"create", nerv_matrix_(create)},
{"colsum", nerv_matrix_(colsum)},
@@ -303,6 +337,9 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"mul_elem", nerv_matrix_(mul_elem)},
{"log_elem", nerv_matrix_(log_elem)},
{"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)},
+ {"expand_frm", nerv_matrix_(expand_frm)},
+ {"rearrange_frm", nerv_matrix_(rearrange_frm)},
+ {"scale_row", nerv_matrix_(scale_row)},
{NULL, NULL}
};