diff options
-rw-r--r-- | matrix.c | 93 | ||||
-rw-r--r-- | matrix.lua | 10 | ||||
-rw-r--r-- | matrix_example.lua | 9 |
3 files changed, 99 insertions, 13 deletions
@@ -1,32 +1,44 @@ #include "common.h" typedef struct Matrix { - long stride; /* size of a row */ + long stride; /* size of a row */ long ncol, nrow, nmax; /* dimension of the matrix */ union { float *f; double *d; - } storage; /* pointer to actual storage */ + } data; /* pointer to actual storage */ + long *data_ref; } Matrix; const char *float_matrix_tname = "nerv.FloatMatrix"; const char *matrix_tname = "nerv.Matrix"; +void float_matrix_data_free(Matrix *self) { + if (--(*self->data_ref) == 0) + free(self->data.f); +} + +void float_matrix_data_retain(Matrix *self) { + (*self->data_ref)++; +} + int float_matrix_new(lua_State *L) { Matrix *self = (Matrix *)malloc(sizeof(Matrix)); self->nrow = luaL_checkinteger(L, 1); self->ncol = luaL_checkinteger(L, 2); self->nmax = self->nrow * self->ncol; - self->stride = sizeof(float) * self->nrow; - self->storage.f = (float *)malloc(self->stride * self->ncol); + self->stride = sizeof(float) * self->ncol; + self->data.f = (float *)malloc(self->stride * self->nrow); + self->data_ref = (long *)malloc(sizeof(long)); + *self->data_ref = 0; + float_matrix_data_retain(self); luaT_pushudata(L, self, float_matrix_tname); return 1; } int float_matrix_destroy(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, float_matrix_tname); - free(self->storage.f); - fprintf(stderr, "[debug] destroyted\n"); + float_matrix_data_free(self); return 0; } @@ -35,7 +47,7 @@ int nerv_float_matrix_get_elem(lua_State *L) { int idx = luaL_checkinteger(L, 2); if (idx < 0 || idx >= self->nmax) nerv_error(L, "index must be within range [0, %d)", self->nmax); - lua_pushnumber(L, self->storage.f[idx]); + lua_pushnumber(L, self->data.f[idx]); return 1; } @@ -46,10 +58,73 @@ int nerv_float_matrix_set_elem(lua_State *L) { long upper = self->nrow * self->ncol; if (idx < 0 || idx >= self->nmax) nerv_error(L, "index must be within range [0, %d)", self->nmax); - self->storage.f[idx] = v; + self->data.f[idx] = v; return 0; } +static Matrix *nerv_float_matrix_getrow(Matrix *self, int row) { + Matrix *prow = (Matrix *)malloc(sizeof(Matrix)); + prow->ncol = self->ncol; + prow->nrow = 1; + prow->stride = self->stride; + prow->nmax = prow->ncol; + prow->data.f = (float *)((char *)self->data.f + row * self->stride); + prow->data_ref = self->data_ref; + float_matrix_data_retain(self); + return prow; +} + +static int nerv_float_matrix_newindex(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, float_matrix_tname); + if (lua_isnumber(L, 2)) + { + int idx = luaL_checkinteger(L, 2); + if (self->nrow == 1) + { + if (idx < 0 || idx >= self->ncol) + nerv_error(L, "index must be within range [0, %d)", self->ncol); + self->data.f[idx] = luaL_checknumber(L, 3); + } + else + nerv_error(L, "cannot assign a scalar to row vector"); + lua_pushboolean(L, 1); + return 2; + } + else + { + lua_pushboolean(L, 0); + return 1; + } +} + + +static int nerv_float_matrix_index(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, float_matrix_tname); + if (lua_isnumber(L, 2)) + { + int idx = luaL_checkinteger(L, 2); + if (self->nrow == 1) + { + if (idx < 0 || idx >= self->ncol) + nerv_error(L, "index must be within range [0, %d)", self->ncol); + lua_pushnumber(L, self->data.f[idx]); + } + else + { + if (idx < 0 || idx >= self->nrow) + nerv_error(L, "index must be within range [0, %d)", self->nrow); + luaT_pushudata(L, nerv_float_matrix_getrow(self, idx), float_matrix_tname); + } + lua_pushboolean(L, 1); + return 2; + } + else + { + lua_pushboolean(L, 0); + return 1; + } +} + static int nerv_float_matrix_ncol(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, float_matrix_tname); lua_pushinteger(L, self->ncol); @@ -68,6 +143,8 @@ static const luaL_Reg float_matrix_methods[] = { {"set_elem", nerv_float_matrix_set_elem}, {"ncol", nerv_float_matrix_ncol}, {"nrow", nerv_float_matrix_nrow}, + {"__index__", nerv_float_matrix_index}, + {"__newindex__", nerv_float_matrix_newindex}, {NULL, NULL} }; @@ -2,14 +2,14 @@ function nerv.FloatMatrix:__tostring__() local ncol = self:ncol() local nrow = self:nrow() local i = 0 - local res = "" + local strt = {} for row = 0, nrow - 1 do for col = 0, ncol - 1 do - res = res .. string.format("%f ", self:get_elem(i)) + table.insert(strt, string.format("%f ", self:get_elem(i))) i = i + 1 end - res = res .. "\n" + table.insert(strt, "\n") end - res = res .. string.format("[Float Matrix %d x %d]", nrow, ncol) - return res + table.insert(strt, string.format("[Float Matrix %d x %d]", nrow, ncol)) + return table.concat(strt) end diff --git a/matrix_example.lua b/matrix_example.lua index 1ff129d..361e126 100644 --- a/matrix_example.lua +++ b/matrix_example.lua @@ -4,4 +4,13 @@ t:set_elem(1, 3.23432); print(t:get_elem(1)) print(t) t = nerv.FloatMatrix(10, 20) +t:set_elem(1, 3.34); print(t) +a = t[1] +for i = 0, 9 do + for j = 0, 19 do + t[i][j] = i + j + end +end +print(t) +print(a) |