diff options
Diffstat (limited to 'nerv/layer/graph.lua')
-rw-r--r-- | nerv/layer/graph.lua | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua new file mode 100644 index 0000000..68d5f51 --- /dev/null +++ b/nerv/layer/graph.lua @@ -0,0 +1,156 @@ +local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer') + +function GraphLayer:__init(id, global_conf, layer_conf) + nerv.Layer.__init(self, id, global_conf, layer_conf) + self:graph_init(layer_conf.layer_repo, layer_conf.connections) +end + +local function parse_id(str) + local id, port, _ + _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]") + if id == nil or port == nil then + _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]") + if not (id == "<input>" or id == "<output>") then + nerv.error("wrong format of connection id") + end + end + port = tonumber(port) + return id, port +end + +function GraphLayer:add_prefix(layers, connections) + local function ap(name) + return self.id .. '.' .. name + end + + for layer_type, sublayers in pairs(layers) do + local tmp = {} + for name, layer_config in pairs(sublayers) do + tmp[ap(name)] = layer_config + end + layers[layer_type] = tmp + end + + for i = 1, #connections do + local from, to = connections[i][1], connections[i][2] + if parse_id(from) ~= '<input>' then + connections[i][1] = ap(from) + end + if parse_id(to) ~= '<output>' then + connections[i][2] = ap(to) + end + end +end + +function GraphLayer:discover(id, layer_repo) + if id == '<output>' then + id = '<input>' + end + local layers = self.layers + local ref = layers[id] + if ref == nil then + local layer = layer_repo:get_layer(id) + local dim_in, dim_out = layer:get_dim() + self.layer_num = self.layer_num + 1 + ref = { + layer = layer, + inputs = {}, + outputs = {}, + dim_in = dim_in, + dim_out = dim_out, + id = self.layer_num, + } + layers[id] = ref + end + return ref +end + +function GraphLayer:graph_init(layer_repo, connections) + local layers = {} + layers['<input>'] = { + inputs = {}, + outputs = {}, + dim_in = self.dim_out, + dim_out = self.dim_in, + id = 0, + } + self.layers = layers + self.layer_num = 0 + self.connections = {} + + -- check data dimension between connected ports + for _, edge in pairs(connections) do + local from, to, time = edge[1], edge[2], edge[3] + local id_from, port_from = parse_id(from) + local id_to, port_to = parse_id(to) + local ref_from = self:discover(id_from, layer_repo) + local ref_to = self:discover(id_to, layer_repo) + if ref_from.outputs[port_from] ~= nil then + nerv.error('%s has already been attached', from) + end + if ref_to.inputs[port_to] ~= nil then + nerv.error('%s has already been attached', to) + end + if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then + nerv.error('mismatching data dimension between %s and %s', from, to) + end + if ref_from.id == 0 and ref_to.id == 0 then + nerv.error('short-circuit connection between <input> and <output>') + end + ref_from.outputs[port_from] = true + ref_to.inputs[port_to] = true + table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time}) + end + + -- check dangling ports + for id, ref in pairs(layers) do + if id ~= '<input>' then + for i = 1, #ref.dim_in do + if ref.inputs[i] == nil then + nerv.error('dangling input port %d of layer %s', i, id) + end + end + for i = 1, #ref.dim_out do + if ref.outputs[i] == nil then + nerv.error('dangling output port %d of layer %s', i, id) + end + end + end + end + for i = 1, #self.dim_in do + if layers['<input>'].outputs[i] == nil then + nerv.error('dangling port %d of layer <input>', i) + end + end + for i = 1, #self.dim_out do + if layers['<input>'].inputs[i] == nil then + nerv.error('dangling port %d of layer <output>', i) + end + end +end + +function GraphLayer:set_attr(name, value) + self[name] = value + for id, ref in pairs(self.layers) do + if id ~= '<input>' then + ref.layer:set_attr(name, value) + end + end +end + +function GraphLayer:get_sublayer(id) + if self.layers[id] == nil or id == '<input>' then + nerv.error('layer with id %s not found', id) + end + return self.layers[id].layer +end + +function GraphLayer:get_params() + local param_repos = {} + for id, ref in pairs(self.layers) do + if id ~= '<input>' then + table.insert(param_repos, ref.layer:get_params()) + end + end + return nerv.ParamRepo.merge(param_repos, self.loc_type) +end |