diff options
author | cloudygoose <[email protected]> | 2015-06-02 11:14:18 +0800 |
---|---|---|
committer | cloudygoose <[email protected]> | 2015-06-02 11:14:18 +0800 |
commit | 2bb2076fe12deb3bf6a38bd2e192dca06c7736e0 (patch) | |
tree | 8d0ae95e5474eb70d86a000ccc0b38df017af3e0 /matrix/generic/cumatrix.c | |
parent | 5e7fcdf4e5be450927764254d492d87349e4114e (diff) | |
parent | d0a3e02d1a25a681ac78fd66aedf63f96636f6d2 (diff) |
...
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'matrix/generic/cumatrix.c')
-rw-r--r-- | matrix/generic/cumatrix.c | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c index 3bc58d7..58f3679 100644 --- a/matrix/generic/cumatrix.c +++ b/matrix/generic/cumatrix.c @@ -282,6 +282,40 @@ static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) { return 0; } +static int nerv_matrix_(expand_frm)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + int context = luaL_checkinteger(L, 3); + if (a->nrow != b->nrow) + nerv_error(L, "mismatching number of frames"); + if (a->ncol != b->ncol * (context * 2 + 1)) + nerv_error(L, "the width should be 2 * context + 1"); + cudak_(cuda_expand_frm)(b, a, context); + return 0; +} + +static int nerv_matrix_(rearrange_frm)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + int step = luaL_checkinteger(L, 3); + CHECK_SAME_DIMENSION(a, b); + if (b->ncol % step) + nerv_error(L, "the dimension of columns is not divisible by step"); + cudak_(cuda_rearrange_frm)(b, a, step); + return 0; +} + +static int nerv_matrix_(scale_row)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + if (a->ncol != b->ncol) + nerv_error(L, "the number of columns is not the same"); + if (b->nrow != 1) + nerv_error(L, "a row vector is expected"); + cudak_(cuda_scale_row)(b, a); + return 0; +} + static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"create", nerv_matrix_(create)}, {"colsum", nerv_matrix_(colsum)}, @@ -303,6 +337,9 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"mul_elem", nerv_matrix_(mul_elem)}, {"log_elem", nerv_matrix_(log_elem)}, {"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)}, + {"expand_frm", nerv_matrix_(expand_frm)}, + {"rearrange_frm", nerv_matrix_(rearrange_frm)}, + {"scale_row", nerv_matrix_(scale_row)}, {NULL, NULL} }; |