diff options
author | Determinant <[email protected]> | 2015-05-15 02:36:55 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-05-15 02:36:55 +0800 |
commit | efb786d716363dde8f90ef0672f479790befc79c (patch) | |
tree | 1d7a6f07db9cae3fd65a437d3ca6ff398a2b684b /matrix | |
parent | b03471e2b0d604806773b540551cd047979b7b3b (diff) |
use C macro to implement matrix template
Diffstat (limited to 'matrix')
-rw-r--r-- | matrix/generic/matrix.c | 132 | ||||
-rw-r--r-- | matrix/generic/matrix.h | 9 | ||||
-rw-r--r-- | matrix/init.c | 19 | ||||
-rw-r--r-- | matrix/matrix.c | 26 | ||||
-rw-r--r-- | matrix/matrix.lua | 15 |
5 files changed, 201 insertions, 0 deletions
diff --git a/matrix/generic/matrix.c b/matrix/generic/matrix.c new file mode 100644 index 0000000..3bb12ef --- /dev/null +++ b/matrix/generic/matrix.c @@ -0,0 +1,132 @@ +#ifdef MATRIX_GENERIC +#include "../../common.h" +#include "matrix.h" + +extern const char *nerv_matrix_tname; +const char *nerv_float_matrix_tname = "nerv.FloatMatrix"; + +void nerv_float_matrix_(data_free)(Matrix *self) { + if (--(*self->data_ref) == 0) + MATRIX_DATA_FREE(self->data.f); +} + +void nerv_float_matrix_(data_retain)(Matrix *self) { + (*self->data_ref)++; +} + +int nerv_float_matrix_(new)(lua_State *L) { + Matrix *self = (Matrix *)malloc(sizeof(Matrix)); + self->nrow = luaL_checkinteger(L, 1); + self->ncol = luaL_checkinteger(L, 2); + self->nmax = self->nrow * self->ncol; + self->stride = MATRIX_DATA_STRIDE(self->ncol); + self->data.f = MATRIX_DATA_ALLOC(self->stride * self->nrow); + self->data_ref = (long *)malloc(sizeof(long)); + *self->data_ref = 0; + nerv_float_matrix_(data_retain)(self); + luaT_pushudata(L, self, nerv_float_matrix_tname); + return 1; +} + +int nerv_float_matrix_(destroy)(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname); + nerv_float_matrix_(data_free)(self); + return 0; +} + +int nerv_float_matrix_(get_elem)(lua_State *L); +int nerv_float_matrix_(set_elem)(lua_State *L); + +static Matrix *nerv_float_matrix_(getrow)(Matrix *self, int row) { + Matrix *prow = (Matrix *)malloc(sizeof(Matrix)); + prow->ncol = self->ncol; + prow->nrow = 1; + prow->stride = self->stride; + prow->nmax = prow->ncol; + prow->data.f = (float *)((char *)self->data.f + row * self->stride); + prow->data_ref = self->data_ref; + nerv_float_matrix_(data_retain)(self); + return prow; +} + +static int nerv_float_matrix_(newindex)(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname); + if (lua_isnumber(L, 2)) + { + int idx = luaL_checkinteger(L, 2); + if (self->nrow == 1) + { + if (idx < 0 || idx >= self->ncol) + nerv_error(L, "index must be within range [0, %d)", self->ncol); + self->data.f[idx] = luaL_checknumber(L, 3); + } + else + nerv_error(L, "cannot assign a scalar to row vector"); + lua_pushboolean(L, 1); + return 2; + } + else + { + lua_pushboolean(L, 0); + return 1; + } +} + + +static int nerv_float_matrix_(index)(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname); + if (lua_isnumber(L, 2)) + { + int idx = luaL_checkinteger(L, 2); + if (self->nrow == 1) + { + if (idx < 0 || idx >= self->ncol) + nerv_error(L, "index must be within range [0, %d)", self->ncol); + lua_pushnumber(L, self->data.f[idx]); + } + else + { + if (idx < 0 || idx >= self->nrow) + nerv_error(L, "index must be within range [0, %d)", self->nrow); + luaT_pushudata(L, nerv_float_matrix_(getrow)(self, idx), nerv_float_matrix_tname); + } + lua_pushboolean(L, 1); + return 2; + } + else + { + lua_pushboolean(L, 0); + return 1; + } +} + +static int nerv_float_matrix_(ncol)(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname); + lua_pushinteger(L, self->ncol); + return 1; +} + +static int nerv_float_matrix_(nrow)(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname); + lua_pushinteger(L, self->nrow); + return 1; +} + + +static const luaL_Reg nerv_float_matrix_(methods)[] = { + {"get_elem", nerv_float_matrix_(get_elem)}, + {"set_elem", nerv_float_matrix_(set_elem)}, + {"ncol", nerv_float_matrix_(ncol)}, + {"nrow", nerv_float_matrix_(nrow)}, + {"__index__", nerv_float_matrix_(index)}, + {"__newindex__", nerv_float_matrix_(newindex)}, + {NULL, NULL} +}; + +void nerv_float_matrix_(init)(lua_State *L) { + luaT_newmetatable(L, nerv_float_matrix_tname, nerv_matrix_tname, + nerv_float_matrix_(new), nerv_float_matrix_(destroy), NULL); + luaL_register(L, NULL, nerv_float_matrix_(methods)); + lua_pop(L, 1); +} +#endif diff --git a/matrix/generic/matrix.h b/matrix/generic/matrix.h new file mode 100644 index 0000000..d02b56e --- /dev/null +++ b/matrix/generic/matrix.h @@ -0,0 +1,9 @@ +typedef struct Matrix { + long stride; /* size of a row */ + long ncol, nrow, nmax; /* dimension of the matrix */ + union { + float *f; + double *d; + } data; /* pointer to actual storage */ + long *data_ref; +} Matrix; diff --git a/matrix/init.c b/matrix/init.c new file mode 100644 index 0000000..e251628 --- /dev/null +++ b/matrix/init.c @@ -0,0 +1,19 @@ +#include "../common.h" +#include "generic/matrix.h" + +const char *nerv_matrix_tname = "nerv.Matrix"; +static const luaL_Reg matrix_methods[] = { + {"__tostring__", nerv_error_method_not_implemented }, + {"__add__", nerv_error_method_not_implemented }, + {"__sub__", nerv_error_method_not_implemented }, + {"__mul__", nerv_error_method_not_implemented }, + {NULL, NULL} +}; + +void nerv_matrix_init(lua_State *L) { + /* abstract class */ + luaT_newmetatable(L, nerv_matrix_tname, NULL, NULL, NULL, NULL); + luaL_register(L, NULL, matrix_methods); + lua_pop(L, 1); + nerv_float_matrix_host_init(L); +} diff --git a/matrix/matrix.c b/matrix/matrix.c new file mode 100644 index 0000000..3a593e5 --- /dev/null +++ b/matrix/matrix.c @@ -0,0 +1,26 @@ +#define MATRIX_DATA_FREE(ptr) free(ptr) +#define MATRIX_DATA_ALLOC(size) malloc(size) +#define MATRIX_DATA_STRIDE(ncol) (sizeof(float) * (ncol)) +#define MATRIX_GENERIC +#define nerv_float_matrix_(NAME) nerv_float_matrix_host_ ## NAME +#include "generic/matrix.c" + +int nerv_float_matrix_(get_elem)(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname); + int idx = luaL_checkinteger(L, 2); + if (idx < 0 || idx >= self->nmax) + nerv_error(L, "index must be within range [0, %d)", self->nmax); + lua_pushnumber(L, self->data.f[idx]); + return 1; +} + +int nerv_float_matrix_(set_elem)(lua_State *L) { + Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname); + int idx = luaL_checkinteger(L, 2); + float v = luaL_checknumber(L, 3); + long upper = self->nrow * self->ncol; + if (idx < 0 || idx >= self->nmax) + nerv_error(L, "index must be within range [0, %d)", self->nmax); + self->data.f[idx] = v; + return 0; +} diff --git a/matrix/matrix.lua b/matrix/matrix.lua new file mode 100644 index 0000000..7aa1f12 --- /dev/null +++ b/matrix/matrix.lua @@ -0,0 +1,15 @@ +function nerv.FloatMatrix:__tostring__() + local ncol = self:ncol() + local nrow = self:nrow() + local i = 0 + local strt = {} + for row = 0, nrow - 1 do + for col = 0, ncol - 1 do + table.insert(strt, string.format("%f ", self:get_elem(i))) + i = i + 1 + end + table.insert(strt, "\n") + end + table.insert(strt, string.format("[Float Matrix %d x %d]", nrow, ncol)) + return table.concat(strt) +end |