aboutsummaryrefslogblamecommitdiff
path: root/nerv/layer/graph.lua
blob: 5b5d4c724d17257a1c6c8bdca70c2400b3c94371 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11










                                                                                 

                                                              



















































                                                                                        
                                                       
                                                        

                                                       














                                                                     























                                                             
                                            


                            
                              



                                               
                                           





                              
                                





                        





                                                  
                                                       



                               





                              
               
     






                                                        

                                                 

                                                           


                                                            





                                                                                


                                                                               

                                          
                                                                                          











                                                                           
                                                                           

                   















                                                               



                                           



                                    



                                                     



                                



                                                             
       
                                                           
   



                                    
--- Implements a special kind of layers having an internal structure, a
-- directed graph of connected sub-level layers.

--- The class describing the concept of a graph layer having an internal
-- structure, a directed graph of connected sub-level layers. Some of these
-- sub-level layers can again be graph layers, thus, it enables nested and
-- recursive layer declaration. The graph layer can be regarded as a container of
-- its sub-level layers. A layer other than a graph layer is also referenced as
-- "*primitive layer*".
-- @type nerv.GraphLayer

local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer')

--- The constructor.
-- @param id the identifier
-- @param global_conf see `self.gconf` of `nerv.Layer.__init`
-- @param layer_conf a table providing with settings dedicated for the layer,
-- the following fields should be specified:
--
-- * `lrepo`: the layer repo that should be used to find the sub-level layers
-- * `connections`: an array of 3-tuples describing the connections of
--   sub-level layers, the structure is as follow:
--
--        {
--            {<from_port1>, <to_port1>, <time_shift1>}, -- tuple 1
--            {<from_port2>, <to_port2>, <time_shift2>}, -- tuple 2
--            {<from_port3>, <to_port3>, <time_shift3>}, -- tuple 3
--            ...
--        }
--   Each tuple stands for a directed edge between two ports. The first two
--   elements in the tuple are called *port specification* which is a string
--   with the following format:
--
--        <layer_id>[<port_idx>]
--   where the `<layer_id>` is a string that identifies the layer in
--   `lconf.lrepo`, and `<port_id>` is the input or output port index when used
--   in the first or second port specification respectively.
--
--   The third element in the tuple is an integer specifying the time delay of
--   this connection. In most cases, it will be simply zero. But for an
--   recurrent network, a positive value `i` means the output from `<from_port>`
--   will be used as the input to `<to_port>` in `i`th computation of the future.
--   Negative values are also allowed to propagate the output to the past.
--
--   Note that there are two possible strings of `<layer_id>` that have special
--   meanings: the string `"<input>"` and `"<output>"` are placeholders of the
--   the input and output ports of the outer graph layer. The input for the graph
--   layer as a whole can be used  by establishing connections from
--   `"<input>[i]"`, and vice versa for the output.
--
--   As an example, tuples:
--
--        {
--            {"<input>[1]", "affine0[1]", 0},
--            {"affine0[1]", "sigmoid0[1]", 0},
--            {"sigmoid0[1]", "affine1[1]", 0},
--            {"affine1[1]", "<output>[1]", 0}
--        }
--   Specify a graph layer that contains two stacked and fully connected linear
--   transformation sub-level layers.
--
-- * `reversed`: optional, reverse the time shifting of all connections if true
--
-- For other `layer_conf` fields that are shared by all layers, see `nerv.Layer.__init`.

function GraphLayer:__init(id, global_conf, layer_conf)
    nerv.Layer.__init(self, id, global_conf, layer_conf)
    self.lrepo = layer_conf.layer_repo
    self:graph_init(self.lrepo, layer_conf.connections)
end

local function parse_id(str)
    local id, port, _
    _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
    if id == nil or port == nil then
        _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
        if not (id == "<input>" or id == "<output>") then
            nerv.error("wrong format of connection id")
        end
    end
    port = tonumber(port)
    return id, port
end

