--- Implements a concept that stores a collection of parameter groups. --- The class for stroing a collection of parameter groups (`nerv.Param`). local ParamRepo = nerv.class("nerv.ParamRepo") --- The location constants for `loc_type`. -- @field ON_DEVICE the storage is on device (GPU RAM) -- @field ON_HOST the storage is on host (main RAM) ParamRepo.LOC_TYPES = { ON_DEVICE = {}, ON_HOST = {} } --- The constructor. -- @param plist an array of parameters that will be initially in the collection -- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES` function ParamRepo:__init(plist, loc_type) self.params = {} self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST local function make_checker(tname) return function (mat) if not nerv.is_type(mat, tname) then nerv.error("unexpected param type in repo specification") end end end self.make_copier = function (mat_type, copy_method) return function (mat) local target = mat_type(mat:nrow(), mat:ncol()) mat[copy_method](mat, target) return target end end if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then self.checker = make_checker("nerv.MMatrix") else self.checker = make_checker("nerv.CuMatrix") end if plist ~= nil then for i, p in ipairs(plist) do p:check(self.checker) self.params[p.id] = p end end end --- Add a parameter to the collection. -- @param p the parameter to be added function ParamRepo:add(p) if self.params[p.id] ~= nil then nerv.error("duplicate params with the same id: %s", p.id) end p:check(self.checker) self.params[p.id] = p end --- Remove a parameter from the collection. -- @param pid the id of the parameter to be removed function ParamRepo:remove(pid) if self.params[pid] == nil then nerv.error("param %s does not exit", pid) end self.params[pid] = nil end --- Merge two or more parameter collecitons. -- @param repos an array of parameter repos to be merged -- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES` -- @return the merged parameter collection (repo) function ParamRepo.merge(repos, loc_type) -- TODO: remove redundant `loc_type` and check the consistency of `loc_type` -- from different merging param repos. local self = nerv.ParamRepo(nil, loc_type) for i, repo in ipairs(repos) do if not nerv.is_type(repo, "nerv.ParamRepo") then nerv.error("nerv.ParamRepo objects expected, got %s", repo) end for pid, p in pairs(repo.params) do self:add(p) end end return self end --- Import parameters from a NERV chunk file. -- @param param_files an array of filenames of the files to be loaded from -- @param gconf a table describing the computation state and providing -- with some global settings -- @param pids optional, an array of identifiers of the parameters to be imported function ParamRepo:import(param_files, gconf, pids) if type(param_files) ~= "table" then nerv.error("param file table is need") end for i = 1, #param_files do local pf = nerv.ChunkFile(param_files[i], "r") for cid, cspec in pairs(pf.metadata) do if pids == nil or pids[cid] ~= nil then local p = pf:read_chunk(cid, gconf) if not nerv.is_type(p, "nerv.Param") then nerv.error("param chunk is expected") end self:add(p) end end end end --- Export the parameter collection to a NERV chunk file. -- @param param_file the output filename -- @param pids optional, the identifiers of the parameters to be exported function ParamRepo:export(param_file, pids) cf = nerv.ChunkFile(param_file, "w") if pids == nil then for id, p in pairs(self.params) do cf:write_chunk(p) end else for i, pid in ipairs(pids) do cf:write_chunk(self:get_param(pid)) end end cf:close() end --- Test whether the collection has a parameter. -- @param pid the identifier to be tested -- @return true if a parameter with the identifier exists function ParamRepo:has_param(pid) return self.params[pid] ~= nil end --- Retrieve the parameter by the identifier. -- @param pid the identifier of the parameter to be retrieved -- @return the retrieved parameter function ParamRepo:get_param(pid) local p = self.params[pid] if p == nil then nerv.error("param with id %s not found", pid) end return p end --- Create a copy of the current collection. -- @param loc_type the storage location of the new copy -- @param gconf a table describing the computation state and providing -- with some global settings -- @param pids optional, an array of identifiers of the parameters to be copied function ParamRepo:copy(loc_type, gconf, pids) local copier local target = nerv.ParamRepo(nil, loc_type) if loc_type == nil then loc_type = self.loc_type end if loc_type == ParamRepo.LOC_TYPES.ON_HOST then copier = self.make_copier(gconf.mmat_type, 'copy_toh') else copier = self.make_copier(gconf.cumat_type, 'copy_tod') end if pids == nil then for id, p in pairs(self.params) do target.params[id] = p:copy(copier) end else for i, pid in ipairs(pids) do target.params[pid] = self:get_param(pid):copy(copier) end end return target end