aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/graph.lua
diff options
context:
space:
mode:
authorQi Liu <liuq901@163.com>2016-03-03 19:42:15 +0800
committerQi Liu <liuq901@163.com>2016-03-03 19:42:15 +0800
commit8374e8fbc545633b6adf5c4090af8997a65778d2 (patch)
tree9d959337628686b2b9ece9016a92ea55d40c0d31 /nerv/layer/graph.lua
parentc682dfee8686c43aed8628633412c9b4d2bd708b (diff)
update add_prefix for graph layer
Diffstat (limited to 'nerv/layer/graph.lua')
-rw-r--r--nerv/layer/graph.lua24
1 files changed, 24 insertions, 0 deletions
diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua
index d72d849..1406eff 100644
--- a/nerv/layer/graph.lua
+++ b/nerv/layer/graph.lua
@@ -21,6 +21,30 @@ local function parse_id(str)
return id, port
end
+function GraphLayer:add_prefix(layers, connections)
+ local function ap(name)
+ return self.id .. '.' .. name
+ end
+
+ for layer_type, sublayers in pairs(layers) do
+ local tmp = {}
+ for name, layer_config in pairs(sublayers) do
+ tmp[ap(name)] = layer_config
+ end
+ layers[layer_type] = tmp
+ end
+
+ for i = 1, #connections do
+ local from, to = connections[i][1], connections[i][2]
+ if parse_id(from) ~= '<input>' then
+ connections[i][1] = ap(from)
+ end
+ if parse_id(to) ~= '<output>' then
+ connections[i][2] = ap(to)
+ end
+ end
+end
+
function GraphLayer:discover(id, layer_repo)
if id == '<output>' then
id = '<input>'