aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/graph.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer/graph.lua')
-rw-r--r--nerv/layer/graph.lua46
1 files changed, 30 insertions, 16 deletions
diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua
index 83cf810..36a9672 100644
--- a/nerv/layer/graph.lua
+++ b/nerv/layer/graph.lua
@@ -21,20 +21,23 @@ local function parse_id(str)
return id, port
end
-local function discover(id, layers, layer_repo)
+function GraphLayer:discover(id, layer_repo)
if id == '<output>' then
id = '<input>'
end
+ local layers = self.layers
local ref = layers[id]
if ref == nil then
local layer = layer_repo:get_layer(id)
local dim_in, dim_out = layer:get_dim()
+ self.layer_num = self.layer_num + 1
ref = {
layer = layer,
inputs = {},
outputs = {},
dim_in = dim_in,
dim_out = dim_out,
+ id = self.layer_num,
}
layers[id] = ref
end
@@ -42,32 +45,37 @@ local function discover(id, layers, layer_repo)
end
function GraphLayer:graph_init(layer_repo, connections)
- self.connections = connections
- self.sublayer = nerv.LayerRepo({}, nerv.ParamRepo(), self.gconf)
-
- -- check data dimension between connected ports
local layers = {}
layers['<input>'] = {
inputs = {},
outputs = {},
dim_in = self.dim_out,
dim_out = self.dim_in,
+ id = 0,
}
- for _, edge in pairs(self.connections) do
- local from = edge[1]
- local to = edge[2]
+ self.layers = layers
+ self.layer_num = 0
+ self.connections = {}
+
+ -- check data dimension between connected ports
+ for _, edge in pairs(connections) do
+ local from, to, time = edge[1], edge[2], edge[3]
local id_from, port_from = parse_id(from)
local id_to, port_to = parse_id(to)
- local ref_from = discover(id_from, layers, layer_repo)
- local ref_to = discover(id_to, layers, layer_repo)
+ local ref_from = self:discover(id_from, layer_repo)
+ local ref_to = self:discover(id_to, layer_repo)
if ref_to.inputs[port_to] ~= nil then
nerv.error('%s has already been attached', to)
end
if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then
nerv.error('mismatching data dimension between %s and %s', from, to)
end
+ if ref_from.id == 0 and ref_to.id == 0 then
+ nerv.error('short-circuit connection between <input> and <output>')
+ end
ref_from.outputs[port_from] = true
ref_to.inputs[port_to] = true
+ table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time})
end
-- check dangling ports
@@ -83,7 +91,6 @@ function GraphLayer:graph_init(layer_repo, connections)
nerv.error('dangling output port %d os layer %s', i, id)
end
end
- self.sublayer.layers[id] = ref.layer
end
end
for i = 1, #self.dim_in do
@@ -100,19 +107,26 @@ end
function GraphLayer:set_attr(name, value)
self[name] = value
- for id, layer in pairs(self.sublayer.layers) do
- layer:set_attr(name, value)
+ for id, ref in pairs(self.layers) do
+ if id ~= '<input>' then
+ ref.layer:set_attr(name, value)
+ end
end
end
function GraphLayer:get_sublayer(id)
- return self.sublayer:get_layer(id)
+ if self.layers[id] == nil or id == '<input>' then
+ nerv.error('layer with id %s not found', id)
+ end
+ return self.layers[id].layer
end
function GraphLayer:get_params()
local param_repos = {}
- for id, layer in pairs(self.sublayer.layers) do
- table.insert(param_repos, layer:get_params())
+ for id, ref in pairs(self.layers) do
+ if id ~= '<input>' then
+ table.insert(param_repos, ref.layer:get_params())
+ end
end
return nerv.ParamRepo.merge(param_repos)
end