local ParamRepo = nerv.class("nerv.ParamRepo") ParamRepo.LOC_TYPES = { ON_DEVICE = {}, ON_HOST = {} } function ParamRepo:__init(plist, loc_type) self.params = {} self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST local function make_checker(tname) return function (mat) if not nerv.is_type(mat, tname) then nerv.error("unexpected param type in repo specification") end end end self.make_copier = function (mat_type, copy_method) return function (mat) local target = mat_type(mat:nrow(), mat:ncol()) mat[copy_method](mat, target) return target end end if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then self.checker = make_checker("nerv.MMatrix") else self.checker = make_checker("nerv.CuMatrix") end if plist ~= nil then for i, p in ipairs(plist) do p:check(self.checker) self.params[p.id] = p end end end function ParamRepo:add(p) if self.params[p.id] ~= nil then nerv.error("duplicate params with the same id: %s", p.id) end p:check(self.checker) self.params[p.id] = p end function ParamRepo:remove(pid) if self.params[pid] == nil then nerv.error("param %s does not exit", pid) end self.params[pid] = nil end function ParamRepo.merge(repos, loc_type) local self = nerv.ParamRepo(nil, loc_type) for i, repo in ipairs(repos) do if not nerv.is_type(repo, "nerv.ParamRepo") then nerv.error("nerv.ParamRepo objects expected, got %s", repo) end for pid, p in pairs(repo.params) do self:add(p) end end return self end function ParamRepo:import(param_files, gconf, pids) if type(param_files) ~= "table" then nerv.error("param file table is need") end for i = 1, #param_files do local pf = nerv.ChunkFile(param_files[i], "r") for cid, cspec in pairs(pf.metadata) do if pids == nil or pids[cid] ~= nil then local p = pf:read_chunk(cid, gconf) if not nerv.is_type(p, "nerv.Param") then nerv.error("param chunk is expected") end self:add(p) end end end end function ParamRepo:export(param_file, pids) cf = nerv.ChunkFile(param_file, "w") if pids == nil then for id, p in pairs(self.params) do cf:write_chunk(p) end else for i, pid in ipairs(pids) do cf:write_chunk(self:get_param(pid)) end end cf:close() end function ParamRepo:has_param(pid) return self.params[pid] ~= nil end function ParamRepo:get_param(pid) local p = self.params[pid] if p == nil then nerv.error("param with id %s not found", pid) end return p end function ParamRepo:copy(loc_type, gconf, pids) local copier local target = nerv.ParamRepo(nil, loc_type) if loc_type == nil then loc_type = self.loc_type end if loc_type == ParamRepo.LOC_TYPES.ON_HOST then copier = self.make_copier(gconf.mmat_type, 'copy_toh') else copier = self.make_copier(gconf.cumat_type, 'copy_tod') end if pids == nil then for id, p in pairs(self.params) do target.params[id] = p:copy(copier) end else for i, pid in ipairs(pids) do target.params[pid] = self:get_param(pid):copy(copier) end end return target end