From e21e2d9480c83fee13b2e721417cc04fe8036ced Mon Sep 17 00:00:00 2001 From: Determinant Date: Sun, 24 May 2015 15:39:24 +0800 Subject: add param file implementation --- Makefile | 6 +- class.lua | 250 ----------------------------------------------- io/init.c | 6 ++ io/param.c | 171 ++++++++++++++++++++++++++++++++ io/param.h | 22 +++++ matrix/generic/mmatrix.c | 3 + nerv.c | 2 + nerv.lua | 11 ++- 8 files changed, 218 insertions(+), 253 deletions(-) delete mode 100644 class.lua create mode 100644 io/init.c create mode 100644 io/param.c create mode 100644 io/param.h diff --git a/Makefile b/Makefile index fa323a1..83b0df8 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: all clean luajit -OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o +OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o io/init.o io/param.o LIBS := libnerv.so -LUA_LIBS := matrix/init.lua nerv.lua class.lua +LUA_LIBS := matrix/init.lua nerv.lua pl/class.lua pl/utils.lua pl/compat.lua INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK CUDA_BASE := /usr/local/cuda-6.5 CUDA_INCLUDE := -I $(CUDA_BASE)/include/ @@ -24,6 +24,8 @@ $(OBJ_DIR): -mkdir -p $(OBJ_DIR) -mkdir -p $(OBJ_DIR)/matrix -mkdir -p $(LUA_DIR)/matrix + -mkdir -p $(OBJ_DIR)/io + -mkdir -p $(LUA_DIR)/pl $(LUA_DIR): -mkdir -p $(LUA_DIR) $(OBJ_DIR)/%.o: %.c diff --git a/class.lua b/class.lua deleted file mode 100644 index d260c31..0000000 --- a/class.lua +++ /dev/null @@ -1,250 +0,0 @@ ---- Provides a reuseable and convenient framework for creating classes in Lua. --- Two possible notations: --- --- B = class(A) --- class.B(A) --- --- The latter form creates a named class within the current environment. Note --- that this implicitly brings in `pl.utils` as a dependency. --- --- See the Guide for further @{01-introduction.md.Simplifying_Object_Oriented_Programming_in_Lua|discussion} --- @module pl.class - -local error, getmetatable, io, pairs, rawget, rawset, setmetatable, tostring, type = - _G.error, _G.getmetatable, _G.io, _G.pairs, _G.rawget, _G.rawset, _G.setmetatable, _G.tostring, _G.type -local compat - --- this trickery is necessary to prevent the inheritance of 'super' and --- the resulting recursive call problems. -local function call_ctor (c,obj,...) - -- nice alias for the base class ctor - local base = rawget(c,'_base') - if base then - local parent_ctor = rawget(base,'_init') - while not parent_ctor do - base = rawget(base,'_base') - if not base then break end - parent_ctor = rawget(base,'_init') - end - if parent_ctor then - rawset(obj,'super',function(obj,...) - call_ctor(base,obj,...) - end) - end - end - local res = c._init(obj,...) - rawset(obj,'super',nil) - return res -end - ---- initializes an __instance__ upon creation. --- @function class:_init --- @param ... parameters passed to the constructor --- @usage local Cat = class() --- function Cat:_init(name) --- --self:super(name) -- call the ancestor initializer if needed --- self.name = name --- end --- --- local pussycat = Cat("pussycat") --- print(pussycat.name) --> pussycat - ---- checks whether an __instance__ is derived from some class. --- Works the other way around as `class_of`. --- @function instance:is_a --- @param some_class class to check against --- @return `true` if `instance` is derived from `some_class` --- @usage local pussycat = Lion() -- assuming Lion derives from Cat --- if pussycat:is_a(Cat) then --- -- it's true --- end -local function is_a(self,klass) - local m = getmetatable(self) - if not m then return false end --*can't be an object! - while m do - if m == klass then return true end - m = rawget(m,'_base') - end - return false -end - ---- checks whether an __instance__ is derived from some class. --- Works the other way around as `is_a`. --- @function some_class:class_of --- @param some_instance instance to check against --- @return `true` if `some_instance` is derived from `some_class` --- @usage local pussycat = Lion() -- assuming Lion derives from Cat --- if Cat:class_of(pussycat) then --- -- it's true --- end -local function class_of(klass,obj) - if type(klass) ~= 'table' or not rawget(klass,'is_a') then return false end - return klass.is_a(obj,klass) -end - ---- cast an object to another class. --- It is not clever (or safe!) so use carefully. --- @param some_instance the object to be changed --- @function some_class:cast -local function cast (klass, obj) - return setmetatable(obj,klass) -end - - -local function _class_tostring (obj) - local mt = obj._class - local name = rawget(mt,'_name') - setmetatable(obj,nil) - local str = tostring(obj) - setmetatable(obj,mt) - if name then str = name ..str:gsub('table','') end - return str -end - -local function tupdate(td,ts,dont_override) - for k,v in pairs(ts) do - if not dont_override or td[k] == nil then - td[k] = v - end - end -end - -local function _class(base,c_arg,c) - -- the class `c` will be the metatable for all its objects, - -- and they will look up their methods in it. - local mt = {} -- a metatable for the class to support __call and _handler - -- can define class by passing it a plain table of methods - local plain = type(base) == 'table' and not getmetatable(base) - if plain then - c = base - base = c._base - else - c = c or {} - end - - if type(base) == 'table' then - -- our new class is a shallow copy of the base class! - -- but be careful not to wipe out any methods we have been given at this point! - tupdate(c,base,plain) - c._base = base - -- inherit the 'not found' handler, if present - if rawget(c,'_handler') then mt.__index = c._handler end - elseif base ~= nil then - error("must derive from a table type",3) - end - - c.__index = c - setmetatable(c,mt) - if not plain then - c._init = nil - end - - if base and rawget(base,'_class_init') then - base._class_init(c,c_arg) - end - - -- expose a ctor which can be called by () - mt.__call = function(class_tbl,...) - local obj - if rawget(c,'_create') then obj = c._create(...) end - if not obj then obj = {} end - setmetatable(obj,c) - - if rawget(c,'_init') then -- explicit constructor - local res = call_ctor(c,obj,...) - if res then -- _if_ a ctor returns a value, it becomes the object... - obj = res - setmetatable(obj,c) - end - elseif base and rawget(base,'_init') then -- default constructor - -- make sure that any stuff from the base class is initialized! - call_ctor(base,obj,...) - end - - if base and rawget(base,'_post_init') then - base._post_init(obj) - end - - if not rawget(c,'__tostring') then - c.__tostring = _class_tostring - end - return obj - end - -- Call Class.catch to set a handler for methods/properties not found in the class! - c.catch = function(self, handler) - if type(self) == "function" then - -- called using . instead of : - handler = self - end - c._handler = handler - mt.__index = handler - end - c.is_a = is_a - c.class_of = class_of - c.cast = cast - c._class = c - - return c -end - ---- create a new class, derived from a given base class. --- Supporting two class creation syntaxes: --- either `Name = class(base)` or `class.Name(base)`. --- The first form returns the class directly and does not set its `_name`. --- The second form creates a variable `Name` in the current environment set --- to the class, and also sets `_name`. --- @function class --- @param base optional base class --- @param c_arg optional parameter to class constructor --- @param c optional table to be used as class -local class -class = setmetatable({},{ - __call = function(fun,...) - return _class(...) - end, - __index = function(tbl,key) - if key == 'class' then - io.stderr:write('require("pl.class").class is deprecated. Use require("pl.class")\n') - return class - end - compat = compat or require 'pl.compat' - local env = compat.getfenv(2) - return function(...) - local c = _class(...) - c._name = key - rawset(env,key,c) - return c - end - end -}) - -class.properties = class() - -function class.properties._class_init(klass) - klass.__index = function(t,key) - -- normal class lookup! - local v = klass[key] - if v then return v end - -- is it a getter? - v = rawget(klass,'get_'..key) - if v then - return v(t) - end - -- is it a field? - return rawget(t,'_'..key) - end - klass.__newindex = function (t,key,value) - -- if there's a setter, use that, otherwise directly set table - local p = 'set_'..key - local setter = klass[p] - if setter then - setter(t,value) - else - rawset(t,key,value) - end - end -end - - -return class - diff --git a/io/init.c b/io/init.c new file mode 100644 index 0000000..d299d54 --- /dev/null +++ b/io/init.c @@ -0,0 +1,6 @@ +#include "../common.h" + +extern void nerv_param_file_init(lua_State *L); +void nerv_param_init(lua_State *L) { + nerv_param_file_init(L); +} diff --git a/io/param.c b/io/param.c new file mode 100644 index 0000000..de6ba7a --- /dev/null +++ b/io/param.c @@ -0,0 +1,171 @@ +#include +#include +#include "../common.h" +#include "param.h" + +#define INVALID_FORMAT_ERROR(fn) \ + nerv_error(L, "Invalid param file: %s", fn) +#define CHECK_FORMAT(exp, ret, fname) \ + do { \ + if ((exp) != (ret)) INVALID_FORMAT_ERROR(fn); \ + } while (0) + +const char *nerv_param_file_tname = "nerv.ParamFile"; +const char *nerv_param_file_handle_tname = "nerv.ParamFileHandle"; +const char *nerv_param_chunk_info_tname = "nerv.ParamChunkInfo"; +const char *nerv_param_chunk_data_tname = "nerv.ParamChunkData"; + +#define PARAM_HEADER_SIZE 16 + +enum { + NORMAL, + INVALID_FORMAT, + END_OF_FILE +}; + +size_t read_param_header_plain(FILE *fp, int *status) { + static char buff[PARAM_HEADER_SIZE]; + int i; + size_t size = 0; + *status = NORMAL; + if (fread(buff, 1, PARAM_HEADER_SIZE, fp) != PARAM_HEADER_SIZE) + { + if (feof(fp)) *status = END_OF_FILE; + else *status = INVALID_FORMAT; + } + for (i = 0; i < PARAM_HEADER_SIZE; i++) + if (isdigit(buff[i])) + size = size * 10 + buff[i] - '0'; + fprintf(stderr, "header: %d\n", size); + return size; +} + +ParamChunkData *get_param_chunk_data(FILE *fp, ParamChunkInfo *info) { + ParamChunkData *pcd = (ParamChunkData *)malloc(sizeof(ParamChunkData)); + pcd->data = (char *)malloc(info->length); + pcd->fp = fmemopen(pcd->data, info->length, "r"); + assert(fseeko(fp, info->offset, SEEK_SET) == 0); + assert(fread(pcd->data, 1, info->length, fp) == (size_t)info->length); + return pcd; +} + +const char *read_param_metadata(lua_State *L, FILE *fp, const char *fn) { +#define LINEBUFF_SIZE 1024 + static char buff[7 + LINEBUFF_SIZE] = "return "; + CHECK_FORMAT(fgets(buff + 7, LINEBUFF_SIZE, fp), buff + 7, fn); + fprintf(stderr, "metadata: %s\n", buff); + return buff; +} + +int nerv_param_file_new(lua_State *L) { + + const char *fn = luaL_checkstring(L, 1); + FILE *fp = fopen(fn, "r"); + int i, status; + size_t param_len; + off_t offset; + ParamFileHandle *lfp; + + if (!fp) nerv_error(L, "Error while opening param file: %s", fn); + offset = ftello(fp); + lua_newtable(L); + lua_newtable(L); + fprintf(stderr, "%d\n", (int)offset); + for (i = 0;; offset += param_len, i++) + { + ParamChunkInfo *pci; + fprintf(stderr, "reading param chunk %d from %d\n", i, (int)offset); + /* skip to the begining of param chunk i */ + CHECK_FORMAT(fseeko(fp, offset, SEEK_SET), 0, fn); + /* read header */ + param_len = read_param_header_plain(fp, &status); + if (status == END_OF_FILE) break; + else if (status == INVALID_FORMAT) + INVALID_FORMAT_ERROR(fn); + /* read metadata */ + luaL_loadstring(L, read_param_metadata(L, fp, fn)); + CHECK_FORMAT(lua_pcall(L, 0, 1, 0), 0, fn); + CHECK_FORMAT(lua_istable(L, -1), 1, fn); + /* chunk info */ + pci = (ParamChunkInfo *)malloc(sizeof(ParamChunkInfo)); + pci->offset = ftello(fp); + pci->length = param_len - (pci->offset - offset); + fprintf(stderr, "%d + %d (skip %d)\n", (int)pci->offset, + (int)pci->length, param_len); + luaT_pushudata(L, pci, nerv_param_chunk_info_tname); + lua_setfield(L, -2, "chunk"); + lua_rawseti(L, -2, i); + } + lua_setfield(L, -2, "metadata"); + lfp = (ParamFileHandle *)malloc(sizeof(ParamFileHandle)); + lfp->fp = fp; + luaT_pushudata(L, lfp, nerv_param_file_handle_tname); + lua_setfield(L, -2, "handle"); + luaT_pushmetatable(L, nerv_param_file_tname); + lua_setmetatable(L, -2); + return 1; +} + +int nerv_param_file_get_chunkdata(lua_State *L) { + ParamFileHandle *pfh; + ParamChunkInfo *pci; + int k = luaL_checkinteger(L, 2); + + lua_getfield(L, 1, "handle"); + pfh = luaT_checkudata(L, -1, nerv_param_file_handle_tname); + lua_pop(L, 1); /* pop handle */ + + lua_getfield(L, 1, "metadata"); + /* now stack: self, k, metadata */ + lua_rawgeti(L, -1, k); + /* now stack: self, k, metadata, ith{} */ + lua_getfield(L, -1, "chunk"); + pci = luaT_checkudata(L, -1, nerv_param_chunk_info_tname); + + luaT_pushudata(L, get_param_chunk_data(pfh->fp, pci), + nerv_param_chunk_data_tname); + lua_setfield(L, -2, "data"); + return 1; +} + +int nerv_param_file_handle_destroy(lua_State *L) { + ParamFileHandle *pfh = luaT_checkudata(L, 1, + nerv_param_file_handle_tname); + fclose(pfh->fp); + free(pfh); + return 0; +} + +static int nerv_param_chunk_destroy(lua_State *L) { + ParamChunkInfo *pci = luaT_checkudata(L, 1, nerv_param_chunk_info_tname); + free(pci); + return 0; +} + +static int nerv_param_chunk_data_destroy(lua_State *L) { + ParamChunkData *pcd = luaT_checkudata(L, 1, nerv_param_chunk_data_tname); + fclose(pcd->fp); + free(pcd->data); + free(pcd); + return 0; +} + +static const luaL_Reg nerv_param_file_methods[] = { + {"get_chunkdata", nerv_param_file_get_chunkdata}, + {NULL, NULL} +}; + +void nerv_param_file_init(lua_State *L) { + luaT_newmetatable(L, nerv_param_file_tname, NULL, + nerv_param_file_new, + NULL, NULL); + luaL_register(L, NULL, nerv_param_file_methods); + lua_pop(L, 1); + luaT_newmetatable(L, nerv_param_file_handle_tname, NULL, + NULL, nerv_param_file_handle_destroy, NULL); + luaT_newmetatable(L, nerv_param_chunk_info_tname, NULL, + NULL, nerv_param_chunk_destroy, NULL); + luaT_newmetatable(L, nerv_param_chunk_data_tname, NULL, + NULL, nerv_param_chunk_data_destroy, NULL); +} + diff --git a/io/param.h b/io/param.h new file mode 100644 index 0000000..e5841b8 --- /dev/null +++ b/io/param.h @@ -0,0 +1,22 @@ +#ifndef NERV_LAYER_FILE_H +#define NERV_LAYER_FILE_H + +extern const char *nerv_param_file_tname; +extern const char *nerv_param_file_handle_tname; +extern const char *nerv_param_chunk_info_tname; +extern const char *nerv_param_chunk_data_tname; + +typedef struct ParamFileHandle { + FILE *fp; +} ParamFileHandle; + +typedef struct ParamChunkInfo { + off_t offset, length; +} ParamChunkInfo; + +typedef struct ParamChunkData { + FILE *fp; + char *data; +} ParamChunkData; + +#endif diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c index c301e23..6edac69 100644 --- a/matrix/generic/mmatrix.c +++ b/matrix/generic/mmatrix.c @@ -17,6 +17,9 @@ static void host_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride, *stride = width; } +static const luaL_Reg nerv_matrix_(extra_methods)[] = { +}; + int nerv_matrix_(get_elem)(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); int idx = luaL_checkinteger(L, 2); diff --git a/nerv.c b/nerv.c index 3dc9895..55ae5b6 100644 --- a/nerv.c +++ b/nerv.c @@ -4,11 +4,13 @@ extern void nerv_point_init(lua_State *L); extern void nerv_matrix_init(lua_State *L); +extern void nerv_param_init(lua_State *L); int luaopen_libnerv(lua_State *L) { lua_newtable(L); lua_setfield(L, LUA_GLOBALSINDEX, "nerv"); nerv_point_init(L); nerv_matrix_init(L); + nerv_param_init(L); return 1; } diff --git a/nerv.lua b/nerv.lua index de2e701..ccff0a0 100644 --- a/nerv.lua +++ b/nerv.lua @@ -1,3 +1,12 @@ require 'libnerv' require 'matrix.init' -nerv.class = require 'class' +nerv.class = require 'pl.class' +nerv.utils = require 'pl.utils' + +function nerv.error(fmt, ...) + error(nerv.utils.printf("Nerv internal error: " .. fmt, ...)) +end + +function nerv.error_method_not_implement() + nerv.error("method not implemented"); +end -- cgit v1.2.3