From 81bf2d653902860c5d28ccade19ac6e1fd56acaf Mon Sep 17 00:00:00 2001 From: Determinant Date: Tue, 26 May 2015 14:06:52 +0800 Subject: add layer and param --- io/init.lua | 21 +++++++++++++++++++++ io/param.c | 8 ++++---- layer/affine.lua | 11 +++++++++++ layer/init.lua | 41 +++++++++++++++++++++++++++++++++++++++++ matrix/init.lua | 13 +++++++++++++ nerv.lua | 2 +- 6 files changed, 91 insertions(+), 5 deletions(-) create mode 100644 layer/affine.lua create mode 100644 layer/init.lua diff --git a/io/init.lua b/io/init.lua index 1288bd4..2fa38e6 100644 --- a/io/init.lua +++ b/io/init.lua @@ -5,3 +5,24 @@ function nerv.ParamFile:write_chunkdata(metadata, writer) end return self:_write_chunkdata(table.tostring(metadata), writer) end + +function nerv.ParamFile:write_param(param) + local id = param.id + local type = param.__typename + if id == nil then + nerv_error("id of param %s must be specified", type) + end + self:write_chunkdata({id = id, + type = type, + info = param:get_info()}, param) +end + +function nerv.ParamFile:read_param(id) + local metadata = self.metadata[id] + if metadata == nil then + nerv_error("param with id %s does not exist", id) + end + local param = assert(loadstring("return " .. metadata.type .. "(" .. id .. ")"))() + param:set_info(metadata.info) + param:read(self:get_chunkdata(id)) +end diff --git a/io/param.c b/io/param.c index 627c815..477df28 100644 --- a/io/param.c +++ b/io/param.c @@ -220,12 +220,12 @@ int nerv_param_file_write_chunkdata(lua_State *L) { write_param_metadata(pfh->fp, metadata_str, &status); CHECK_WRITE(status); lua_pushvalue(L, 3); - lua_getfield(L, -1, "save"); + lua_getfield(L, -1, "write"); if (!lua_isfunction(L, -1)) - nerv_error(L, "\"save\" method must be implemented"); + nerv_error(L, "\"write\" method must be implemented"); lua_pushvalue(L, -2); - lua_pushvalue(L, 4); /* pass handle as parameter to save() */ - lua_call(L, 2, 0); /* let the save() to write */ + lua_pushvalue(L, 4); /* pass handle as parameter to write() */ + lua_call(L, 2, 0); /* let the write() to write */ lua_pop(L, 1); size = ftello(pfh->fp) - start; fseeko(pfh->fp, start, SEEK_SET); diff --git a/layer/affine.lua b/layer/affine.lua new file mode 100644 index 0000000..d5c50fc --- /dev/null +++ b/layer/affine.lua @@ -0,0 +1,11 @@ +local LinearTransParam = nerv.class('nerv.LinearTransParam', 'nerv.Param') +local BiasParam = nerv.class('nerv.BiasParam', 'nerv.LinearTransParam') +local AffineLayer = nerv.class('nerv.AffineLayer', 'nerv.Layer') + +function LinearTransParam:read(pcdata) + self.trans = nerv.CuMatrixFloat.new_from_host(nerv.MMatrixFloat.load(pcdata)) +end + +function LinearTransParam:write(pfhandle) + self.trans:new_to_host():save(pfhandle) +end diff --git a/layer/init.lua b/layer/init.lua new file mode 100644 index 0000000..c57a405 --- /dev/null +++ b/layer/init.lua @@ -0,0 +1,41 @@ +-- The following methods must be implemented to let a layer work properly + +local Param = nerv.class('nerv.Param') + +function nerv.Param:__init(id) + self.id = id +end + +function nerv.Param:get_info() + return self.info +end + +function nerv.Param:set_info(info) + self.info = info +end + +function nerv.Param:read(pfhandle) + nerv.error_method_not_implemented() +end + +function nerv.Param:write(pfhandle) + nerv.error_method_not_implemented() +end + +local Layer = nerv.class('nerv.Layer') + +function nerv.Layer:_init(param) + nerv.error_method_not_implemented() +end + +function nerv.Layer:update(bp_err, input, output) + nerv.error_method_not_implemented() +end + +function nerv.Layer:propagate(input, output) + nerv.error_method_not_implemented() +end + +function nerv.Layer:back_propagate(next_bp_err, bp_err, input, output) + nerv.error_method_not_implemented() +end diff --git a/matrix/init.lua b/matrix/init.lua index 8f626dc..08080a9 100644 --- a/matrix/init.lua +++ b/matrix/init.lua @@ -38,3 +38,16 @@ function nerv.CuMatrix:__mul__(b) c:mul(self, b, 'N', 'N') return c end + +function nerv.CuMatrixFloat.new_from_host(mat) + local res = nerv.CuMatrixFloat(mat:nrow(), mat:ncol()) + res:copy_from(mat) + print(res) + return res +end + +function nerv.CuMatrixFloat:new_to_host() + local res = nerv.MMatrixFloat(self:nrow(), self:ncol()) + self:copy_to(res) + return res +end diff --git a/nerv.lua b/nerv.lua index 0b8943e..00042a7 100644 --- a/nerv.lua +++ b/nerv.lua @@ -5,7 +5,7 @@ function nerv.error(fmt, ...) error(nerv.utils.printf("Nerv internal error: " .. fmt .. "\n", ...)) end -function nerv.error_method_not_implement() +function nerv.error_method_not_implemented() nerv.error("method not implemented"); end -- cgit v1.2.3-70-g09d2