summaryrefslogtreecommitdiff
path: root/nerv/nn/layer_repo.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/layer_repo.lua')
-rw-r--r--nerv/nn/layer_repo.lua28
1 files changed, 15 insertions, 13 deletions
diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua
index a169b2b..acef54a 100644
--- a/nerv/nn/layer_repo.lua
+++ b/nerv/nn/layer_repo.lua
@@ -12,27 +12,29 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf)
if layer_type == nil then
nerv.error('layer type `%s` not found', ltype)
end
- for id, layer_config in pairs(llist) do
- if layers[id] ~= nil then
- nerv.error("a layer with id %s already exists", id)
- end
- nerv.info("create layer: %s", id)
- if type(layer_config) ~= "table" then
+ for id, lconf in pairs(llist) do
+ if type(lconf) ~= "table" then
nerv.error("layer config table is need")
end
- if type(layer_config.params) == "table" then
- for pname, pid in pairs(layer_config.params) do
- layer_config[pname] = param_repo:get_param(pid)
- end
+ if lconf.pr == nil then
+ lconf.pr = param_repo
end
- if layer_config.pr == nil then
- layer_config.pr = param_repo
+ if layers[id] ~= nil then
+ nerv.error("a layer with id %s already exists", id)
end
- layers[id] = layer_type(id, global_conf, layer_config)
+ nerv.info("create layer: %s", id)
+ layers[id] = layer_type(id, global_conf, lconf)
end
end
end
+function LayerRepo:rebind(param_repo)
+ for id, layer in pairs(self.layers) do
+ layer.lconf.pr = param_repo
+ layer:bind_params()
+ end
+end
+
function LayerRepo:get_layer(lid)
local layer = self.layers[lid]
if layer == nil then