diff options
Diffstat (limited to 'nerv/nn/param_repo.lua')
-rw-r--r-- | nerv/nn/param_repo.lua | 59 |
1 files changed, 56 insertions, 3 deletions
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua index c124e08..aba7765 100644 --- a/nerv/nn/param_repo.lua +++ b/nerv/nn/param_repo.lua @@ -1,8 +1,37 @@ local ParamRepo = nerv.class("nerv.ParamRepo") -function ParamRepo:__init(plist) + +ParamRepo.LOC_TYPES = { + ON_DEVICE = {}, + ON_HOST = {} +} + +function ParamRepo:__init(plist, loc_type) self.params = {} + self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST + local function make_checker(tname) + return function (mat) + if not nerv.is_type(mat, tname) then + nerv.error("unexpected param type in repo specification") + end + end + end + self.make_copier = function (mat_type, copy_method) + return function (mat) + local target = mat_type(mat:nrow(), mat:ncol()) + mat[copy_method](mat, target) + return target + end + end + + if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then + self.checker = make_checker("nerv.MMatrix") + else + self.checker = make_checker("nerv.CuMatrix") + end + if plist ~= nil then for i, p in ipairs(plist) do + p:check(self.checker) self.params[p.id] = p end end @@ -12,6 +41,7 @@ function ParamRepo:add(pid, p) if self.params[pid] ~= nil then nerv.error("duplicate params with the same id: %s", pid) end + p:check(self.checker) self.params[pid] = p end @@ -22,8 +52,8 @@ function ParamRepo:remove(pid, p) table.remove(self.params, pid) end -function ParamRepo.merge(repos) - local self = nerv.ParamRepo() +function ParamRepo.merge(repos, loc_type) + local self = nerv.ParamRepo(nil, loc_type) for i, repo in ipairs(repos) do if not nerv.is_type(repo, "nerv.ParamRepo") then nerv.error("nerv.ParamRepo objects expected, got %s", repo) @@ -78,3 +108,26 @@ function ParamRepo:get_param(pid) end return p end + +function ParamRepo:copy(loc_type, pids) + local copier + local target = nerv.ParamRepo(nil, loc_type) + if loc_type == nil then + loc_type = self.loc_type + end + if loc_type == ParamRepo.LOC_TYPES.ON_HOST then + copier = self.make_copier(gconf.mmat_type, 'copy_toh') + else + copier = self.make_copier(gconf.cumat_type, 'copy_tod') + end + if pids == nil then + for id, p in pairs(self.params) do + target.params[id] = p:copy(copier) + end + else + for i, pid in ipairs(pids) do + target.params[pid] = self:get_param(pid):copy(copier) + end + end + return target +end |