summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--io/init.lua21
-rw-r--r--io/param.c8
-rw-r--r--layer/affine.lua11
-rw-r--r--layer/init.lua41
-rw-r--r--matrix/init.lua13
-rw-r--r--nerv.lua2
6 files changed, 91 insertions, 5 deletions
diff --git a/io/init.lua b/io/init.lua
index 1288bd4..2fa38e6 100644
--- a/io/init.lua
+++ b/io/init.lua
@@ -5,3 +5,24 @@ function nerv.ParamFile:write_chunkdata(metadata, writer)
end
return self:_write_chunkdata(table.tostring(metadata), writer)
end
+
+function nerv.ParamFile:write_param(param)
+ local id = param.id
+ local type = param.__typename
+ if id == nil then
+ nerv_error("id of param %s must be specified", type)
+ end
+ self:write_chunkdata({id = id,
+ type = type,
+ info = param:get_info()}, param)
+end
+
+function nerv.ParamFile:read_param(id)
+ local metadata = self.metadata[id]
+ if metadata == nil then
+ nerv_error("param with id %s does not exist", id)
+ end
+ local param = assert(loadstring("return " .. metadata.type .. "(" .. id .. ")"))()
+ param:set_info(metadata.info)
+ param:read(self:get_chunkdata(id))
+end
diff --git a/io/param.c b/io/param.c
index 627c815..477df28 100644
--- a/io/param.c
+++ b/io/param.c
@@ -220,12 +220,12 @@ int nerv_param_file_write_chunkdata(lua_State *L) {
write_param_metadata(pfh->fp, metadata_str, &status);
CHECK_WRITE(status);
lua_pushvalue(L, 3);
- lua_getfield(L, -1, "save");
+ lua_getfield(L, -1, "write");
if (!lua_isfunction(L, -1))
- nerv_error(L, "\"save\" method must be implemented");
+ nerv_error(L, "\"write\" method must be implemented");
lua_pushvalue(L, -2);
- lua_pushvalue(L, 4); /* pass handle as parameter to save() */
- lua_call(L, 2, 0); /* let the save() to write */
+ lua_pushvalue(L, 4); /* pass handle as parameter to write() */
+ lua_call(L, 2, 0); /* let the write() to write */
lua_pop(L, 1);
size = ftello(pfh->fp) - start;
fseeko(pfh->fp, start, SEEK_SET);
diff --git a/layer/affine.lua b/layer/affine.lua
new file mode 100644
index 0000000..d5c50fc
--- /dev/null
+++ b/layer/affine.lua
@@ -0,0 +1,11 @@
+local LinearTransParam = nerv.class('nerv.LinearTransParam', 'nerv.Param')
+local BiasParam = nerv.class('nerv.BiasParam', 'nerv.LinearTransParam')
+local AffineLayer = nerv.class('nerv.AffineLayer', 'nerv.Layer')
+
+function LinearTransParam:read(pcdata)
+ self.trans = nerv.CuMatrixFloat.new_from_host(nerv.MMatrixFloat.load(pcdata))
+end
+
+function LinearTransParam:write(pfhandle)
+ self.trans:new_to_host():save(pfhandle)
+end
diff --git a/layer/init.lua b/layer/init.lua
new file mode 100644
index 0000000..c57a405
--- /dev/null
+++ b/layer/init.lua
@@ -0,0 +1,41 @@
+-- The following methods must be implemented to let a layer work properly
+
+local Param = nerv.class('nerv.Param')
+
+function nerv.Param:__init(id)
+ self.id = id
+end
+
+function nerv.Param:get_info()
+ return self.info
+end
+
+function nerv.Param:set_info(info)
+ self.info = info
+end
+
+function nerv.Param:read(pfhandle)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Param:write(pfhandle)
+ nerv.error_method_not_implemented()
+end
+
+local Layer = nerv.class('nerv.Layer')
+
+function nerv.Layer:_init(param)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Layer:update(bp_err, input, output)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Layer:propagate(input, output)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Layer:back_propagate(next_bp_err, bp_err, input, output)
+ nerv.error_method_not_implemented()
+end
diff --git a/matrix/init.lua b/matrix/init.lua
index 8f626dc..08080a9 100644
--- a/matrix/init.lua
+++ b/matrix/init.lua
@@ -38,3 +38,16 @@ function nerv.CuMatrix:__mul__(b)
c:mul(self, b, 'N', 'N')
return c
end
+
+function nerv.CuMatrixFloat.new_from_host(mat)
+ local res = nerv.CuMatrixFloat(mat:nrow(), mat:ncol())
+ res:copy_from(mat)
+ print(res)
+ return res
+end
+
+function nerv.CuMatrixFloat:new_to_host()
+ local res = nerv.MMatrixFloat(self:nrow(), self:ncol())
+ self:copy_to(res)
+ return res
+end
diff --git a/nerv.lua b/nerv.lua
index 0b8943e..00042a7 100644
--- a/nerv.lua
+++ b/nerv.lua
@@ -5,7 +5,7 @@ function nerv.error(fmt, ...)
error(nerv.utils.printf("Nerv internal error: " .. fmt .. "\n", ...))
end
-function nerv.error_method_not_implement()
+function nerv.error_method_not_implemented()
nerv.error("method not implemented");
end