aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/network.lua
blob: 6cee08b5face7940236828da193ffa2bc888c078 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
local network = nerv.class('nerv.Network')

function network:__init(graph)
    self.layers = {}
    self.socket = self:compile(graph)
    for i = 1, #self.layers do
        print(self.layers[i].layer.id)
        local _, dim_out = self.layers[i].layer:get_dim()
        for j = 1, #dim_out do
            for k = 1, #self.layers[i].connections[j] do
                local connections = self.layers[i].connections[j][k]
                print(i, connections[1], connections[2], connections[3])
            end
        end
    end
end

function network:compile(layer)
    local socket = {inputs = {}, outputs = {}}
    if not nerv.is_type(layer, 'nerv.GraphLayer') then
        table.insert(self.layers, {layer = layer, connections = {}})
        local id = #self.layers
        local dim_in, dim_out = layer:get_dim()
        for i = 1, #dim_in do
            socket.inputs[i] = {{id, i, 0}}
        end
        for i = 1, #dim_out do
            socket.outputs[i] = {id, i, 0}
            self.layers[id].connections[i] = {}
        end
    else
        local sublayer_socket = {}
        for id, sublayer in pairs(layer.layers) do
            if id ~= '<input>' then
               sublayer_socket[sublayer.id] = self:compile(sublayer.layer)
            end
        end
        local dim_in, _ = layer:get_dim()
        for i = 1, #dim_in do
            socket.inputs[i] = {}
        end
        for _, edge in pairs(layer.connections) do
            -- id = 0 means <input> or <output>
            local id_from, port_from = edge[1], edge[2]
            local id_to, port_to = edge[3], edge[4]
            local time = edge[5]
            if id_from == 0 then
                for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do
                    local id, port, t = input[1], input[2], input[3] + time
                    table.insert(socket.inputs[port_from], {id, port, t})
                end
            else
                local output = sublayer_socket[id_from].outputs[port_from]
                local id, port, t = output[1], output[2], output[3] + time
                if id_to == 0 then
                    socket.outputs[port_to] = {id, port, t}
                else
                    local connections = self.layers[id].connections[port]
                    for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do
                        local id1, port1, t1 = input[1], input[2], input[3]
                        table.insert(connections, {id1, port1, t + t1})
                    end
                end
            end
        end
    end
    return socket
end