local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer') function GraphLayer:__init(id, global_conf, layer_conf) nerv.Layer.__init(self, id, global_conf, layer_conf) self:graph_init(layer_conf.layer_repo, layer_conf.connections) end local function parse_id(str) local id, port, _ _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]") if id == nil or port == nil then _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]") if not (id == "" or id == "") then nerv.error("wrong format of connection id") end end port = tonumber(port) return id, port end function GraphLayer:add_prefix(layers, connections) local function ap(name) return self.id .. '.' .. name end for layer_type, sublayers in pairs(layers) do local tmp = {} for name, layer_config in pairs(sublayers) do tmp[ap(name)] = layer_config end layers[layer_type] = tmp end for i = 1, #connections do local from, to = connections[i][1], connections[i][2] if parse_id(from) ~= '' then connections[i][1] = ap(from) end if parse_id(to) ~= '' then connections[i][2] = ap(to) end end end function GraphLayer:discover(id, layer_repo) if id == '' then id = '' end local layers = self.layers local ref = layers[id] if ref == nil then local layer = layer_repo:get_layer(id) local dim_in, dim_out = layer:get_dim() self.layer_num = self.layer_num + 1 ref = { layer = layer, inputs = {}, outputs = {}, dim_in = dim_in, dim_out = dim_out, id = self.layer_num, } layers[id] = ref end return ref end local function reverse(connections) for i = 1, #connections do connections[i][3] = connections[i][3] * -1 end end function GraphLayer:graph_init(layer_repo, connections) if self.lconf.reversed then reverse(connections) end local layers = {} layers[''] = { inputs = {}, outputs = {}, dim_in = self.dim_out, dim_out = self.dim_in, id = 0, } self.layers = layers self.layer_num = 0 self.connections = {} -- check data dimension between connected ports for _, edge in pairs(connections) do local from, to, time = edge[1], edge[2], edge[3] local id_from, port_from = parse_id(from) local id_to, port_to = parse_id(to) local ref_from = self:discover(id_from, layer_repo) local ref_to = self:discover(id_to, layer_repo) if ref_from.outputs[port_from] ~= nil then nerv.error('%s has already been attached', from) end if ref_to.inputs[port_to] ~= nil then nerv.error('%s has already been attached', to) end if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then nerv.error('mismatching data dimension between %s and %s', from, to) end if ref_from.id == 0 and ref_to.id == 0 then nerv.error('short-circuit connection between and ') end ref_from.outputs[port_from] = true ref_to.inputs[port_to] = true table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time}) end -- check dangling ports for id, ref in pairs(layers) do if id ~= '' then for i = 1, #ref.dim_in do if ref.inputs[i] == nil then nerv.error('dangling input port %d of layer %s', i, id) end end for i = 1, #ref.dim_out do if ref.outputs[i] == nil then nerv.error('dangling output port %d of layer %s', i, id) end end end end for i = 1, #self.dim_in do if layers[''].outputs[i] == nil then nerv.error('dangling port %d of layer ', i) end end for i = 1, #self.dim_out do if layers[''].inputs[i] == nil then nerv.error('dangling port %d of layer ', i) end end end function GraphLayer:set_attr(name, value) self[name] = value for id, ref in pairs(self.layers) do if id ~= '' then ref.layer:set_attr(name, value) end end end function GraphLayer:get_sublayer(id) if self.layers[id] == nil or id == '' then nerv.error('layer with id %s not found', id) end return self.layers[id].layer end function GraphLayer:get_params() local param_repos = {} for id, ref in pairs(self.layers) do if id ~= '' then table.insert(param_repos, ref.layer:get_params()) end end return nerv.ParamRepo.merge(param_repos, self.loc_type) end