local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer') function GraphLayer:__init(id, global_conf, layer_conf) self.id = id self.dim_in = layer_conf.dim_in self.dim_out = layer_conf.dim_out self.gconf = global_conf self:graph_init(layer_conf.layer_repo, layer_conf.connections) end local function parse_id(str) local id, port, _ _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]") if id == nil or port == nil then _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]") if not (id == "" or id == "") then nerv.error("wrong format of connection id") end end port = tonumber(port) return id, port end local function discover(id, layers, layer_repo) if id == '' then id = '' end local ref = layers[id] if ref == nil then local layer = layer_repo:get_layer(id) local dim_in, dim_out = layer:get_dim() ref = { layer = layer, inputs = {}, outputs = {}, dim_in = dim_in, dim_out = dim_out, } layers[id] = ref end return ref end function GraphLayer:graph_init(layer_repo, connections) self.connections = connections self.sublayer = nerv.LayerRepo({}, nerv.ParamRepo(), self.gconf) -- check data dimension between connected ports local layers = {} layers[''] = { inputs = {}, outputs = {}, dim_in = self.dim_out, dim_out = self.dim_in, } for _, edge in pairs(self.connections) do local from = edge[1] local to = edge[2] local id_from, port_from = parse_id(from) local id_to, port_to = parse_id(to) local ref_from = discover(id_from, layers, layer_repo) local ref_to = discover(id_to, layers, layer_repo) if ref_to.inputs[port_to] ~= nil then nerv.error('%s has already been attached', to) end if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then nerv.error('mismatching data dimension between %s and %s', from, to) end ref_from.outputs[port_from] = true ref_to.inputs[port_to] = true end -- check dangling ports for id, ref in pairs(layers) do if id ~= '' then for i = 1, #ref.dim_in do if ref.inputs[i] == nil then nerv.error('dangling input port %d of layer %s', i, id) end end for i = 1, #ref.dim_out do if ref.outputs[i] == nil then nerv.error('dangling output port %d os layer %s', i, id) end end self.sublayer.layers[id] = ref.layer end end for i = 1, #self.dim_in do if layers[''].outputs[i] == nil then nerv.error('dangling port %d of layer ', i) end end for i = 1, #self.dim_out do if layers[''].inputs[i] == nil then nerv.error('dangling port %d of layer ', i) end end end function GraphLayer:set_attr(name, value) self[name] = value for id, layer in pairs(self.sublayer.layers) do layer:set_attr(name, value) end end function GraphLayer:get_sublayer(id) return self.sublayer:get_layer(id) end function GraphLayer:get_params() local param_repos = {} for id, layer in pairs(self.sublayer.layers) do table.insert(param_repos, layer:get_params()) end return nerv.ParamRepo.merge(param_repos) end