diff options
Diffstat (limited to 'nerv/layer/graph.lua')
-rw-r--r-- | nerv/layer/graph.lua | 46 |
1 files changed, 30 insertions, 16 deletions
diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua index 83cf810..36a9672 100644 --- a/nerv/layer/graph.lua +++ b/nerv/layer/graph.lua @@ -21,20 +21,23 @@ local function parse_id(str) return id, port end -local function discover(id, layers, layer_repo) +function GraphLayer:discover(id, layer_repo) if id == '<output>' then id = '<input>' end + local layers = self.layers local ref = layers[id] if ref == nil then local layer = layer_repo:get_layer(id) local dim_in, dim_out = layer:get_dim() + self.layer_num = self.layer_num + 1 ref = { layer = layer, inputs = {}, outputs = {}, dim_in = dim_in, dim_out = dim_out, + id = self.layer_num, } layers[id] = ref end @@ -42,32 +45,37 @@ local function discover(id, layers, layer_repo) end function GraphLayer:graph_init(layer_repo, connections) - self.connections = connections - self.sublayer = nerv.LayerRepo({}, nerv.ParamRepo(), self.gconf) - - -- check data dimension between connected ports local layers = {} layers['<input>'] = { inputs = {}, outputs = {}, dim_in = self.dim_out, dim_out = self.dim_in, + id = 0, } - for _, edge in pairs(self.connections) do - local from = edge[1] - local to = edge[2] + self.layers = layers + self.layer_num = 0 + self.connections = {} + + -- check data dimension between connected ports + for _, edge in pairs(connections) do + local from, to, time = edge[1], edge[2], edge[3] local id_from, port_from = parse_id(from) local id_to, port_to = parse_id(to) - local ref_from = discover(id_from, layers, layer_repo) - local ref_to = discover(id_to, layers, layer_repo) + local ref_from = self:discover(id_from, layer_repo) + local ref_to = self:discover(id_to, layer_repo) if ref_to.inputs[port_to] ~= nil then nerv.error('%s has already been attached', to) end if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then nerv.error('mismatching data dimension between %s and %s', from, to) end + if ref_from.id == 0 and ref_to.id == 0 then + nerv.error('short-circuit connection between <input> and <output>') + end ref_from.outputs[port_from] = true ref_to.inputs[port_to] = true + table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time}) end -- check dangling ports @@ -83,7 +91,6 @@ function GraphLayer:graph_init(layer_repo, connections) nerv.error('dangling output port %d os layer %s', i, id) end end - self.sublayer.layers[id] = ref.layer end end for i = 1, #self.dim_in do @@ -100,19 +107,26 @@ end function GraphLayer:set_attr(name, value) self[name] = value - for id, layer in pairs(self.sublayer.layers) do - layer:set_attr(name, value) + for id, ref in pairs(self.layers) do + if id ~= '<input>' then + ref.layer:set_attr(name, value) + end end end function GraphLayer:get_sublayer(id) - return self.sublayer:get_layer(id) + if self.layers[id] == nil or id == '<input>' then + nerv.error('layer with id %s not found', id) + end + return self.layers[id].layer end function GraphLayer:get_params() local param_repos = {} - for id, layer in pairs(self.sublayer.layers) do - table.insert(param_repos, layer:get_params()) + for id, ref in pairs(self.layers) do + if id ~= '<input>' then + table.insert(param_repos, ref.layer:get_params()) + end end return nerv.ParamRepo.merge(param_repos) end |