summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile3
-rw-r--r--examples/oop_example.c26
-rw-r--r--examples/oop_example.lua15
-rw-r--r--io/param.c40
-rw-r--r--nerv.lua6
5 files changed, 45 insertions, 45 deletions
diff --git a/Makefile b/Makefile
index c562e63..cadbc77 100644
--- a/Makefile
+++ b/Makefile
@@ -4,7 +4,7 @@ OBJS := nerv.o luaT.o common.o \
io/init.o io/param.o \
examples/oop_example.o
LIBS := libnerv.so
-LUA_LIBS := matrix/init.lua io/init.lua nerv.lua pl/utils.lua pl/compat.lua
+LUA_LIBS := matrix/init.lua io/init.lua nerv.lua pl/utils.lua pl/compat.lua layer/init.lua layer/affine.lua
INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK
CUDA_BASE := /usr/local/cuda-6.5
CUDA_INCLUDE := -I $(CUDA_BASE)/include/
@@ -30,6 +30,7 @@ $(OBJ_DIR):
-mkdir -p $(OBJ_DIR)/io
-mkdir -p $(LUA_DIR)/io
-mkdir -p $(LUA_DIR)/pl
+ -mkdir -p $(LUA_DIR)/layer
-mkdir -p $(OBJ_DIR)/examples
$(LUA_DIR):
-mkdir -p $(LUA_DIR)
diff --git a/examples/oop_example.c b/examples/oop_example.c
index 859bd03..59dfc5a 100644
--- a/examples/oop_example.c
+++ b/examples/oop_example.c
@@ -46,22 +46,9 @@ int point_new(lua_State *L) {
return 1;
}
-int point___init(lua_State *L) {
- /* The difference between this function and `_new` function is that this
- * one is called by subclass of Point implemented in Lua, although it
- * basically does the same thing as `_new`. Also, it can read the empty
- * object (table) from the stack. (In this example, the table is ignored.) */
- Point *self = (Point *)malloc(sizeof(Point));
- point_new_(self, luaL_checknumber(L, 2), luaL_checknumber(L, 3));
- luaT_pushudata(L, self, point_tname);
- fprintf(stderr, "[example] A subclass has invoked `__init`\n");
- return 1;
-}
-
static const luaL_Reg point_methods[] = {
{"set_x", point_set_x},
{"set_y", point_set_y},
- {"__init", point___init},
{"norm", point_norm},
{NULL, NULL}
};
@@ -84,21 +71,8 @@ int better_point_new(lua_State *L) {
return 1;
}
-int better_point___init(lua_State *L) {
- /* The difference between this function and `_new` function is that this
- * one is called by subclass of Point implemented in Lua, although it
- * basically does the same thing as `_new`. Also, it can read the empty
- * object (table) from the stack. (In this example, the table is ignored.) */
- Point *self = (Point *)malloc(sizeof(Point));
- point_new_(self, luaL_checknumber(L, 2), luaL_checknumber(L, 3));
- luaT_pushudata(L, self, better_point_tname);
- fprintf(stderr, "[example] A subclass has invoked `__init`\n");
- return 1;
-}
-
static const luaL_Reg better_point_methods[] = {
{"norm", better_point_norm},
- {"__init", better_point___init},
{NULL, NULL}
};
diff --git a/examples/oop_example.lua b/examples/oop_example.lua
index 712387b..b753288 100644
--- a/examples/oop_example.lua
+++ b/examples/oop_example.lua
@@ -5,11 +5,12 @@ p:set_x(1.0)
p:set_y(2.0)
print(p:norm()) -- get 2-norm of the Point
-p = nerv.BetterPoint(1, 2)
-print(p)
-print(p:norm()) --get 1-norm of the Point
+bp = nerv.BetterPoint(1, 2)
+-- use methods from base class
+bp:set_x(1.0)
+bp:set_y(2.0)
+print(bp)
+print(bp:norm()) --get 1-norm of the Point
--- create a subclass using lua
-local EvenBetterPoint = nerv.class('nerv.EvenBetterPoint', 'nerv.BetterPoint')
-bp = nerv.EvenBetterPoint(1, 2)
-print(p:norm())
+print(p.__typename)
+print(bp.__typename)
diff --git a/io/param.c b/io/param.c
index e10944f..627c815 100644
--- a/io/param.c
+++ b/io/param.c
@@ -141,6 +141,7 @@ int nerv_param_file_open_read(lua_State *L, const char *fn) {
luaL_loadstring(L, read_param_metadata(L, fp, fn));
CHECK_FORMAT(lua_pcall(L, 0, 1, 0), 0, fn);
CHECK_FORMAT(lua_istable(L, -1), 1, fn);
+ /* stack: obj_table, metadata */
/* chunk info */
pci = (ParamChunkInfo *)malloc(sizeof(ParamChunkInfo));
pci->offset = ftello(fp);
@@ -149,7 +150,25 @@ int nerv_param_file_open_read(lua_State *L, const char *fn) {
(int)pci->length, param_len);
luaT_pushudata(L, pci, nerv_param_chunk_info_tname);
lua_setfield(L, -2, "chunk");
- lua_rawseti(L, -2, i);
+ /* stack: obj_table, metadata */
+ /* get id */
+ lua_getfield(L, -1, "id");
+ /* stack: obj_table, metadata, id */
+ if (!lua_isstring(L, -1))
+ nerv_error(L, "id field in metadata must be a string");
+ lua_pushvalue(L, -1);
+ /* stack: obj_table, metadata, id, id */
+ lua_gettable(L, -4);
+ /* stack: obj_table, metadata, id, obj[id] */
+ if (!lua_isnil(L, -1))
+ nerv_error(L, "conflicting id");
+ lua_pop(L, 1);
+ /* stack: obj_table, metadata, id */
+ lua_pushvalue(L, -2);
+ /* stack: obj_table, metadata, id, metadata */
+ lua_settable(L, -4);
+ /* stack: obj_table, metadata */
+ lua_pop(L, 1);
}
lua_setfield(L, -2, "metadata");
lfp = (ParamFileHandle *)malloc(sizeof(ParamFileHandle));
@@ -200,12 +219,14 @@ int nerv_param_file_write_chunkdata(lua_State *L) {
CHECK_WRITE(status);
write_param_metadata(pfh->fp, metadata_str, &status);
CHECK_WRITE(status);
- lua_getfield(L, 3, "save");
- if (lua_type(L, -1) != LUA_TFUNCTION)
- nerv_error(L, "\"save\" method must be implemented");
lua_pushvalue(L, 3);
- lua_pushvalue(L, -3); /* pass handle as parameter to save() */
+ lua_getfield(L, -1, "save");
+ if (!lua_isfunction(L, -1))
+ nerv_error(L, "\"save\" method must be implemented");
+ lua_pushvalue(L, -2);
+ lua_pushvalue(L, 4); /* pass handle as parameter to save() */
lua_call(L, 2, 0); /* let the save() to write */
+ lua_pop(L, 1);
size = ftello(pfh->fp) - start;
fseeko(pfh->fp, start, SEEK_SET);
/* write the calced size */
@@ -218,16 +239,17 @@ int nerv_param_file_write_chunkdata(lua_State *L) {
int nerv_param_file_get_chunkdata(lua_State *L) {
ParamFileHandle *pfh;
ParamChunkInfo *pci;
- int k = luaL_checkinteger(L, 2);
+ const char *id = luaL_checkstring(L, 2);
lua_getfield(L, 1, "handle");
pfh = luaT_checkudata(L, -1, nerv_param_file_handle_tname);
lua_pop(L, 1); /* pop handle */
-
lua_getfield(L, 1, "metadata");
/* now stack: self, k, metadata */
- lua_rawgeti(L, -1, k);
- /* now stack: self, k, metadata, ith{} */
+ lua_getfield(L, -1, id);
+ /* now stack: self, k, metadata, kth{} */
+ if (lua_isnil(L, -1)) /* no chunck with the id */
+ return 0;
lua_getfield(L, -1, "chunk");
pci = luaT_checkudata(L, -1, nerv_param_chunk_info_tname);
diff --git a/nerv.lua b/nerv.lua
index d7f7e91..0b8943e 100644
--- a/nerv.lua
+++ b/nerv.lua
@@ -1,6 +1,4 @@
require 'libnerv'
-require 'matrix.init'
-require 'io.init'
nerv.utils = require 'pl.utils'
function nerv.error(fmt, ...)
@@ -72,3 +70,7 @@ function table.tostring(tbl)
end
return "{" .. table.concat(result, ",") .. "}"
end
+
+require 'matrix.init'
+require 'io.init'
+require 'layer.init'