aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/graph.lua
blob: 83cf8103e9a62b408a6093493d0f4047b422b1da (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer')

function GraphLayer:__init(id, global_conf, layer_conf)
    self.id = id
    self.dim_in = layer_conf.dim_in
    self.dim_out = layer_conf.dim_out
    self.gconf = global_conf
    self:graph_init(layer_conf.layer_repo, layer_conf.connections)
end

local function parse_id(str)
    local id, port, _
    _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
    if id == nil or port == nil then
        _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
        if not (id == "<input>" or id == "<output>") then
            nerv.error("wrong format of connection id")
        end
    end
    port = tonumber(port)
    return id, port
end

local function discover(id, layers, layer_repo)
    if id == '<output>' then
        id = '<input>'
    end
    local ref = layers[id]
    if ref == nil then
        local layer = layer_repo:get_layer(id)
        local dim_in, dim_out = layer:get_dim()
        ref = {
            layer = layer,
            inputs = {},
            outputs = {},
            dim_in = dim_in,
            dim_out = dim_out,
        }
        layers[id] = ref
    end
    return ref
end

function GraphLayer:graph_init(layer_repo, connections)
    self.connections = connections
    self.sublayer = nerv.LayerRepo({}, nerv.ParamRepo(), self.gconf)

    -- check data dimension between connected ports
    local layers = {}
    layers['<input>'] = {
        inputs = {},
        outputs = {},
        dim_in = self.dim_out,
        dim_out = self.dim_in,
    }
    for _, edge in pairs(self.connections) do
        local from = edge[1]
        local to = edge[2]
        local id_from, port_from = parse_id(from)
        local id_to, port_to = parse_id(to)
        local ref_from = discover(id_from, layers, layer_repo)
        local ref_to = discover(id_to, layers, layer_repo)
        if ref_to.inputs[port_to] ~= nil then
            nerv.error('%s has already been attached', to)
        end
        if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then
            nerv.error('mismatching data dimension between %s and %s', from, to)
        end
        ref_from.outputs[port_from] = true
        ref_to.inputs[port_to] = true
    end

    -- check dangling ports
    for id, ref in pairs(layers) do
        if id ~= '<input>' then
            for i = 1, #ref.dim_in do
                if ref.inputs[i] == nil then
                    nerv.error('dangling input port %d of layer %s', i, id)
                end
            end
            for i = 1, #ref.dim_out do
                if ref.outputs[i] == nil then
                   nerv.error('dangling output port %d os layer %s', i, id)
                end
            end
            self.sublayer.layers[id] = ref.layer
        end
    end
    for i = 1, #self.dim_in do
        if layers['<input>'].outputs[i] == nil then
            nerv.error('dangling port %d of layer <input>', i)
        end
    end
    for i = 1, #self.dim_out do
        if layers['<input>'].inputs[i] == nil then
            nerv.error('dangling port %d of layer <output>', i)
        end
    end
end

function GraphLayer:set_attr(name, value)
    self[name] = value
    for id, layer in pairs(self.sublayer.layers) do
        layer:set_attr(name, value)
    end
end

function GraphLayer:get_sublayer(id)
    return self.sublayer:get_layer(id)
end

function GraphLayer:get_params()
    local param_repos = {}
    for id, layer in pairs(self.sublayer.layers) do
        table.insert(param_repos, layer:get_params())
    end
    return nerv.ParamRepo.merge(param_repos)
end