aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/param_repo.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/param_repo.lua')
-rw-r--r--nerv/nn/param_repo.lua63
1 files changed, 57 insertions, 6 deletions
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua
index 932ed2a..a9eb0bd 100644
--- a/nerv/nn/param_repo.lua
+++ b/nerv/nn/param_repo.lua
@@ -1,10 +1,22 @@
+--- Implements a concept that stores a collection of parameter groups.
+
+--- The class for stroing a collection of parameter groups (`nerv.Param`).
+
local ParamRepo = nerv.class("nerv.ParamRepo")
+--- The location constants for `loc_type`.
+-- @field ON_DEVICE the storage is on device (GPU RAM)
+-- @field ON_HOST the storage is on host (main RAM)
+
ParamRepo.LOC_TYPES = {
ON_DEVICE = {},
ON_HOST = {}
}
+--- The constructor.
+-- @param plist an array of parameters that will be initially in the collection
+-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES`
+
function ParamRepo:__init(plist, loc_type)
self.params = {}
self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
@@ -37,6 +49,9 @@ function ParamRepo:__init(plist, loc_type)
end
end
+--- Add a parameter to the collection.
+-- @param p the parameter to be added
+
function ParamRepo:add(p)
if self.params[p.id] ~= nil then
nerv.error("duplicate params with the same id: %s", p.id)
@@ -45,6 +60,9 @@ function ParamRepo:add(p)
self.params[p.id] = p
end
+--- Remove a parameter from the collection.
+-- @param pid the id of the parameter to be removed
+
function ParamRepo:remove(pid)
if self.params[pid] == nil then
nerv.error("param %s does not exit", pid)
@@ -52,7 +70,16 @@ function ParamRepo:remove(pid)
self.params[pid] = nil
end
+--- Merge two or more parameter collecitons.
+-- @param repos an array of parameter repos to be merged
+-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES`
+-- @return the merged parameter collection (repo)
+
function ParamRepo.merge(repos, loc_type)
+
+-- TODO: remove redundant `loc_type` and check the consistency of `loc_type`
+-- from different merging param repos.
+
local self = nerv.ParamRepo(nil, loc_type)
for i, repo in ipairs(repos) do
if not nerv.is_type(repo, "nerv.ParamRepo") then
@@ -65,6 +92,12 @@ function ParamRepo.merge(repos, loc_type)
return self
end
+--- Import parameters from a NERV chunk file.
+-- @param param_files an array of filenames of the files to be loaded from
+-- @param gconf a table describing the computation state and providing
+-- with some global settings
+-- @param pids optional, an array of identifiers of the parameters to be imported
+
function ParamRepo:import(param_files, gconf, pids)
if type(param_files) ~= "table" then
nerv.error("param file table is need")
@@ -83,24 +116,36 @@ function ParamRepo:import(param_files, gconf, pids)
end
end
+--- Export the parameter collection to a NERV chunk file.
+-- @param param_file the output filename
+-- @param pids optional, the identifiers of the parameters to be exported
+
function ParamRepo:export(param_file, pids)
cf = nerv.ChunkFile(param_file, "w")
-if pids == nil then
- for id, p in pairs(self.params) do
- cf:write_chunk(p)
- end
-else
- for i, pid in ipairs(pids) do
+ if pids == nil then
+ for id, p in pairs(self.params) do
+ cf:write_chunk(p)
+ end
+ else
+ for i, pid in ipairs(pids) do
cf:write_chunk(self:get_param(pid))
end
end
cf:close()
end
+--- Test whether the collection has a parameter.
+-- @param pid the identifier to be tested
+-- @return true if a parameter with the identifier exists
+
function ParamRepo:has_param(pid)
return self.params[pid] ~= nil
end
+--- Retrieve the parameter by the identifier.
+-- @param pid the identifier of the parameter to be retrieved
+-- @return the retrieved parameter
+
function ParamRepo:get_param(pid)
local p = self.params[pid]
if p == nil then
@@ -109,6 +154,12 @@ function ParamRepo:get_param(pid)
return p
end
+--- Create a copy of the current collection.
+-- @param loc_type the storage location of the new copy
+-- @param gconf a table describing the computation state and providing
+-- with some global settings
+-- @param pids optional, an array of identifiers of the parameters to be copied
+
function ParamRepo:copy(loc_type, gconf, pids)
local copier
local target = nerv.ParamRepo(nil, loc_type)