aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn
diff options
context:
space:
mode:
authorQi Liu <liuq901@163.com>2016-02-29 17:46:09 +0800
committerQi Liu <liuq901@163.com>2016-02-29 17:46:09 +0800
commit77b558898a2a29097d8697a59a7d23cd2a52975f (patch)
tree06bab2379224a6d06bd6b9c60468597e1fbe6e1e /nerv/nn
parent550680eacd00555817df19d2b59a20a92df77c42 (diff)
graph layer complete
Diffstat (limited to 'nerv/nn')
-rw-r--r--nerv/nn/layer_repo.lua14
1 files changed, 6 insertions, 8 deletions
diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua
index 3d3a79f..a169b2b 100644
--- a/nerv/nn/layer_repo.lua
+++ b/nerv/nn/layer_repo.lua
@@ -12,20 +12,18 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf)
if layer_type == nil then
nerv.error('layer type `%s` not found', ltype)
end
- for id, spec in pairs(llist) do
+ for id, layer_config in pairs(llist) do
if layers[id] ~= nil then
nerv.error("a layer with id %s already exists", id)
end
nerv.info("create layer: %s", id)
- if type(spec[2]) ~= "table" then
+ if type(layer_config) ~= "table" then
nerv.error("layer config table is need")
end
- layer_config = spec[2]
- if type(spec[1]) ~= "table" then
- nerv.error("parameter description table is needed")
- end
- for pname, pid in pairs(spec[1]) do
- layer_config[pname] = param_repo:get_param(pid)
+ if type(layer_config.params) == "table" then
+ for pname, pid in pairs(layer_config.params) do
+ layer_config[pname] = param_repo:get_param(pid)
+ end
end
if layer_config.pr == nil then
layer_config.pr = param_repo