diff options
-rw-r--r-- | io/init.lua | 21 | ||||
-rw-r--r-- | io/param.c | 8 | ||||
-rw-r--r-- | layer/affine.lua | 11 | ||||
-rw-r--r-- | layer/init.lua | 41 | ||||
-rw-r--r-- | matrix/init.lua | 13 | ||||
-rw-r--r-- | nerv.lua | 2 |
6 files changed, 91 insertions, 5 deletions
diff --git a/io/init.lua b/io/init.lua index 1288bd4..2fa38e6 100644 --- a/io/init.lua +++ b/io/init.lua @@ -5,3 +5,24 @@ function nerv.ParamFile:write_chunkdata(metadata, writer) end return self:_write_chunkdata(table.tostring(metadata), writer) end + +function nerv.ParamFile:write_param(param) + local id = param.id + local type = param.__typename + if id == nil then + nerv_error("id of param %s must be specified", type) + end + self:write_chunkdata({id = id, + type = type, + info = param:get_info()}, param) +end + +function nerv.ParamFile:read_param(id) + local metadata = self.metadata[id] + if metadata == nil then + nerv_error("param with id %s does not exist", id) + end + local param = assert(loadstring("return " .. metadata.type .. "(" .. id .. ")"))() + param:set_info(metadata.info) + param:read(self:get_chunkdata(id)) +end @@ -220,12 +220,12 @@ int nerv_param_file_write_chunkdata(lua_State *L) { write_param_metadata(pfh->fp, metadata_str, &status); CHECK_WRITE(status); lua_pushvalue(L, 3); - lua_getfield(L, -1, "save"); + lua_getfield(L, -1, "write"); if (!lua_isfunction(L, -1)) - nerv_error(L, "\"save\" method must be implemented"); + nerv_error(L, "\"write\" method must be implemented"); lua_pushvalue(L, -2); - lua_pushvalue(L, 4); /* pass handle as parameter to save() */ - lua_call(L, 2, 0); /* let the save() to write */ + lua_pushvalue(L, 4); /* pass handle as parameter to write() */ + lua_call(L, 2, 0); /* let the write() to write */ lua_pop(L, 1); size = ftello(pfh->fp) - start; fseeko(pfh->fp, start, SEEK_SET); diff --git a/layer/affine.lua b/layer/affine.lua new file mode 100644 index 0000000..d5c50fc --- /dev/null +++ b/layer/affine.lua @@ -0,0 +1,11 @@ +local LinearTransParam = nerv.class('nerv.LinearTransParam', 'nerv.Param') +local BiasParam = nerv.class('nerv.BiasParam', 'nerv.LinearTransParam') +local AffineLayer = nerv.class('nerv.AffineLayer', 'nerv.Layer') + +function LinearTransParam:read(pcdata) + self.trans = nerv.CuMatrixFloat.new_from_host(nerv.MMatrixFloat.load(pcdata)) +end + +function LinearTransParam:write(pfhandle) + self.trans:new_to_host():save(pfhandle) +end diff --git a/layer/init.lua b/layer/init.lua new file mode 100644 index 0000000..c57a405 --- /dev/null +++ b/layer/init.lua @@ -0,0 +1,41 @@ +-- The following methods must be implemented to let a layer work properly + +local Param = nerv.class('nerv.Param') + +function nerv.Param:__init(id) + self.id = id +end + +function nerv.Param:get_info() + return self.info +end + +function nerv.Param:set_info(info) + self.info = info +end + +function nerv.Param:read(pfhandle) + nerv.error_method_not_implemented() +end + +function nerv.Param:write(pfhandle) + nerv.error_method_not_implemented() +end + +local Layer = nerv.class('nerv.Layer') + +function nerv.Layer:_init(param) + nerv.error_method_not_implemented() +end + +function nerv.Layer:update(bp_err, input, output) + nerv.error_method_not_implemented() +end + +function nerv.Layer:propagate(input, output) + nerv.error_method_not_implemented() +end + +function nerv.Layer:back_propagate(next_bp_err, bp_err, input, output) + nerv.error_method_not_implemented() +end diff --git a/matrix/init.lua b/matrix/init.lua index 8f626dc..08080a9 100644 --- a/matrix/init.lua +++ b/matrix/init.lua @@ -38,3 +38,16 @@ function nerv.CuMatrix:__mul__(b) c:mul(self, b, 'N', 'N') return c end + +function nerv.CuMatrixFloat.new_from_host(mat) + local res = nerv.CuMatrixFloat(mat:nrow(), mat:ncol()) + res:copy_from(mat) + print(res) + return res +end + +function nerv.CuMatrixFloat:new_to_host() + local res = nerv.MMatrixFloat(self:nrow(), self:ncol()) + self:copy_to(res) + return res +end @@ -5,7 +5,7 @@ function nerv.error(fmt, ...) error(nerv.utils.printf("Nerv internal error: " .. fmt .. "\n", ...)) end -function nerv.error_method_not_implement() +function nerv.error_method_not_implemented() nerv.error("method not implemented"); end |