aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-05-24 15:39:24 +0800
committerDeterminant <[email protected]>2015-05-24 15:39:24 +0800
commite21e2d9480c83fee13b2e721417cc04fe8036ced (patch)
treec45e19379badce0815bdbdbd58bc8df27cc5da7d
parent0e250c43b62b7593edc163d0510d229010361707 (diff)
add param file implementation
-rw-r--r--Makefile6
-rw-r--r--class.lua250
-rw-r--r--io/init.c6
-rw-r--r--io/param.c171
-rw-r--r--io/param.h22
-rw-r--r--matrix/generic/mmatrix.c3
-rw-r--r--nerv.c2
-rw-r--r--nerv.lua11
8 files changed, 218 insertions, 253 deletions
diff --git a/Makefile b/Makefile
index fa323a1..83b0df8 100644
--- a/Makefile
+++ b/Makefile
@@ -1,7 +1,7 @@
.PHONY: all clean luajit
-OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o
+OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o io/init.o io/param.o
LIBS := libnerv.so
-LUA_LIBS := matrix/init.lua nerv.lua class.lua
+LUA_LIBS := matrix/init.lua nerv.lua pl/class.lua pl/utils.lua pl/compat.lua
INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK
CUDA_BASE := /usr/local/cuda-6.5
CUDA_INCLUDE := -I $(CUDA_BASE)/include/
@@ -24,6 +24,8 @@ $(OBJ_DIR):
-mkdir -p $(OBJ_DIR)
-mkdir -p $(OBJ_DIR)/matrix
-mkdir -p $(LUA_DIR)/matrix
+ -mkdir -p $(OBJ_DIR)/io
+ -mkdir -p $(LUA_DIR)/pl
$(LUA_DIR):
-mkdir -p $(LUA_DIR)
$(OBJ_DIR)/%.o: %.c
diff --git a/class.lua b/class.lua
deleted file mode 100644
index d260c31..0000000
--- a/class.lua
+++ /dev/null
@@ -1,250 +0,0 @@
---- Provides a reuseable and convenient framework for creating classes in Lua.
--- Two possible notations:
---
--- B = class(A)
--- class.B(A)
---
--- The latter form creates a named class within the current environment. Note
--- that this implicitly brings in `pl.utils` as a dependency.
---
--- See the Guide for further @{01-introduction.md.Simplifying_Object_Oriented_Programming_in_Lua|discussion}
--- @module pl.class
-
-local error, getmetatable, io, pairs, rawget, rawset, setmetatable, tostring, type =
- _G.error, _G.getmetatable, _G.io, _G.pairs, _G.rawget, _G.rawset, _G.setmetatable, _G.tostring, _G.type
-local compat
-
--- this trickery is necessary to prevent the inheritance of 'super' and
--- the resulting recursive call problems.
-local function call_ctor (c,obj,...)
- -- nice alias for the base class ctor
- local base = rawget(c,'_base')
- if base then
- local parent_ctor = rawget(base,'_init')
- while not parent_ctor do
- base = rawget(base,'_base')
- if not base then break end
- parent_ctor = rawget(base,'_init')
- end
- if parent_ctor then
- rawset(obj,'super',function(obj,...)
- call_ctor(base,obj,...)
- end)
- end
- end
- local res = c._init(obj,...)
- rawset(obj,'super',nil)
- return res
-end
-
---- initializes an __instance__ upon creation.
--- @function class:_init
--- @param ... parameters passed to the constructor
--- @usage local Cat = class()
--- function Cat:_init(name)
--- --self:super(name) -- call the ancestor initializer if needed
--- self.name = name
--- end
---
--- local pussycat = Cat("pussycat")
--- print(pussycat.name) --> pussycat
-
---- checks whether an __instance__ is derived from some class.
--- Works the other way around as `class_of`.
--- @function instance:is_a
--- @param some_class class to check against
--- @return `true` if `instance` is derived from `some_class`
--- @usage local pussycat = Lion() -- assuming Lion derives from Cat
--- if pussycat:is_a(Cat) then
--- -- it's true
--- end
-local function is_a(self,klass)
- local m = getmetatable(self)
- if not m then return false end --*can't be an object!
- while m do
- if m == klass then return true end
- m = rawget(m,'_base')
- end
- return false
-end
-
---- checks whether an __instance__ is derived from some class.
--- Works the other way around as `is_a`.
--- @function some_class:class_of
--- @param some_instance instance to check against
--- @return `true` if `some_instance` is derived from `some_class`
--- @usage local pussycat = Lion() -- assuming Lion derives from Cat
--- if Cat:class_of(pussycat) then
--- -- it's true
--- end
-local function class_of(klass,obj)
- if type(klass) ~= 'table' or not rawget(klass,'is_a') then return false end
- return klass.is_a(obj,klass)
-end
-
---- cast an object to another class.
--- It is not clever (or safe!) so use carefully.
--- @param some_instance the object to be changed
--- @function some_class:cast
-local function cast (klass, obj)
- return setmetatable(obj,klass)
-end
-
-
-local function _class_tostring (obj)
- local mt = obj._class
- local name = rawget(mt,'_name')
- setmetatable(obj,nil)
- local str = tostring(obj)
- setmetatable(obj,mt)
- if name then str = name ..str:gsub('table','') end
- return str
-end
-
-local function tupdate(td,ts,dont_override)
- for k,v in pairs(ts) do
- if not dont_override or td[k] == nil then
- td[k] = v
- end
- end
-end
-
-local function _class(base,c_arg,c)
- -- the class `c` will be the metatable for all its objects,
- -- and they will look up their methods in it.
- local mt = {} -- a metatable for the class to support __call and _handler
- -- can define class by passing it a plain table of methods
- local plain = type(base) == 'table' and not getmetatable(base)
- if plain then
- c = base
- base = c._base
- else
- c = c or {}
- end
-
- if type(base) == 'table' then
- -- our new class is a shallow copy of the base class!
- -- but be careful not to wipe out any methods we have been given at this point!
- tupdate(c,base,plain)
- c._base = base
- -- inherit the 'not found' handler, if present
- if rawget(c,'_handler') then mt.__index = c._handler end
- elseif base ~= nil then
- error("must derive from a table type",3)
- end
-
- c.__index = c
- setmetatable(c,mt)
- if not plain then
- c._init = nil
- end
-
- if base and rawget(base,'_class_init') then
- base._class_init(c,c_arg)
- end
-
- -- expose a ctor which can be called by <classname>(<args>)
- mt.__call = function(class_tbl,...)
- local obj
- if rawget(c,'_create') then obj = c._create(...) end
- if not obj then obj = {} end
- setmetatable(obj,c)
-
- if rawget(c,'_init') then -- explicit constructor
- local res = call_ctor(c,obj,...)
- if res then -- _if_ a ctor returns a value, it becomes the object...
- obj = res
- setmetatable(obj,c)
- end
- elseif base and rawget(base,'_init') then -- default constructor
- -- make sure that any stuff from the base class is initialized!
- call_ctor(base,obj,...)
- end
-
- if base and rawget(base,'_post_init') then
- base._post_init(obj)
- end
-
- if not rawget(c,'__tostring') then
- c.__tostring = _class_tostring
- end
- return obj
- end
- -- Call Class.catch to set a handler for methods/properties not found in the class!
- c.catch = function(self, handler)
- if type(self) == "function" then
- -- called using . instead of :
- handler = self
- end
- c._handler = handler
- mt.__index = handler
- end
- c.is_a = is_a
- c.class_of = class_of
- c.cast = cast
- c._class = c
-
- return c
-end
-
---- create a new class, derived from a given base class.
--- Supporting two class creation syntaxes:
--- either `Name = class(base)` or `class.Name(base)`.
--- The first form returns the class directly and does not set its `_name`.
--- The second form creates a variable `Name` in the current environment set
--- to the class, and also sets `_name`.
--- @function class
--- @param base optional base class
--- @param c_arg optional parameter to class constructor
--- @param c optional table to be used as class
-local class
-class = setmetatable({},{
- __call = function(fun,...)
- return _class(...)
- end,
- __index = function(tbl,key)
- if key == 'class' then
- io.stderr:write('require("pl.class").class is deprecated. Use require("pl.class")\n')
- return class
- end
- compat = compat or require 'pl.compat'
- local env = compat.getfenv(2)
- return function(...)
- local c = _class(...)
- c._name = key
- rawset(env,key,c)
- return c
- end
- end
-})
-
-class.properties = class()
-
-function class.properties._class_init(klass)
- klass.__index = function(t,key)
- -- normal class lookup!
- local v = klass[key]
- if v then return v end
- -- is it a getter?
- v = rawget(klass,'get_'..key)
- if v then
- return v(t)
- end
- -- is it a field?
- return rawget(t,'_'..key)
- end
- klass.__newindex = function (t,key,value)
- -- if there's a setter, use that, otherwise directly set table
- local p = 'set_'..key
- local setter = klass[p]
- if setter then
- setter(t,value)
- else
- rawset(t,key,value)
- end
- end
-end
-
-
-return class
-
diff --git a/io/init.c b/io/init.c
new file mode 100644
index 0000000..d299d54
--- /dev/null
+++ b/io/init.c
@@ -0,0 +1,6 @@
+#include "../common.h"
+
+extern void nerv_param_file_init(lua_State *L);
+void nerv_param_init(lua_State *L) {
+ nerv_param_file_init(L);
+}
diff --git a/io/param.c b/io/param.c
new file mode 100644
index 0000000..de6ba7a
--- /dev/null
+++ b/io/param.c
@@ -0,0 +1,171 @@
+#include <stdio.h>
+#include <ctype.h>
+#include "../common.h"
+#include "param.h"
+
+#define INVALID_FORMAT_ERROR(fn) \
+ nerv_error(L, "Invalid param file: %s", fn)
+#define CHECK_FORMAT(exp, ret, fname) \
+ do { \
+ if ((exp) != (ret)) INVALID_FORMAT_ERROR(fn); \
+ } while (0)
+
+const char *nerv_param_file_tname = "nerv.ParamFile";
+const char *nerv_param_file_handle_tname = "nerv.ParamFileHandle";
+const char *nerv_param_chunk_info_tname = "nerv.ParamChunkInfo";
+const char *nerv_param_chunk_data_tname = "nerv.ParamChunkData";
+
+#define PARAM_HEADER_SIZE 16
+
+enum {
+ NORMAL,
+ INVALID_FORMAT,
+ END_OF_FILE
+};
+
+size_t read_param_header_plain(FILE *fp, int *status) {
+ static char buff[PARAM_HEADER_SIZE];
+ int i;
+ size_t size = 0;
+ *status = NORMAL;
+ if (fread(buff, 1, PARAM_HEADER_SIZE, fp) != PARAM_HEADER_SIZE)
+ {
+ if (feof(fp)) *status = END_OF_FILE;
+ else *status = INVALID_FORMAT;
+ }
+ for (i = 0; i < PARAM_HEADER_SIZE; i++)
+ if (isdigit(buff[i]))
+ size = size * 10 + buff[i] - '0';
+ fprintf(stderr, "header: %d\n", size);
+ return size;
+}
+
+ParamChunkData *get_param_chunk_data(FILE *fp, ParamChunkInfo *info) {
+ ParamChunkData *pcd = (ParamChunkData *)malloc(sizeof(ParamChunkData));
+ pcd->data = (char *)malloc(info->length);
+ pcd->fp = fmemopen(pcd->data, info->length, "r");
+ assert(fseeko(fp, info->offset, SEEK_SET) == 0);
+ assert(fread(pcd->data, 1, info->length, fp) == (size_t)info->length);
+ return pcd;
+}
+
+const char *read_param_metadata(lua_State *L, FILE *fp, const char *fn) {
+#define LINEBUFF_SIZE 1024
+ static char buff[7 + LINEBUFF_SIZE] = "return ";
+ CHECK_FORMAT(fgets(buff + 7, LINEBUFF_SIZE, fp), buff + 7, fn);
+ fprintf(stderr, "metadata: %s\n", buff);
+ return buff;
+}
+
+int nerv_param_file_new(lua_State *L) {
+
+ const char *fn = luaL_checkstring(L, 1);
+ FILE *fp = fopen(fn, "r");
+ int i, status;
+ size_t param_len;
+ off_t offset;
+ ParamFileHandle *lfp;
+
+ if (!fp) nerv_error(L, "Error while opening param file: %s", fn);
+ offset = ftello(fp);
+ lua_newtable(L);
+ lua_newtable(L);
+ fprintf(stderr, "%d\n", (int)offset);
+ for (i = 0;; offset += param_len, i++)
+ {
+ ParamChunkInfo *pci;
+ fprintf(stderr, "reading param chunk %d from %d\n", i, (int)offset);
+ /* skip to the begining of param chunk i */
+ CHECK_FORMAT(fseeko(fp, offset, SEEK_SET), 0, fn);
+ /* read header */
+ param_len = read_param_header_plain(fp, &status);
+ if (status == END_OF_FILE) break;
+ else if (status == INVALID_FORMAT)
+ INVALID_FORMAT_ERROR(fn);
+ /* read metadata */
+ luaL_loadstring(L, read_param_metadata(L, fp, fn));
+ CHECK_FORMAT(lua_pcall(L, 0, 1, 0), 0, fn);
+ CHECK_FORMAT(lua_istable(L, -1), 1, fn);
+ /* chunk info */
+ pci = (ParamChunkInfo *)malloc(sizeof(ParamChunkInfo));
+ pci->offset = ftello(fp);
+ pci->length = param_len - (pci->offset - offset);
+ fprintf(stderr, "%d + %d (skip %d)\n", (int)pci->offset,
+ (int)pci->length, param_len);
+ luaT_pushudata(L, pci, nerv_param_chunk_info_tname);
+ lua_setfield(L, -2, "chunk");
+ lua_rawseti(L, -2, i);
+ }
+ lua_setfield(L, -2, "metadata");
+ lfp = (ParamFileHandle *)malloc(sizeof(ParamFileHandle));
+ lfp->fp = fp;
+ luaT_pushudata(L, lfp, nerv_param_file_handle_tname);
+ lua_setfield(L, -2, "handle");
+ luaT_pushmetatable(L, nerv_param_file_tname);
+ lua_setmetatable(L, -2);
+ return 1;
+}
+
+int nerv_param_file_get_chunkdata(lua_State *L) {
+ ParamFileHandle *pfh;
+ ParamChunkInfo *pci;
+ int k = luaL_checkinteger(L, 2);
+
+ lua_getfield(L, 1, "handle");
+ pfh = luaT_checkudata(L, -1, nerv_param_file_handle_tname);
+ lua_pop(L, 1); /* pop handle */
+
+ lua_getfield(L, 1, "metadata");
+ /* now stack: self, k, metadata */
+ lua_rawgeti(L, -1, k);
+ /* now stack: self, k, metadata, ith{} */
+ lua_getfield(L, -1, "chunk");
+ pci = luaT_checkudata(L, -1, nerv_param_chunk_info_tname);
+
+ luaT_pushudata(L, get_param_chunk_data(pfh->fp, pci),
+ nerv_param_chunk_data_tname);
+ lua_setfield(L, -2, "data");
+ return 1;
+}
+
+int nerv_param_file_handle_destroy(lua_State *L) {
+ ParamFileHandle *pfh = luaT_checkudata(L, 1,
+ nerv_param_file_handle_tname);
+ fclose(pfh->fp);
+ free(pfh);
+ return 0;
+}
+
+static int nerv_param_chunk_destroy(lua_State *L) {
+ ParamChunkInfo *pci = luaT_checkudata(L, 1, nerv_param_chunk_info_tname);
+ free(pci);
+ return 0;
+}
+
+static int nerv_param_chunk_data_destroy(lua_State *L) {
+ ParamChunkData *pcd = luaT_checkudata(L, 1, nerv_param_chunk_data_tname);
+ fclose(pcd->fp);
+ free(pcd->data);
+ free(pcd);
+ return 0;
+}
+
+static const luaL_Reg nerv_param_file_methods[] = {
+ {"get_chunkdata", nerv_param_file_get_chunkdata},
+ {NULL, NULL}
+};
+
+void nerv_param_file_init(lua_State *L) {
+ luaT_newmetatable(L, nerv_param_file_tname, NULL,
+ nerv_param_file_new,
+ NULL, NULL);
+ luaL_register(L, NULL, nerv_param_file_methods);
+ lua_pop(L, 1);
+ luaT_newmetatable(L, nerv_param_file_handle_tname, NULL,
+ NULL, nerv_param_file_handle_destroy, NULL);
+ luaT_newmetatable(L, nerv_param_chunk_info_tname, NULL,
+ NULL, nerv_param_chunk_destroy, NULL);
+ luaT_newmetatable(L, nerv_param_chunk_data_tname, NULL,
+ NULL, nerv_param_chunk_data_destroy, NULL);
+}
+
diff --git a/io/param.h b/io/param.h
new file mode 100644
index 0000000..e5841b8
--- /dev/null
+++ b/io/param.h
@@ -0,0 +1,22 @@
+#ifndef NERV_LAYER_FILE_H
+#define NERV_LAYER_FILE_H
+
+extern const char *nerv_param_file_tname;
+extern const char *nerv_param_file_handle_tname;
+extern const char *nerv_param_chunk_info_tname;
+extern const char *nerv_param_chunk_data_tname;
+
+typedef struct ParamFileHandle {
+ FILE *fp;
+} ParamFileHandle;
+
+typedef struct ParamChunkInfo {
+ off_t offset, length;
+} ParamChunkInfo;
+
+typedef struct ParamChunkData {
+ FILE *fp;
+ char *data;
+} ParamChunkData;
+
+#endif
diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c
index c301e23..6edac69 100644
--- a/matrix/generic/mmatrix.c
+++ b/matrix/generic/mmatrix.c
@@ -17,6 +17,9 @@ static void host_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride,
*stride = width;
}
+static const luaL_Reg nerv_matrix_(extra_methods)[] = {
+};
+
int nerv_matrix_(get_elem)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
int idx = luaL_checkinteger(L, 2);
diff --git a/nerv.c b/nerv.c
index 3dc9895..55ae5b6 100644
--- a/nerv.c
+++ b/nerv.c
@@ -4,11 +4,13 @@
extern void nerv_point_init(lua_State *L);
extern void nerv_matrix_init(lua_State *L);
+extern void nerv_param_init(lua_State *L);
int luaopen_libnerv(lua_State *L) {
lua_newtable(L);
lua_setfield(L, LUA_GLOBALSINDEX, "nerv");
nerv_point_init(L);
nerv_matrix_init(L);
+ nerv_param_init(L);
return 1;
}
diff --git a/nerv.lua b/nerv.lua
index de2e701..ccff0a0 100644
--- a/nerv.lua
+++ b/nerv.lua
@@ -1,3 +1,12 @@
require 'libnerv'
require 'matrix.init'
-nerv.class = require 'class'
+nerv.class = require 'pl.class'
+nerv.utils = require 'pl.utils'
+
+function nerv.error(fmt, ...)
+ error(nerv.utils.printf("Nerv internal error: " .. fmt, ...))
+end
+
+function nerv.error_method_not_implement()
+ nerv.error("method not implemented");
+end