aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-05-24 19:02:09 +0800
committerDeterminant <[email protected]>2015-05-24 19:02:09 +0800
commit39e1834c449a55a44e95f2cfb6b24887fd3cec70 (patch)
tree272cbd5da6ad2208a86b2b7e6c960127b3cfecbe
parent63b529dc50ef0fc39e9279a976ab805ea9b11de7 (diff)
add nerv.class and try to let Lua class inherit from ParamFile
-rw-r--r--io/param.c47
-rw-r--r--matrix/cukernel.cu1
-rw-r--r--matrix/cumatrix.c1
-rw-r--r--matrix/generic/elem_type.h2
-rw-r--r--matrix/generic/matrix.c6
-rw-r--r--matrix/generic/mmatrix.c43
-rw-r--r--matrix/mmatrix.c1
-rw-r--r--nerv.c17
-rw-r--r--nerv.lua27
9 files changed, 130 insertions, 15 deletions
diff --git a/io/param.c b/io/param.c
index de6ba7a..dbd1438 100644
--- a/io/param.c
+++ b/io/param.c
@@ -1,5 +1,6 @@
#include <stdio.h>
#include <ctype.h>
+#include <string.h>
#include "../common.h"
#include "param.h"
@@ -57,9 +58,14 @@ const char *read_param_metadata(lua_State *L, FILE *fp, const char *fn) {
return buff;
}
-int nerv_param_file_new(lua_State *L) {
+int nerv_param_file_open_write(lua_State *L, const char *fn) {
+ FILE *fp = fopen(fn, "w");
+ if (!fp) nerv_error(L, "Error while opening param file: %s", fn);
+ lua_newtable(L);
+ return 1;
+}
- const char *fn = luaL_checkstring(L, 1);
+int nerv_param_file_open_read(lua_State *L, const char *fn) {
FILE *fp = fopen(fn, "r");
int i, status;
size_t param_len;
@@ -69,7 +75,6 @@ int nerv_param_file_new(lua_State *L) {
if (!fp) nerv_error(L, "Error while opening param file: %s", fn);
offset = ftello(fp);
lua_newtable(L);
- lua_newtable(L);
fprintf(stderr, "%d\n", (int)offset);
for (i = 0;; offset += param_len, i++)
{
@@ -106,6 +111,40 @@ int nerv_param_file_new(lua_State *L) {
return 1;
}
+int nerv_param_file___init(lua_State *L) {
+ const char *fn = luaL_checkstring(L, 2);
+ const char *mode = luaL_checkstring(L, 3);
+ int rd = 1, bin = 0;
+ size_t i, len = strlen(mode);
+ lua_pushvalue(L, 1);
+ for (i = 0; i < len; i++)
+ switch (mode[i])
+ {
+ case 'r': rd = 1; break;
+ case 'w': rd = 0; break;
+ case 'b': bin = 1; break;
+ }
+ return rd ? nerv_param_file_open_read(L, fn) : \
+ nerv_param_file_open_write(L, fn);
+}
+
+int nerv_param_file_new(lua_State *L) {
+ const char *fn = luaL_checkstring(L, 1);
+ const char *mode = luaL_checkstring(L, 2);
+ int rd = 1, bin = 0;
+ size_t i, len = strlen(mode);
+ for (i = 0; i < len; i++)
+ switch (mode[i])
+ {
+ case 'r': rd = 1; break;
+ case 'w': rd = 0; break;
+ case 'b': bin = 1; break;
+ }
+ lua_newtable(L);
+ return rd ? nerv_param_file_open_read(L, fn) : \
+ nerv_param_file_open_write(L, fn);
+}
+
int nerv_param_file_get_chunkdata(lua_State *L) {
ParamFileHandle *pfh;
ParamChunkInfo *pci;
@@ -124,7 +163,6 @@ int nerv_param_file_get_chunkdata(lua_State *L) {
luaT_pushudata(L, get_param_chunk_data(pfh->fp, pci),
nerv_param_chunk_data_tname);
- lua_setfield(L, -2, "data");
return 1;
}
@@ -152,6 +190,7 @@ static int nerv_param_chunk_data_destroy(lua_State *L) {
static const luaL_Reg nerv_param_file_methods[] = {
{"get_chunkdata", nerv_param_file_get_chunkdata},
+ {"__init", nerv_param_file___init},
{NULL, NULL}
};
diff --git a/matrix/cukernel.cu b/matrix/cukernel.cu
index e71ae49..fbac369 100644
--- a/matrix/cukernel.cu
+++ b/matrix/cukernel.cu
@@ -8,6 +8,7 @@
#undef MATRIX_USE_FLOAT
#undef MATRIX_ELEM
#undef MATRIX_ELEM_PTR
+#undef MATRIX_ELEM_FMT
#define cudak_(NAME) cudak_double_ ## NAME
#define MATRIX_USE_DOUBLE
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index 838183a..db4c784 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -14,6 +14,7 @@ const char *nerv_matrix_(tname) = "nerv.CuMatrixFloat";
#undef MATRIX_USE_FLOAT
#undef MATRIX_ELEM
#undef MATRIX_ELEM_PTR
+#undef MATRIX_ELEM_FMT
#define MATRIX_USE_DOUBLE
#define cuda_matrix_(NAME) cuda_matrix_double_##NAME
diff --git a/matrix/generic/elem_type.h b/matrix/generic/elem_type.h
index 8f80306..78233a3 100644
--- a/matrix/generic/elem_type.h
+++ b/matrix/generic/elem_type.h
@@ -1,11 +1,13 @@
#ifdef MATRIX_USE_FLOAT
#define MATRIX_ELEM float
+#define MATRIX_ELEM_FMT "%f"
#define MATRIX_ELEM_PTR(self) ((self)->data.f)
#elif defined(MATRIX_USE_DOUBLE)
#define MATRIX_ELEM double
+#define MATRIX_ELEM_FMT "%lf"
#define MATRIX_ELEM_PTR(self) ((self)->data.d)
#endif
diff --git a/matrix/generic/matrix.c b/matrix/generic/matrix.c
index b06ed89..417c534 100644
--- a/matrix/generic/matrix.c
+++ b/matrix/generic/matrix.c
@@ -2,6 +2,9 @@
#include "../../common.h"
#include "matrix.h"
+#define MATRIX_ROW_PTR(self, row) \
+ (MATRIX_ELEM *)((char *)MATRIX_ELEM_PTR(self) + (row) * (self)->stride)
+
extern const char *nerv_matrix_(tname);
extern const char *MATRIX_BASE_TNAME;
@@ -49,8 +52,7 @@ static Matrix *nerv_matrix_(getrow)(Matrix *self, int row) {
prow->nrow = 1;
prow->stride = self->stride;
prow->nmax = prow->ncol;
- MATRIX_ELEM_PTR(prow) = \
- (MATRIX_ELEM *)((char *)MATRIX_ELEM_PTR(self) + row * self->stride);
+ MATRIX_ELEM_PTR(prow) = MATRIX_ROW_PTR(self, row);
prow->data_ref = self->data_ref;
nerv_matrix_(data_retain)(self);
return prow;
diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c
index 6edac69..d37cd80 100644
--- a/matrix/generic/mmatrix.c
+++ b/matrix/generic/mmatrix.c
@@ -7,9 +7,11 @@
#define MATRIX_DATA_STRIDE(ncol) (sizeof(MATRIX_ELEM) * (ncol))
#define MATRIX_DATA_WRITE(data, idx, val) (data[idx] = val)
#define MATRIX_DATA_READ(data, idx) (data[idx])
+#define MATRIX_INIT(L) host_matrix_(init)(L)
#define MATRIX_BASE_TNAME nerv_matrix_host_tname
#define NERV_GENERIC_MATRIX
#include "../../common.h"
+#include "../../io/param.h"
static void host_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride,
long width, long height) {
@@ -17,15 +19,12 @@ static void host_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride,
*stride = width;
}
-static const luaL_Reg nerv_matrix_(extra_methods)[] = {
-};
-
int nerv_matrix_(get_elem)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
int idx = luaL_checkinteger(L, 2);
if (idx < 0 || idx >= self->nmax)
nerv_error(L, "index must be within range [0, %d)", self->nmax);
- lua_pushnumber(L, self->data.f[idx]);
+ lua_pushnumber(L, MATRIX_ELEM_PTR(self)[idx]);
return 1;
}
@@ -35,9 +34,43 @@ int nerv_matrix_(set_elem)(lua_State *L) {
MATRIX_ELEM v = luaL_checknumber(L, 3);
if (idx < 0 || idx >= self->nmax)
nerv_error(L, "index must be within range [0, %d)", self->nmax);
- self->data.f[idx] = v;
+ MATRIX_ELEM_PTR(self)[idx] = v;
return 0;
}
+static const luaL_Reg nerv_matrix_(extra_methods)[];
+static void host_matrix_(init)(lua_State *L) {
+ luaN_append_methods(L, nerv_matrix_(extra_methods));
+}
+
#include "matrix.c"
+
+int nerv_matrix_(load)(lua_State *L) {
+ ParamChunkData *chunk = luaT_checkudata(L, 1, nerv_param_chunk_data_tname);
+ Matrix *self;
+ int i, j;
+ long nrow, ncol;
+ FILE *fp = chunk->fp;
+ if (fscanf(fp, "%ld %ld", &nrow, &ncol) != 2)
+ return 0;
+ self = nerv_matrix_(new_)(nrow, ncol);
+ for (i = 0; i < nrow; i++)
+ {
+ MATRIX_ELEM *row = MATRIX_ROW_PTR(self, i);
+ for (j = 0; j < ncol; j++)
+ if (fscanf(fp, MATRIX_ELEM_FMT, row + j) != 1)
+ {
+ free(self);
+ return 0;
+ }
+ }
+ luaT_pushudata(L, self, nerv_matrix_(tname));
+ return 1;
+}
+
+static const luaL_Reg nerv_matrix_(extra_methods)[] = {
+ {"load", nerv_matrix_(load)},
+ {NULL, NULL}
+};
+
#endif
diff --git a/matrix/mmatrix.c b/matrix/mmatrix.c
index ffb02ac..b7d7dae 100644
--- a/matrix/mmatrix.c
+++ b/matrix/mmatrix.c
@@ -9,6 +9,7 @@ const char *nerv_matrix_(tname) = "nerv.MMatrixFloat";
#undef MATRIX_USE_FLOAT
#undef MATRIX_ELEM
#undef MATRIX_ELEM_PTR
+#undef MATRIX_ELEM_FMT
#define NERV_GENERIC_MMATRIX
#define MATRIX_USE_DOUBLE
diff --git a/nerv.c b/nerv.c
index 55ae5b6..d586867 100644
--- a/nerv.c
+++ b/nerv.c
@@ -1,14 +1,25 @@
-#include "lua.h"
-#include "lauxlib.h"
-#include "lualib.h"
+#include "common.h"
extern void nerv_point_init(lua_State *L);
extern void nerv_matrix_init(lua_State *L);
extern void nerv_param_init(lua_State *L);
+static const luaL_Reg nerv_utils_methods[] = {
+ {"setmetatable", luaT_lua_setmetatable},
+ {"getmetatable", luaT_lua_getmetatable},
+ {"newmetatable", luaT_lua_newmetatable},
+ {NULL, NULL}
+};
+
+void nerv_utils_init(lua_State *L) {
+ luaL_register(L, NULL, nerv_utils_methods);
+}
+
int luaopen_libnerv(lua_State *L) {
lua_newtable(L);
+ lua_pushvalue(L, -1);
lua_setfield(L, LUA_GLOBALSINDEX, "nerv");
+ nerv_utils_init(L);
nerv_point_init(L);
nerv_matrix_init(L);
nerv_param_init(L);
diff --git a/nerv.lua b/nerv.lua
index ccff0a0..8e03cb2 100644
--- a/nerv.lua
+++ b/nerv.lua
@@ -1,6 +1,6 @@
require 'libnerv'
require 'matrix.init'
-nerv.class = require 'pl.class'
+-- nerv.class = require 'pl.class'
nerv.utils = require 'pl.utils'
function nerv.error(fmt, ...)
@@ -10,3 +10,28 @@ end
function nerv.error_method_not_implement()
nerv.error("method not implemented");
end
+
+function nerv.class(tname, parenttname)
+
+ local function constructor(...)
+ local self = {}
+ nerv.setmetatable(self, tname)
+ if self.__init then
+ self:__init(...)
+ end
+ return self
+ end
+
+ local function factory()
+ local self = {}
+ nerv.setmetatable(self, tname)
+ return self
+ end
+
+ local mt = nerv.newmetatable(tname, parenttname, constructor, nil, factory)
+ local mpt
+ if parenttname then
+ mpt = nerv.getmetatable(parenttname)
+ end
+ return mt, mpt
+end