From a309ce5e33b22030bcac348c63576187676abee3 Mon Sep 17 00:00:00 2001 From: Determinant Date: Mon, 1 Jun 2015 17:37:13 +0800 Subject: add expand_frm, rearrange_frm, scale_row --- matrix/generic/cumatrix.c | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) (limited to 'matrix/generic/cumatrix.c') diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c index 3bc58d7..58f3679 100644 --- a/matrix/generic/cumatrix.c +++ b/matrix/generic/cumatrix.c @@ -282,6 +282,40 @@ static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) { return 0; } +static int nerv_matrix_(expand_frm)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + int context = luaL_checkinteger(L, 3); + if (a->nrow != b->nrow) + nerv_error(L, "mismatching number of frames"); + if (a->ncol != b->ncol * (context * 2 + 1)) + nerv_error(L, "the width should be 2 * context + 1"); + cudak_(cuda_expand_frm)(b, a, context); + return 0; +} + +static int nerv_matrix_(rearrange_frm)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + int step = luaL_checkinteger(L, 3); + CHECK_SAME_DIMENSION(a, b); + if (b->ncol % step) + nerv_error(L, "the dimension of columns is not divisible by step"); + cudak_(cuda_rearrange_frm)(b, a, step); + return 0; +} + +static int nerv_matrix_(scale_row)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + if (a->ncol != b->ncol) + nerv_error(L, "the number of columns is not the same"); + if (b->nrow != 1) + nerv_error(L, "a row vector is expected"); + cudak_(cuda_scale_row)(b, a); + return 0; +} + static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"create", nerv_matrix_(create)}, {"colsum", nerv_matrix_(colsum)}, @@ -303,6 +337,9 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"mul_elem", nerv_matrix_(mul_elem)}, {"log_elem", nerv_matrix_(log_elem)}, {"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)}, + {"expand_frm", nerv_matrix_(expand_frm)}, + {"rearrange_frm", nerv_matrix_(rearrange_frm)}, + {"scale_row", nerv_matrix_(scale_row)}, {NULL, NULL} }; -- cgit v1.2.3