aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic/mmatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic/mmatrix.c')
-rw-r--r--matrix/generic/mmatrix.c43
1 files changed, 38 insertions, 5 deletions
diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c
index 6edac69..d37cd80 100644
--- a/matrix/generic/mmatrix.c
+++ b/matrix/generic/mmatrix.c
@@ -7,9 +7,11 @@
#define MATRIX_DATA_STRIDE(ncol) (sizeof(MATRIX_ELEM) * (ncol))
#define MATRIX_DATA_WRITE(data, idx, val) (data[idx] = val)
#define MATRIX_DATA_READ(data, idx) (data[idx])
+#define MATRIX_INIT(L) host_matrix_(init)(L)
#define MATRIX_BASE_TNAME nerv_matrix_host_tname
#define NERV_GENERIC_MATRIX
#include "../../common.h"
+#include "../../io/param.h"
static void host_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride,
long width, long height) {
@@ -17,15 +19,12 @@ static void host_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride,
*stride = width;
}
-static const luaL_Reg nerv_matrix_(extra_methods)[] = {
-};
-
int nerv_matrix_(get_elem)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
int idx = luaL_checkinteger(L, 2);
if (idx < 0 || idx >= self->nmax)
nerv_error(L, "index must be within range [0, %d)", self->nmax);
- lua_pushnumber(L, self->data.f[idx]);
+ lua_pushnumber(L, MATRIX_ELEM_PTR(self)[idx]);
return 1;
}
@@ -35,9 +34,43 @@ int nerv_matrix_(set_elem)(lua_State *L) {
MATRIX_ELEM v = luaL_checknumber(L, 3);
if (idx < 0 || idx >= self->nmax)
nerv_error(L, "index must be within range [0, %d)", self->nmax);
- self->data.f[idx] = v;
+ MATRIX_ELEM_PTR(self)[idx] = v;
return 0;
}
+static const luaL_Reg nerv_matrix_(extra_methods)[];
+static void host_matrix_(init)(lua_State *L) {
+ luaN_append_methods(L, nerv_matrix_(extra_methods));
+}
+
#include "matrix.c"
+
+int nerv_matrix_(load)(lua_State *L) {
+ ParamChunkData *chunk = luaT_checkudata(L, 1, nerv_param_chunk_data_tname);
+ Matrix *self;
+ int i, j;
+ long nrow, ncol;
+ FILE *fp = chunk->fp;
+ if (fscanf(fp, "%ld %ld", &nrow, &ncol) != 2)
+ return 0;
+ self = nerv_matrix_(new_)(nrow, ncol);
+ for (i = 0; i < nrow; i++)
+ {
+ MATRIX_ELEM *row = MATRIX_ROW_PTR(self, i);
+ for (j = 0; j < ncol; j++)
+ if (fscanf(fp, MATRIX_ELEM_FMT, row + j) != 1)
+ {
+ free(self);
+ return 0;
+ }
+ }
+ luaT_pushudata(L, self, nerv_matrix_(tname));
+ return 1;
+}
+
+static const luaL_Reg nerv_matrix_(extra_methods)[] = {
+ {"load", nerv_matrix_(load)},
+ {NULL, NULL}
+};
+
#endif