From 3362020a6bc43766a92882abe6d127c8bb98a628 Mon Sep 17 00:00:00 2001 From: Determinant Date: Mon, 15 Feb 2016 15:04:13 +0800 Subject: try a basic merge --- nerv/matrix/generic/matrix.c | 213 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) (limited to 'nerv/matrix/generic/matrix.c') diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c index 8efe608..c1da774 100644 --- a/nerv/matrix/generic/matrix.c +++ b/nerv/matrix/generic/matrix.c @@ -125,4 +125,217 @@ void nerv_matrix_(lua_init)(lua_State *L) { #endif lua_pop(L, 1); } + +static int nerv_matrix_(lua_add)(lua_State *L) { + Status status; + Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); + MATRIX_ELEM alpha = luaL_checknumber(L, 4); + MATRIX_ELEM beta = luaL_checknumber(L, 5); + nerv_matrix_(add)(c, a, b, alpha, beta, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_mul)(lua_State *L) { + Status status; + Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); + MATRIX_ELEM alpha = luaL_checknumber(L, 4); + MATRIX_ELEM beta = luaL_checknumber(L, 5); + int nargs = lua_gettop(L); + int ta = nargs > 5 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 6)) \ + : BLAS_OP_N; + int tb = nargs > 6 ? nerv_matrix_(lua_get_blas_op)(*luaL_checkstring(L, 7)) \ + : BLAS_OP_N; + nerv_matrix_(mul)(c, a, b, alpha, beta, ta, tb, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_sigmoid)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + nerv_matrix_(sigmoid)(a, b, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_sigmoid_grad)(lua_State *L) { + Status status; + Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname)); + Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname)); + nerv_matrix_(sigmoid_grad)(nerr, err, output, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_softmax)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *max_idx = nerv_matrix_(softmax)(b, a, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, max_idx, nerv_matrix_(tname)); + return 1; +} + +static int nerv_matrix_(lua_rowsum)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(rowsum)(a, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + +static int nerv_matrix_(lua_colsum)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(colsum)(a, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + +static int nerv_matrix_(lua_colsame)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(colsame)(a, ref, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + +static int nerv_matrix_(lua_rowmax)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(rowmax)(a, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + +static int nerv_matrix_(lua_rowmax_idx)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b; + Matrix *idx; + nerv_matrix_(rowmax_idx)(a, &b, &idx, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, b, nerv_matrix_(tname)); + luaT_pushudata(L, idx, nerv_matrix_(tname)); + return 2; +} + +static int nerv_matrix_(lua_add_row)(lua_State *L) { + Status status; + const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); + double beta = luaL_checknumber(L, 3); + nerv_matrix_(add_row)(b, a, beta, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_fill)(lua_State *L) { + Status status; + Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); + double val = luaL_checknumber(L, 2); + nerv_matrix_(fill)(self, val, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_clip)(lua_State *L) { + Status status; + Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); + double val_1 = luaL_checknumber(L, 2); + double val_2 = luaL_checknumber(L, 3); + nerv_matrix_(clip)(self, val_1, val_2, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_trans)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(trans)(a, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + +static int nerv_matrix_(lua_mul_elem)(lua_State *L) { + Status status; + const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname)); + Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); + nerv_matrix_(mul_elem)(c, a, b, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_log_elem)(lua_State *L) { + Status status; + const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); + Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); + nerv_matrix_(log_elem)(b, a, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_decompress)(lua_State *L) { + Status status; + const Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + int orig_col = luaL_checkinteger(L, 2); + Matrix *b = nerv_matrix_(decompress)(a, orig_col, &status); + NERV_LUA_CHECK_STATUS(L, status); + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + +static int nerv_matrix_(lua_expand_frm)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + int context = luaL_checkinteger(L, 3); + nerv_matrix_(expand_frm)(a, b, context, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_rearrange_frm)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + int step = luaL_checkinteger(L, 3); + nerv_matrix_(rearrange_frm)(a, b, step, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_scale_rows_by_col)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + nerv_matrix_(scale_rows_by_col)(a, b, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + +static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); + nerv_matrix_(scale_rows_by_row)(a, b, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + #endif -- cgit v1.2.3