aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile3
-rw-r--r--io/init.lua7
-rw-r--r--io/param.c83
-rw-r--r--matrix/generic/mmatrix.c29
-rw-r--r--nerv.lua37
5 files changed, 156 insertions, 3 deletions
diff --git a/Makefile b/Makefile
index d766cd4..30c91ce 100644
--- a/Makefile
+++ b/Makefile
@@ -1,7 +1,7 @@
.PHONY: all clean luajit
OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o io/init.o io/param.o
LIBS := libnerv.so
-LUA_LIBS := matrix/init.lua nerv.lua pl/utils.lua pl/compat.lua
+LUA_LIBS := matrix/init.lua io/init.lua nerv.lua pl/utils.lua pl/compat.lua
INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK
CUDA_BASE := /usr/local/cuda-6.5
CUDA_INCLUDE := -I $(CUDA_BASE)/include/
@@ -25,6 +25,7 @@ $(OBJ_DIR):
-mkdir -p $(OBJ_DIR)/matrix
-mkdir -p $(LUA_DIR)/matrix
-mkdir -p $(OBJ_DIR)/io
+ -mkdir -p $(LUA_DIR)/io
-mkdir -p $(LUA_DIR)/pl
$(LUA_DIR):
-mkdir -p $(LUA_DIR)
diff --git a/io/init.lua b/io/init.lua
new file mode 100644
index 0000000..1288bd4
--- /dev/null
+++ b/io/init.lua
@@ -0,0 +1,7 @@
+function nerv.ParamFile:write_chunkdata(metadata, writer)
+ if type(metadata) ~= "table" then
+ nerv.error("metadata should be a Lua table")
+ return
+ end
+ return self:_write_chunkdata(table.tostring(metadata), writer)
+end
diff --git a/io/param.c b/io/param.c
index dbd1438..91c4d26 100644
--- a/io/param.c
+++ b/io/param.c
@@ -21,7 +21,9 @@ const char *nerv_param_chunk_data_tname = "nerv.ParamChunkData";
enum {
NORMAL,
INVALID_FORMAT,
- END_OF_FILE
+ END_OF_FILE,
+ SECTION_OVERFLOW,
+ WRITE_ERROR
};
size_t read_param_header_plain(FILE *fp, int *status) {
@@ -41,6 +43,35 @@ size_t read_param_header_plain(FILE *fp, int *status) {
return size;
}
+#define CHECK_WRITE(status) \
+ do { \
+ if (status == SECTION_OVERFLOW) \
+ nerv_error(L, "section overflowed"); \
+ else if (status == WRITE_ERROR) \
+ nerv_error(L, "error while writing"); \
+ } while (0)
+
+void write_param_header_plain(FILE *fp, size_t size, int *status) {
+ static char buff[PARAM_HEADER_SIZE];
+ int i;
+ *status = NORMAL;
+ for (i = PARAM_HEADER_SIZE - 3; i > 0; i--, size /= 10)
+ buff[i] = size % 10 + '0';
+ if (size)
+ {
+ *status = SECTION_OVERFLOW;
+ return;
+ }
+ buff[0] = '[';
+ buff[PARAM_HEADER_SIZE - 2] = ']';
+ buff[PARAM_HEADER_SIZE - 1] = '\n';
+ if (fwrite(buff, 1, PARAM_HEADER_SIZE, fp) != PARAM_HEADER_SIZE)
+ {
+ *status = WRITE_ERROR;
+ return;
+ }
+}
+
ParamChunkData *get_param_chunk_data(FILE *fp, ParamChunkInfo *info) {
ParamChunkData *pcd = (ParamChunkData *)malloc(sizeof(ParamChunkData));
pcd->data = (char *)malloc(info->length);
@@ -58,10 +89,29 @@ const char *read_param_metadata(lua_State *L, FILE *fp, const char *fn) {
return buff;
}
+void write_param_metadata(FILE *fp, const char *metadata_str, int *status) {
+ size_t size = strlen(metadata_str);
+ *status = NORMAL;
+ if (fwrite(metadata_str, 1, size, fp) != size ||
+ fprintf(fp, "\n") < 0)
+ {
+ *status = WRITE_ERROR;
+ return;
+ }
+ fprintf(stderr, "metadata: %s\n", metadata_str);
+}
+
+
int nerv_param_file_open_write(lua_State *L, const char *fn) {
FILE *fp = fopen(fn, "w");
+ ParamFileHandle *lfp;
if (!fp) nerv_error(L, "Error while opening param file: %s", fn);
- lua_newtable(L);
+ lfp = (ParamFileHandle *)malloc(sizeof(ParamFileHandle));
+ lfp->fp = fp;
+ luaT_pushudata(L, lfp, nerv_param_file_handle_tname);
+ lua_setfield(L, -2, "handle");
+ luaT_pushmetatable(L, nerv_param_file_tname);
+ lua_setmetatable(L, -2);
return 1;
}
@@ -145,6 +195,34 @@ int nerv_param_file_new(lua_State *L) {
nerv_param_file_open_write(L, fn);
}
+int nerv_param_file_write_chunkdata(lua_State *L) {
+ ParamFileHandle *pfh;
+ int status;
+ off_t start;
+ size_t size;
+ const char *metadata_str = lua_tolstring(L, 2, NULL);
+ lua_getfield(L, 1, "handle");
+ pfh = luaT_checkudata(L, -1, nerv_param_file_handle_tname);
+ start = ftello(pfh->fp);
+ write_param_header_plain(pfh->fp, 0, &status); /* fill zeros */
+ CHECK_WRITE(status);
+ write_param_metadata(pfh->fp, metadata_str, &status);
+ CHECK_WRITE(status);
+ lua_getfield(L, 3, "save");
+ if (lua_type(L, -1) != LUA_TFUNCTION)
+ nerv_error(L, "\"save\" method must be implemented");
+ lua_pushvalue(L, 3);
+ lua_pushvalue(L, -3); /* pass handle as parameter to save() */
+ lua_call(L, 2, 0); /* let the save() to write */
+ size = ftello(pfh->fp) - start;
+ fseeko(pfh->fp, start, SEEK_SET);
+ /* write the calced size */
+ write_param_header_plain(pfh->fp, size, &status);
+ CHECK_WRITE(status);
+ fseeko(pfh->fp, 0, SEEK_END);
+ return 0;
+}
+
int nerv_param_file_get_chunkdata(lua_State *L) {
ParamFileHandle *pfh;
ParamChunkInfo *pci;
@@ -190,6 +268,7 @@ static int nerv_param_chunk_data_destroy(lua_State *L) {
static const luaL_Reg nerv_param_file_methods[] = {
{"get_chunkdata", nerv_param_file_get_chunkdata},
+ {"_write_chunkdata", nerv_param_file_write_chunkdata},
{"__init", nerv_param_file___init},
{NULL, NULL}
};
diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c
index d37cd80..c25b9f7 100644
--- a/matrix/generic/mmatrix.c
+++ b/matrix/generic/mmatrix.c
@@ -68,8 +68,37 @@ int nerv_matrix_(load)(lua_State *L) {
return 1;
}
+int nerv_matrix_(save)(lua_State *L) {
+ ParamFileHandle *chunk = luaT_checkudata(L, 2,
+ nerv_param_file_handle_tname);
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ int i, j;
+ long nrow = self->nrow, ncol = self->ncol;
+ FILE *fp = chunk->fp;
+ if (fprintf(fp, "%ld %ld\n", nrow, ncol) < 0)
+ return 0;
+ for (i = 0; i < nrow; i++)
+ {
+ MATRIX_ELEM *row = MATRIX_ROW_PTR(self, i);
+ for (j = 0; j < ncol; j++)
+ if (fprintf(fp, MATRIX_ELEM_FMT " ", row[j]) < 0)
+ {
+ free(self);
+ return 0;
+ }
+ if (fprintf(fp, "\n") < 0)
+ {
+ free(self);
+ return 0;
+ }
+ }
+ return 0;
+}
+
+
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"load", nerv_matrix_(load)},
+ {"save", nerv_matrix_(save)},
{NULL, NULL}
};
diff --git a/nerv.lua b/nerv.lua
index 33b1aff..1c4ba39 100644
--- a/nerv.lua
+++ b/nerv.lua
@@ -1,5 +1,6 @@
require 'libnerv'
require 'matrix.init'
+require 'io.init'
-- nerv.class = require 'pl.class'
nerv.utils = require 'pl.utils'
@@ -35,3 +36,39 @@ function nerv.class(tname, parenttname)
end
return mt, mpt
end
+
+function table.val_to_str(v)
+ if "string" == type(v) then
+ v = string.gsub(v, "\n", "\\n")
+ if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then
+ return "'" .. v .. "'"
+ end
+ return '"' .. string.gsub(v,'"', '\\"') .. '"'
+ else
+ return "table" == type(v) and table.tostring(v) or
+ tostring(v)
+ end
+end
+
+function table.key_to_str (k)
+ if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then
+ return k
+ else
+ return "[" .. table.val_to_str(k) .. "]"
+ end
+end
+
+function table.tostring(tbl)
+ local result, done = {}, {}
+ for k, v in ipairs(tbl) do
+ table.insert(result, table.val_to_str(v))
+ done[k] = true
+ end
+ for k, v in pairs(tbl) do
+ if not done[k] then
+ table.insert(result,
+ table.key_to_str(k) .. "=" .. table.val_to_str(v))
+ end
+ end
+ return "{" .. table.concat(result, ",") .. "}"
+end