summaryrefslogtreecommitdiff
path: root/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'matrix')
-rw-r--r--matrix/cukernel.cu1
-rw-r--r--matrix/cumatrix.c1
-rw-r--r--matrix/generic/elem_type.h2
-rw-r--r--matrix/generic/matrix.c6
-rw-r--r--matrix/generic/mmatrix.c43
-rw-r--r--matrix/mmatrix.c1
6 files changed, 47 insertions, 7 deletions
diff --git a/matrix/cukernel.cu b/matrix/cukernel.cu
index e71ae49..fbac369 100644
--- a/matrix/cukernel.cu
+++ b/matrix/cukernel.cu
@@ -8,6 +8,7 @@
#undef MATRIX_USE_FLOAT
#undef MATRIX_ELEM
#undef MATRIX_ELEM_PTR
+#undef MATRIX_ELEM_FMT
#define cudak_(NAME) cudak_double_ ## NAME
#define MATRIX_USE_DOUBLE
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index 838183a..db4c784 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -14,6 +14,7 @@ const char *nerv_matrix_(tname) = "nerv.CuMatrixFloat";
#undef MATRIX_USE_FLOAT
#undef MATRIX_ELEM
#undef MATRIX_ELEM_PTR
+#undef MATRIX_ELEM_FMT
#define MATRIX_USE_DOUBLE
#define cuda_matrix_(NAME) cuda_matrix_double_##NAME
diff --git a/matrix/generic/elem_type.h b/matrix/generic/elem_type.h
index 8f80306..78233a3 100644
--- a/matrix/generic/elem_type.h
+++ b/matrix/generic/elem_type.h
@@ -1,11 +1,13 @@
#ifdef MATRIX_USE_FLOAT
#define MATRIX_ELEM float
+#define MATRIX_ELEM_FMT "%f"
#define MATRIX_ELEM_PTR(self) ((self)->data.f)
#elif defined(MATRIX_USE_DOUBLE)
#define MATRIX_ELEM double
+#define MATRIX_ELEM_FMT "%lf"
#define MATRIX_ELEM_PTR(self) ((self)->data.d)
#endif
diff --git a/matrix/generic/matrix.c b/matrix/generic/matrix.c
index b06ed89..417c534 100644
--- a/matrix/generic/matrix.c
+++ b/matrix/generic/matrix.c
@@ -2,6 +2,9 @@
#include "../../common.h"
#include "matrix.h"
+#define MATRIX_ROW_PTR(self, row) \
+ (MATRIX_ELEM *)((char *)MATRIX_ELEM_PTR(self) + (row) * (self)->stride)
+
extern const char *nerv_matrix_(tname);
extern const char *MATRIX_BASE_TNAME;
@@ -49,8 +52,7 @@ static Matrix *nerv_matrix_(getrow)(Matrix *self, int row) {
prow->nrow = 1;
prow->stride = self->stride;
prow->nmax = prow->ncol;
- MATRIX_ELEM_PTR(prow) = \
- (MATRIX_ELEM *)((char *)MATRIX_ELEM_PTR(self) + row * self->stride);
+ MATRIX_ELEM_PTR(prow) = MATRIX_ROW_PTR(self, row);
prow->data_ref = self->data_ref;
nerv_matrix_(data_retain)(self);
return prow;
diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c
index 6edac69..d37cd80 100644
--- a/matrix/generic/mmatrix.c
+++ b/matrix/generic/mmatrix.c
@@ -7,9 +7,11 @@
#define MATRIX_DATA_STRIDE(ncol) (sizeof(MATRIX_ELEM) * (ncol))
#define MATRIX_DATA_WRITE(data, idx, val) (data[idx] = val)
#define MATRIX_DATA_READ(data, idx) (data[idx])
+#define MATRIX_INIT(L) host_matrix_(init)(L)
#define MATRIX_BASE_TNAME nerv_matrix_host_tname
#define NERV_GENERIC_MATRIX
#include "../../common.h"
+#include "../../io/param.h"
static void host_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride,
long width, long height) {
@@ -17,15 +19,12 @@ static void host_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride,
*stride = width;
}
-static const luaL_Reg nerv_matrix_(extra_methods)[] = {
-};
-
int nerv_matrix_(get_elem)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
int idx = luaL_checkinteger(L, 2);
if (idx < 0 || idx >= self->nmax)
nerv_error(L, "index must be within range [0, %d)", self->nmax);
- lua_pushnumber(L, self->data.f[idx]);
+ lua_pushnumber(L, MATRIX_ELEM_PTR(self)[idx]);
return 1;
}
@@ -35,9 +34,43 @@ int nerv_matrix_(set_elem)(lua_State *L) {
MATRIX_ELEM v = luaL_checknumber(L, 3);
if (idx < 0 || idx >= self->nmax)
nerv_error(L, "index must be within range [0, %d)", self->nmax);
- self->data.f[idx] = v;
+ MATRIX_ELEM_PTR(self)[idx] = v;
return 0;
}
+static const luaL_Reg nerv_matrix_(extra_methods)[];
+static void host_matrix_(init)(lua_State *L) {
+ luaN_append_methods(L, nerv_matrix_(extra_methods));
+}
+
#include "matrix.c"
+
+int nerv_matrix_(load)(lua_State *L) {
+ ParamChunkData *chunk = luaT_checkudata(L, 1, nerv_param_chunk_data_tname);
+ Matrix *self;
+ int i, j;
+ long nrow, ncol;
+ FILE *fp = chunk->fp;
+ if (fscanf(fp, "%ld %ld", &nrow, &ncol) != 2)
+ return 0;
+ self = nerv_matrix_(new_)(nrow, ncol);
+ for (i = 0; i < nrow; i++)
+ {
+ MATRIX_ELEM *row = MATRIX_ROW_PTR(self, i);
+ for (j = 0; j < ncol; j++)
+ if (fscanf(fp, MATRIX_ELEM_FMT, row + j) != 1)
+ {
+ free(self);
+ return 0;
+ }
+ }
+ luaT_pushudata(L, self, nerv_matrix_(tname));
+ return 1;
+}
+
+static const luaL_Reg nerv_matrix_(extra_methods)[] = {
+ {"load", nerv_matrix_(load)},
+ {NULL, NULL}
+};
+
#endif
diff --git a/matrix/mmatrix.c b/matrix/mmatrix.c
index ffb02ac..b7d7dae 100644
--- a/matrix/mmatrix.c
+++ b/matrix/mmatrix.c
@@ -9,6 +9,7 @@ const char *nerv_matrix_(tname) = "nerv.MMatrixFloat";
#undef MATRIX_USE_FLOAT
#undef MATRIX_ELEM
#undef MATRIX_ELEM_PTR
+#undef MATRIX_ELEM_FMT
#define NERV_GENERIC_MMATRIX
#define MATRIX_USE_DOUBLE