aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn')
-rw-r--r--nerv/nn/layer_dag.lua16
-rw-r--r--nerv/nn/layer_repo.lua30
-rw-r--r--nerv/nn/param_repo.lua59
3 files changed, 77 insertions, 28 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
index 6896878..f999752 100644
--- a/nerv/nn/layer_dag.lua
+++ b/nerv/nn/layer_dag.lua
@@ -134,20 +134,16 @@ function DAGLayer:__init(id, global_conf, layer_conf)
end
end
+ nerv.Layer.__init(self, id, global_conf, layer_conf)
self.layers = layers
self.inputs = inputs
self.outputs = outputs
- self.id = id
- self.dim_in = dim_in
- self.dim_out = dim_out
self.parsed_conn = parsed_conn
self.queue = queue
- self.gconf = global_conf
- if self.gconf.use_cpu then
- self.mat_type = self.gconf.mmat_type
- else
- self.mat_type = self.gconf.cumat_type
- end
+end
+
+function DAGLayer:bind_params()
+ -- do nothing (instead of rebinding params for each layer)
end
function DAGLayer:init(batch_size, chunk_size)
@@ -325,7 +321,7 @@ function DAGLayer:get_params()
for id, ref in pairs(self.queue) do
table.insert(param_repos, ref.layer:get_params())
end
- return nerv.ParamRepo.merge(param_repos)
+ return nerv.ParamRepo.merge(param_repos, self.loc_type)
end
DAGLayer.PORT_TYPES = {
diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua
index 3d3a79f..acef54a 100644
--- a/nerv/nn/layer_repo.lua
+++ b/nerv/nn/layer_repo.lua
@@ -12,29 +12,29 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf)
if layer_type == nil then
nerv.error('layer type `%s` not found', ltype)
end
- for id, spec in pairs(llist) do
- if layers[id] ~= nil then
- nerv.error("a layer with id %s already exists", id)
- end
- nerv.info("create layer: %s", id)
- if type(spec[2]) ~= "table" then
+ for id, lconf in pairs(llist) do
+ if type(lconf) ~= "table" then
nerv.error("layer config table is need")
end
- layer_config = spec[2]
- if type(spec[1]) ~= "table" then
- nerv.error("parameter description table is needed")
- end
- for pname, pid in pairs(spec[1]) do
- layer_config[pname] = param_repo:get_param(pid)
+ if lconf.pr == nil then
+ lconf.pr = param_repo
end
- if layer_config.pr == nil then
- layer_config.pr = param_repo
+ if layers[id] ~= nil then
+ nerv.error("a layer with id %s already exists", id)
end
- layers[id] = layer_type(id, global_conf, layer_config)
+ nerv.info("create layer: %s", id)
+ layers[id] = layer_type(id, global_conf, lconf)
end
end
end
+function LayerRepo:rebind(param_repo)
+ for id, layer in pairs(self.layers) do
+ layer.lconf.pr = param_repo
+ layer:bind_params()
+ end
+end
+
function LayerRepo:get_layer(lid)
local layer = self.layers[lid]
if layer == nil then
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua
index c124e08..aba7765 100644
--- a/nerv/nn/param_repo.lua
+++ b/nerv/nn/param_repo.lua
@@ -1,8 +1,37 @@
local ParamRepo = nerv.class("nerv.ParamRepo")
-function ParamRepo:__init(plist)
+
+ParamRepo.LOC_TYPES = {
+ ON_DEVICE = {},
+ ON_HOST = {}
+}
+
+function ParamRepo:__init(plist, loc_type)
self.params = {}
+ self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
+ local function make_checker(tname)
+ return function (mat)
+ if not nerv.is_type(mat, tname) then
+ nerv.error("unexpected param type in repo specification")
+ end
+ end
+ end
+ self.make_copier = function (mat_type, copy_method)
+ return function (mat)
+ local target = mat_type(mat:nrow(), mat:ncol())
+ mat[copy_method](mat, target)
+ return target
+ end
+ end
+
+ if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then
+ self.checker = make_checker("nerv.MMatrix")
+ else
+ self.checker = make_checker("nerv.CuMatrix")
+ end
+
if plist ~= nil then
for i, p in ipairs(plist) do
+ p:check(self.checker)
self.params[p.id] = p
end
end
@@ -12,6 +41,7 @@ function ParamRepo:add(pid, p)
if self.params[pid] ~= nil then
nerv.error("duplicate params with the same id: %s", pid)
end
+ p:check(self.checker)
self.params[pid] = p
end
@@ -22,8 +52,8 @@ function ParamRepo:remove(pid, p)
table.remove(self.params, pid)
end
-function ParamRepo.merge(repos)
- local self = nerv.ParamRepo()
+function ParamRepo.merge(repos, loc_type)
+ local self = nerv.ParamRepo(nil, loc_type)
for i, repo in ipairs(repos) do
if not nerv.is_type(repo, "nerv.ParamRepo") then
nerv.error("nerv.ParamRepo objects expected, got %s", repo)
@@ -78,3 +108,26 @@ function ParamRepo:get_param(pid)
end
return p
end
+
+function ParamRepo:copy(loc_type, pids)
+ local copier
+ local target = nerv.ParamRepo(nil, loc_type)
+ if loc_type == nil then
+ loc_type = self.loc_type
+ end
+ if loc_type == ParamRepo.LOC_TYPES.ON_HOST then
+ copier = self.make_copier(gconf.mmat_type, 'copy_toh')
+ else
+ copier = self.make_copier(gconf.cumat_type, 'copy_tod')
+ end
+ if pids == nil then
+ for id, p in pairs(self.params) do
+ target.params[id] = p:copy(copier)
+ end
+ else
+ for i, pid in ipairs(pids) do
+ target.params[pid] = self:get_param(pid):copy(copier)
+ end
+ end
+ return target
+end