diff options
Diffstat (limited to 'nerv/nn/layer_repo.lua')
-rw-r--r-- | nerv/nn/layer_repo.lua | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua index 3d3a79f..acef54a 100644 --- a/nerv/nn/layer_repo.lua +++ b/nerv/nn/layer_repo.lua @@ -12,29 +12,29 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf) if layer_type == nil then nerv.error('layer type `%s` not found', ltype) end - for id, spec in pairs(llist) do - if layers[id] ~= nil then - nerv.error("a layer with id %s already exists", id) - end - nerv.info("create layer: %s", id) - if type(spec[2]) ~= "table" then + for id, lconf in pairs(llist) do + if type(lconf) ~= "table" then nerv.error("layer config table is need") end - layer_config = spec[2] - if type(spec[1]) ~= "table" then - nerv.error("parameter description table is needed") - end - for pname, pid in pairs(spec[1]) do - layer_config[pname] = param_repo:get_param(pid) + if lconf.pr == nil then + lconf.pr = param_repo end - if layer_config.pr == nil then - layer_config.pr = param_repo + if layers[id] ~= nil then + nerv.error("a layer with id %s already exists", id) end - layers[id] = layer_type(id, global_conf, layer_config) + nerv.info("create layer: %s", id) + layers[id] = layer_type(id, global_conf, lconf) end end end +function LayerRepo:rebind(param_repo) + for id, layer in pairs(self.layers) do + layer.lconf.pr = param_repo + layer:bind_params() + end +end + function LayerRepo:get_layer(lid) local layer = self.layers[lid] if layer == nil then |