local ParamRepo = nerv.class("nerv.ParamRepo")
ParamRepo.LOC_TYPES = {
ON_DEVICE = {},
ON_HOST = {}
}
function ParamRepo:__init(plist, loc_type)
self.params = {}
self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
local function make_checker(tname)
return function (mat)
if not nerv.is_type(mat, tname) then
nerv.error("unexpected param type in repo specification")
end
end
end
self.make_copier = function (mat_type, copy_method)
return function (mat)
local target = mat_type(mat:nrow(), mat:ncol())
mat[copy_method](mat, target)
return target
end
end
if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then
self.checker = make_checker("nerv.MMatrix")
else
self.checker = make_checker("nerv.CuMatrix")
end
if plist ~= nil then
for i, p in ipairs(plist) do
p:check(self.checker)
self.params[p.id] = p
end
end
end
function ParamRepo:add(pid, p)
if self.params[pid] ~= nil then
nerv.error("duplicate params with the same id: %s", pid)
end
p:check(self.checker)
self.params[pid] = p
end
function ParamRepo:remove(pid, p)
if self.params[pid] == nil then
nerv.error("param %s does not exit", pid)
end
table.remove(self.params, pid)
end
function ParamRepo.merge(repos, loc_type)
local self = nerv.ParamRepo(nil, loc_type)
for i, repo in ipairs(repos) do
if not nerv.is_type(repo, "nerv.ParamRepo") then
nerv.error("nerv.ParamRepo objects expected, got %s", repo)
end
for pid, p in pairs(repo.params) do
self:add(pid, p)
end
end
return self
end
function ParamRepo:import(param_files, gconf, pids)
if type(param_files) ~= "table" then
nerv.error("param file table is need")
end
for i = 1, #param_files do
local pf = nerv.ChunkFile(param_files[i], "r")
for cid, cspec in pairs(pf.metadata) do
if pids == nil or pids[cid] ~= nil then
local p = pf:read_chunk(cid, gconf)
if not nerv.is_type(p, "nerv.Param") then
nerv.error("param chunk is expected")
end
self:add(cid, p)
end
end
end
end
function ParamRepo:export(param_file, pids)
cf = nerv.ChunkFile(param_file, "w")
if pids == nil then
for id, p in pairs(self.params) do
cf:write_chunk(p)
end
else
for i, pid in ipairs(pids) do
cf:write_chunk(self:get_param(pid))
end
end
cf:close()
end
function ParamRepo:has_param(pid)
return self.params[pid] ~= nil
end
function ParamRepo:get_param(pid)
local p = self.params[pid]
if p == nil then
nerv.error("param with id %s not found", pid)
end
return p
end
function ParamRepo:copy(loc_type, gconf, pids)
local copier
local target = nerv.ParamRepo(nil, loc_type)
if loc_type == nil then
loc_type = self.loc_type
end
if loc_type == ParamRepo.LOC_TYPES.ON_HOST then
copier = self.make_copier(gconf.mmat_type, 'copy_toh')
else
copier = self.make_copier(gconf.cumat_type, 'copy_tod')
end
if pids == nil then
for id, p in pairs(self.params) do
target.params[id] = p:copy(copier)
end
else
for i, pid in ipairs(pids) do
target.params[pid] = self:get_param(pid):copy(copier)
end
end
return target
end