aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/param_repo.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/param_repo.lua')
-rw-r--r--nerv/nn/param_repo.lua59
1 files changed, 56 insertions, 3 deletions
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua
index c124e08..aba7765 100644
--- a/nerv/nn/param_repo.lua
+++ b/nerv/nn/param_repo.lua
@@ -1,8 +1,37 @@
local ParamRepo = nerv.class("nerv.ParamRepo")
-function ParamRepo:__init(plist)
+
+ParamRepo.LOC_TYPES = {
+ ON_DEVICE = {},
+ ON_HOST = {}
+}
+
+function ParamRepo:__init(plist, loc_type)
self.params = {}
+ self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
+ local function make_checker(tname)
+ return function (mat)
+ if not nerv.is_type(mat, tname) then
+ nerv.error("unexpected param type in repo specification")
+ end
+ end
+ end
+ self.make_copier = function (mat_type, copy_method)
+ return function (mat)
+ local target = mat_type(mat:nrow(), mat:ncol())
+ mat[copy_method](mat, target)
+ return target
+ end
+ end
+
+ if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then
+ self.checker = make_checker("nerv.MMatrix")
+ else
+ self.checker = make_checker("nerv.CuMatrix")
+ end
+
if plist ~= nil then
for i, p in ipairs(plist) do
+ p:check(self.checker)
self.params[p.id] = p
end
end
@@ -12,6 +41,7 @@ function ParamRepo:add(pid, p)
if self.params[pid] ~= nil then
nerv.error("duplicate params with the same id: %s", pid)
end
+ p:check(self.checker)
self.params[pid] = p
end
@@ -22,8 +52,8 @@ function ParamRepo:remove(pid, p)
table.remove(self.params, pid)
end
-function ParamRepo.merge(repos)
- local self = nerv.ParamRepo()
+function ParamRepo.merge(repos, loc_type)
+ local self = nerv.ParamRepo(nil, loc_type)
for i, repo in ipairs(repos) do
if not nerv.is_type(repo, "nerv.ParamRepo") then
nerv.error("nerv.ParamRepo objects expected, got %s", repo)
@@ -78,3 +108,26 @@ function ParamRepo:get_param(pid)
end
return p
end
+
+function ParamRepo:copy(loc_type, pids)
+ local copier
+ local target = nerv.ParamRepo(nil, loc_type)
+ if loc_type == nil then
+ loc_type = self.loc_type
+ end
+ if loc_type == ParamRepo.LOC_TYPES.ON_HOST then
+ copier = self.make_copier(gconf.mmat_type, 'copy_toh')
+ else
+ copier = self.make_copier(gconf.cumat_type, 'copy_tod')
+ end
+ if pids == nil then
+ for id, p in pairs(self.params) do
+ target.params[id] = p:copy(copier)
+ end
+ else
+ for i, pid in ipairs(pids) do
+ target.params[pid] = self:get_param(pid):copy(copier)
+ end
+ end
+ return target
+end