function GraphLayer:add_prefix(layers, connections)
    local function ap(name)
        return self.id .. '.' .. name
    end

    for layer_type, sublayers in pairs(layers) do
        local tmp = {}
        for name, layer_config in pairs(sublayers) do
            tmp[ap(name)] = layer_config
        end
        layers[layer_type] = tmp
    end

    for i = 1, #connections do
        local from, to = connections[i][1], connections[i][2]
        if parse_id(from) ~= '<input>' then
            connections[i][1] = ap(from)
        end
        if parse_id(to) ~= '<output>' then
            connections[i][2] = ap(to)
        end
    end
end

function GraphLayer:discover(id, layer_repo)
    if id == '<output>' then
        id = '<input>'
    end
    local layers = self.layers
    local ref = layers[id]
    if ref == nil then
        local layer = layer_repo:get_layer(id)
        local dim_in, dim_out = layer:get_dim()
        self.layer_num = self.layer_num + 1
        ref = {
            layer = layer,
            inputs = {},
            outputs = {},
            dim_in = dim_in,
            dim_out = dim_out,
            id = self.layer_num,
        }
        layers[id] = ref
    end
    return ref
end

local function reverse(connections)
    for i = 1, #connections do
        connections[i][3] = connections[i][3] * -1
    end
end

function GraphLayer:graph_init(layer_repo, connections)
    if self.lconf.reversed then
        reverse(connections)
    end

    local layers = {}
    layers['<input>'] = {
        inputs = {},
        outputs = {},
        dim_in = self.dim_out,
        dim_out = self.dim_in,
        id = 0,
    }
    self.layers = layers
    self.layer_num = 0
    self.connections = {}

    -- check data dimension between connected ports
    for _, edge in pairs(connections) do
        local from, to, time = edge[1], edge[2], edge[3]
        local id_from, port_from = parse_id(from)
        local id_to, port_to = parse_id(to)
        local ref_from = self:discover(id_from, layer_repo)
        local ref_to = self:discover(id_to, layer_repo)
        if ref_from.outputs[port_from] ~= nil then
            nerv.error('%s has already been attached', from)
        end
        if ref_to.inputs[port_to] ~= nil then
            nerv.error('%s has already been attached', to)
        end
        if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then
            nerv.error('mismatching data dimension between %s and %s', from, to)
        end
        if ref_from.id == 0 and ref_to.id == 0 then
            nerv.error('short-circuit connection between <input> and <output>')
        end
        ref_from.outputs[port_from] = true
        ref_to.inputs[port_to] = true
        table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time})
    end

    -- check dangling ports
    for id, ref in pairs(layers) do
        if id ~= '<input>' then
            for i = 1, #ref.dim_in do
                if ref.inputs[i] == nil then
                    nerv.error('dangling input port %d of layer %s', i, id)
                end
            end
            for i = 1, #ref.dim_out do
                if ref.outputs[i] == nil then
                   nerv.error('dangling output port %d of layer %s', i, id)
                end
            end
        end
    end
    for i = 1, #self.dim_in do
        if layers['<input>'].outputs[i] == nil then
            nerv.error('dangling port %d of layer <input>', i)
        end
    end
    for i = 1, #self.dim_out do
        if layers['<input>'].inputs[i] == nil then
            nerv.error('dangling port %d of layer <output>', i)
        end
    end
end

function GraphLayer:set_attr(name, value)
    self[name] = value
    for id, ref in pairs(self.layers) do
        if id ~= '<input>' then
            ref.layer:set_attr(name, value)
        end
    end
end

function GraphLayer:get_sublayer(id)
    if self.layers[id] == nil or id == '<input>' then
        nerv.error('layer with id %s not found', id)
    end
    return self.layers[id].layer
end

function GraphLayer:get_params()
    local param_repos = {}
    for id, ref in pairs(self.layers) do
        if id ~= '<input>' then
            table.insert(param_repos, ref.layer:get_params())
        end
    end
    return nerv.ParamRepo.merge(param_repos, self.loc_type)
end

function GraphLayer:bind_params()
    self.lrepo:rebind(self.lconf.pr)
end