From 3ccca6ba1eb7b6732036f4128c977c0b02ef3836 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sun, 24 May 2015 20:35:40 +0800 Subject: add write functionality to ParamFile --- Makefile | 3 +- io/init.lua | 7 ++++ io/param.c | 83 ++++++++++++++++++++++++++++++++++++++++++++++-- matrix/generic/mmatrix.c | 29 +++++++++++++++++ nerv.lua | 37 +++++++++++++++++++++ 5 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 io/init.lua diff --git a/Makefile b/Makefile index d766cd4..30c91ce 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: all clean luajit OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o io/init.o io/param.o LIBS := libnerv.so -LUA_LIBS := matrix/init.lua nerv.lua pl/utils.lua pl/compat.lua +LUA_LIBS := matrix/init.lua io/init.lua nerv.lua pl/utils.lua pl/compat.lua INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK CUDA_BASE := /usr/local/cuda-6.5 CUDA_INCLUDE := -I $(CUDA_BASE)/include/ @@ -25,6 +25,7 @@ $(OBJ_DIR): -mkdir -p $(OBJ_DIR)/matrix -mkdir -p $(LUA_DIR)/matrix -mkdir -p $(OBJ_DIR)/io + -mkdir -p $(LUA_DIR)/io -mkdir -p $(LUA_DIR)/pl $(LUA_DIR): -mkdir -p $(LUA_DIR) diff --git a/io/init.lua b/io/init.lua new file mode 100644 index 0000000..1288bd4 --- /dev/null +++ b/io/init.lua @@ -0,0 +1,7 @@ +function nerv.ParamFile:write_chunkdata(metadata, writer) + if type(metadata) ~= "table" then + nerv.error("metadata should be a Lua table") + return + end + return self:_write_chunkdata(table.tostring(metadata), writer) +end diff --git a/io/param.c b/io/param.c index dbd1438..91c4d26 100644 --- a/io/param.c +++ b/io/param.c @@ -21,7 +21,9 @@ const char *nerv_param_chunk_data_tname = "nerv.ParamChunkData"; enum { NORMAL, INVALID_FORMAT, - END_OF_FILE + END_OF_FILE, + SECTION_OVERFLOW, + WRITE_ERROR }; size_t read_param_header_plain(FILE *fp, int *status) { @@ -41,6 +43,35 @@ size_t read_param_header_plain(FILE *fp, int *status) { return size; } +#define CHECK_WRITE(status) \ + do { \ + if (status == SECTION_OVERFLOW) \ + nerv_error(L, "section overflowed"); \ + else if (status == WRITE_ERROR) \ + nerv_error(L, "error while writing"); \ + } while (0) + +void write_param_header_plain(FILE *fp, size_t size, int *status) { + static char buff[PARAM_HEADER_SIZE]; + int i; + *status = NORMAL; + for (i = PARAM_HEADER_SIZE - 3; i > 0; i--, size /= 10) + buff[i] = size % 10 + '0'; + if (size) + { + *status = SECTION_OVERFLOW; + return; + } + buff[0] = '['; + buff[PARAM_HEADER_SIZE - 2] = ']'; + buff[PARAM_HEADER_SIZE - 1] = '\n'; + if (fwrite(buff, 1, PARAM_HEADER_SIZE, fp) != PARAM_HEADER_SIZE) + { + *status = WRITE_ERROR; + return; + } +} + ParamChunkData *get_param_chunk_data(FILE *fp, ParamChunkInfo *info) { ParamChunkData *pcd = (ParamChunkData *)malloc(sizeof(ParamChunkData)); pcd->data = (char *)malloc(info->length); @@ -58,10 +89,29 @@ const char *read_param_metadata(lua_State *L, FILE *fp, const char *fn) { return buff; } +void write_param_metadata(FILE *fp, const char *metadata_str, int *status) { + size_t size = strlen(metadata_str); + *status = NORMAL; + if (fwrite(metadata_str, 1, size, fp) != size || + fprintf(fp, "\n") < 0) + { + *status = WRITE_ERROR; + return; + } + fprintf(stderr, "metadata: %s\n", metadata_str); +} + + int nerv_param_file_open_write(lua_State *L, const char *fn) { FILE *fp = fopen(fn, "w"); + ParamFileHandle *lfp; if (!fp) nerv_error(L, "Error while opening param file: %s", fn); - lua_newtable(L); + lfp = (ParamFileHandle *)malloc(sizeof(ParamFileHandle)); + lfp->fp = fp; + luaT_pushudata(L, lfp, nerv_param_file_handle_tname); + lua_setfield(L, -2, "handle"); + luaT_pushmetatable(L, nerv_param_file_tname); + lua_setmetatable(L, -2); return 1; } @@ -145,6 +195,34 @@ int nerv_param_file_new(lua_State *L) { nerv_param_file_open_write(L, fn); } +int nerv_param_file_write_chunkdata(lua_State *L) { + ParamFileHandle *pfh; + int status; + off_t start; + size_t size; + const char *metadata_str = lua_tolstring(L, 2, NULL); + lua_getfield(L, 1, "handle"); + pfh = luaT_checkudata(L, -1, nerv_param_file_handle_tname); + start = ftello(pfh->fp); + write_param_header_plain(pfh->fp, 0, &status); /* fill zeros */ + CHECK_WRITE(status); + write_param_metadata(pfh->fp, metadata_str, &status); + CHECK_WRITE(status); + lua_getfield(L, 3, "save"); + if (lua_type(L, -1) != LUA_TFUNCTION) + nerv_error(L, "\"save\" method must be implemented"); + lua_pushvalue(L, 3); + lua_pushvalue(L, -3); /* pass handle as parameter to save() */ + lua_call(L, 2, 0); /* let the save() to write */ + size = ftello(pfh->fp) - start; + fseeko(pfh->fp, start, SEEK_SET); + /* write the calced size */ + write_param_header_plain(pfh->fp, size, &status); + CHECK_WRITE(status); + fseeko(pfh->fp, 0, SEEK_END); + return 0; +} + int nerv_param_file_get_chunkdata(lua_State *L) { ParamFileHandle *pfh; ParamChunkInfo *pci; @@ -190,6 +268,7 @@ static int nerv_param_chunk_data_destroy(lua_State *L) { static const luaL_Reg nerv_param_file_methods[] = { {"get_chunkdata", nerv_param_file_get_chunkdata}, + {"_write_chunkdata", nerv_param_file_write_chunkdata}, {"__init", nerv_param_file___init}, {NULL, NULL} }; diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c index d37cd80..c25b9f7 100644 --- a/matrix/generic/mmatrix.c +++ b/matrix/generic/mmatrix.c @@ -68,8 +68,37 @@ int nerv_matrix_(load)(lua_State *L) { return 1; } +int nerv_matrix_(save)(lua_State *L) { + ParamFileHandle *chunk = luaT_checkudata(L, 2, + nerv_param_file_handle_tname); + Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); + int i, j; + long nrow = self->nrow, ncol = self->ncol; + FILE *fp = chunk->fp; + if (fprintf(fp, "%ld %ld\n", nrow, ncol) < 0) + return 0; + for (i = 0; i < nrow; i++) + { + MATRIX_ELEM *row = MATRIX_ROW_PTR(self, i); + for (j = 0; j < ncol; j++) + if (fprintf(fp, MATRIX_ELEM_FMT " ", row[j]) < 0) + { + free(self); + return 0; + } + if (fprintf(fp, "\n") < 0) + { + free(self); + return 0; + } + } + return 0; +} + + static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"load", nerv_matrix_(load)}, + {"save", nerv_matrix_(save)}, {NULL, NULL} }; diff --git a/nerv.lua b/nerv.lua index 33b1aff..1c4ba39 100644 --- a/nerv.lua +++ b/nerv.lua @@ -1,5 +1,6 @@ require 'libnerv' require 'matrix.init' +require 'io.init' -- nerv.class = require 'pl.class' nerv.utils = require 'pl.utils' @@ -35,3 +36,39 @@ function nerv.class(tname, parenttname) end return mt, mpt end + +function table.val_to_str(v) + if "string" == type(v) then + v = string.gsub(v, "\n", "\\n") + if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then + return "'" .. v .. "'" + end + return '"' .. string.gsub(v,'"', '\\"') .. '"' + else + return "table" == type(v) and table.tostring(v) or + tostring(v) + end +end + +function table.key_to_str (k) + if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then + return k + else + return "[" .. table.val_to_str(k) .. "]" + end +end + +function table.tostring(tbl) + local result, done = {}, {} + for k, v in ipairs(tbl) do + table.insert(result, table.val_to_str(v)) + done[k] = true + end + for k, v in pairs(tbl) do + if not done[k] then + table.insert(result, + table.key_to_str(k) .. "=" .. table.val_to_str(v)) + end + end + return "{" .. table.concat(result, ",") .. "}" +end -- cgit v1.2.3