diff options
Diffstat (limited to 'nerv/nn/param_repo.lua')
-rw-r--r-- | nerv/nn/param_repo.lua | 63 |
1 files changed, 57 insertions, 6 deletions
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua index 932ed2a..a9eb0bd 100644 --- a/nerv/nn/param_repo.lua +++ b/nerv/nn/param_repo.lua @@ -1,10 +1,22 @@ +--- Implements a concept that stores a collection of parameter groups. + +--- The class for stroing a collection of parameter groups (`nerv.Param`). + local ParamRepo = nerv.class("nerv.ParamRepo") +--- The location constants for `loc_type`. +-- @field ON_DEVICE the storage is on device (GPU RAM) +-- @field ON_HOST the storage is on host (main RAM) + ParamRepo.LOC_TYPES = { ON_DEVICE = {}, ON_HOST = {} } +--- The constructor. +-- @param plist an array of parameters that will be initially in the collection +-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES` + function ParamRepo:__init(plist, loc_type) self.params = {} self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST @@ -37,6 +49,9 @@ function ParamRepo:__init(plist, loc_type) end end +--- Add a parameter to the collection. +-- @param p the parameter to be added + function ParamRepo:add(p) if self.params[p.id] ~= nil then nerv.error("duplicate params with the same id: %s", p.id) @@ -45,6 +60,9 @@ function ParamRepo:add(p) self.params[p.id] = p end +--- Remove a parameter from the collection. +-- @param pid the id of the parameter to be removed + function ParamRepo:remove(pid) if self.params[pid] == nil then nerv.error("param %s does not exit", pid) @@ -52,7 +70,16 @@ function ParamRepo:remove(pid) self.params[pid] = nil end +--- Merge two or more parameter collecitons. +-- @param repos an array of parameter repos to be merged +-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES` +-- @return the merged parameter collection (repo) + function ParamRepo.merge(repos, loc_type) + +-- TODO: remove redundant `loc_type` and check the consistency of `loc_type` +-- from different merging param repos. + local self = nerv.ParamRepo(nil, loc_type) for i, repo in ipairs(repos) do if not nerv.is_type(repo, "nerv.ParamRepo") then @@ -65,6 +92,12 @@ function ParamRepo.merge(repos, loc_type) return self end +--- Import parameters from a NERV chunk file. +-- @param param_files an array of filenames of the files to be loaded from +-- @param gconf a table describing the computation state and providing +-- with some global settings +-- @param pids optional, an array of identifiers of the parameters to be imported + function ParamRepo:import(param_files, gconf, pids) if type(param_files) ~= "table" then nerv.error("param file table is need") @@ -83,24 +116,36 @@ function ParamRepo:import(param_files, gconf, pids) end end +--- Export the parameter collection to a NERV chunk file. +-- @param param_file the output filename +-- @param pids optional, the identifiers of the parameters to be exported + function ParamRepo:export(param_file, pids) cf = nerv.ChunkFile(param_file, "w") -if pids == nil then - for id, p in pairs(self.params) do - cf:write_chunk(p) - end -else - for i, pid in ipairs(pids) do + if pids == nil then + for id, p in pairs(self.params) do + cf:write_chunk(p) + end + else + for i, pid in ipairs(pids) do cf:write_chunk(self:get_param(pid)) end end cf:close() end +--- Test whether the collection has a parameter. +-- @param pid the identifier to be tested +-- @return true if a parameter with the identifier exists + function ParamRepo:has_param(pid) return self.params[pid] ~= nil end +--- Retrieve the parameter by the identifier. +-- @param pid the identifier of the parameter to be retrieved +-- @return the retrieved parameter + function ParamRepo:get_param(pid) local p = self.params[pid] if p == nil then @@ -109,6 +154,12 @@ function ParamRepo:get_param(pid) return p end +--- Create a copy of the current collection. +-- @param loc_type the storage location of the new copy +-- @param gconf a table describing the computation state and providing +-- with some global settings +-- @param pids optional, an array of identifiers of the parameters to be copied + function ParamRepo:copy(loc_type, gconf, pids) local copier local target = nerv.ParamRepo(nil, loc_type) |