local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer')
function GraphLayer:__init(id, global_conf, layer_conf)
self.id = id
self.dim_in = layer_conf.dim_in
self.dim_out = layer_conf.dim_out
self.gconf = global_conf
self:graph_init(layer_conf.layer_repo, layer_conf.connections)
end
local function parse_id(str)
local id, port, _
_, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
if id == nil or port == nil then
_, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
if not (id == " " or id == "") then
nerv.error("wrong format of connection id")
end
end
port = tonumber(port)
return id, port
end
function GraphLayer:discover(id, layer_repo)
if id == '' then
id = ' '
end
local layers = self.layers
local ref = layers[id]
if ref == nil then
local layer = layer_repo:get_layer(id)
local dim_in, dim_out = layer:get_dim()
self.layer_num = self.layer_num + 1
ref = {
layer = layer,
inputs = {},
outputs = {},
dim_in = dim_in,
dim_out = dim_out,
id = self.layer_num,
}
layers[id] = ref
end
return ref
end
function GraphLayer:graph_init(layer_repo, connections)
local layers = {}
layers[' '] = {
inputs = {},
outputs = {},
dim_in = self.dim_out,
dim_out = self.dim_in,
id = 0,
}
self.layers = layers
self.layer_num = 0
self.connections = {}
-- check data dimension between connected ports
for _, edge in pairs(connections) do
local from, to, time = edge[1], edge[2], edge[3]
local id_from, port_from = parse_id(from)
local id_to, port_to = parse_id(to)
local ref_from = self:discover(id_from, layer_repo)
local ref_to = self:discover(id_to, layer_repo)
if ref_to.inputs[port_to] ~= nil then
nerv.error('%s has already been attached', to)
end
if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then
nerv.error('mismatching data dimension between %s and %s', from, to)
end
if ref_from.id == 0 and ref_to.id == 0 then
nerv.error('short-circuit connection between and ')
end
ref_from.outputs[port_from] = true
ref_to.inputs[port_to] = true
table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time})
end
-- check dangling ports
for id, ref in pairs(layers) do
if id ~= ' ' then
for i = 1, #ref.dim_in do
if ref.inputs[i] == nil then
nerv.error('dangling input port %d of layer %s', i, id)
end
end
for i = 1, #ref.dim_out do
if ref.outputs[i] == nil then
nerv.error('dangling output port %d os layer %s', i, id)
end
end
end
end
for i = 1, #self.dim_in do
if layers[' '].outputs[i] == nil then
nerv.error('dangling port %d of layer ', i)
end
end
for i = 1, #self.dim_out do
if layers[' '].inputs[i] == nil then
nerv.error('dangling port %d of layer ', i)
end
end
end
function GraphLayer:set_attr(name, value)
self[name] = value
for id, ref in pairs(self.layers) do
if id ~= ' ' then
ref.layer:set_attr(name, value)
end
end
end
function GraphLayer:get_sublayer(id)
if self.layers[id] == nil or id == ' ' then
nerv.error('layer with id %s not found', id)
end
return self.layers[id].layer
end
function GraphLayer:get_params()
local param_repos = {}
for id, ref in pairs(self.layers) do
if id ~= ' ' then
table.insert(param_repos, ref.layer:get_params())
end
end
return nerv.ParamRepo.merge(param_repos)
end