1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer')
function GraphLayer:__init(id, global_conf, layer_conf)
self.id = id
self.dim_in = layer_conf.dim_in
self.dim_out = layer_conf.dim_out
self.gconf = global_conf
self:graph_init(layer_conf.layer_repo, layer_conf.connections)
end
local function parse_id(str)
local id, port, _
_, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
if id == nil or port == nil then
_, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
if not (id == "<input>" or id == "<output>") then
nerv.error("wrong format of connection id")
end
end
port = tonumber(port)
return id, port
end
function GraphLayer:discover(id, layer_repo)
if id == '<output>' then
id = '<input>'
end
local layers = self.layers
local ref = layers[id]
if ref == nil then
local layer = layer_repo:get_layer(id)
local dim_in, dim_out = layer:get_dim()
self.layer_num = self.layer_num + 1
ref = {
layer = layer,
inputs = {},
outputs = {},
dim_in = dim_in,
dim_out = dim_out,
id = self.layer_num,
}
layers[id] = ref
end
return ref
end
function GraphLayer:graph_init(layer_repo, connections)
local layers = {}
layers['<input>'] = {
inputs = {},
outputs = {},
dim_in = self.dim_out,
dim_out = self.dim_in,
id = 0,
}
self.layers = layers
self.layer_num = 0
self.connections = {}
-- check data dimension between connected ports
for _, edge in pairs(connections) do
local from, to, time = edge[1], edge[2], edge[3]
local id_from, port_from = parse_id(from)
local id_to, port_to = parse_id(to)
local ref_from = self:discover(id_from, layer_repo)
local ref_to = self:discover(id_to, layer_repo)
if ref_to.inputs[port_to] ~= nil then
nerv.error('%s has already been attached', to)
end
if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then
nerv.error('mismatching data dimension between %s and %s', from, to)
end
if ref_from.id == 0 and ref_to.id == 0 then
nerv.error('short-circuit connection between <input> and <output>')
end
ref_from.outputs[port_from] = true
ref_to.inputs[port_to] = true
table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time})
end
-- check dangling ports
for id, ref in pairs(layers) do
if id ~= '<input>' then
for i = 1, #ref.dim_in do
if ref.inputs[i] == nil then
nerv.error('dangling input port %d of layer %s', i, id)
end
end
for i = 1, #ref.dim_out do
if ref.outputs[i] == nil then
nerv.error('dangling output port %d os layer %s', i, id)
end
end
end
end
for i = 1, #self.dim_in do
if layers['<input>'].outputs[i] == nil then
nerv.error('dangling port %d of layer <input>', i)
end
end
for i = 1, #self.dim_out do
if layers['<input>'].inputs[i] == nil then
nerv.error('dangling port %d of layer <output>', i)
end
end
end
function GraphLayer:set_attr(name, value)
self[name] = value
for id, ref in pairs(self.layers) do
if id ~= '<input>' then
ref.layer:set_attr(name, value)
end
end
end
function GraphLayer:get_sublayer(id)
if self.layers[id] == nil or id == '<input>' then
nerv.error('layer with id %s not found', id)
end
return self.layers[id].layer
end
function GraphLayer:get_params()
local param_repos = {}
for id, ref in pairs(self.layers) do
if id ~= '<input>' then
table.insert(param_repos, ref.layer:get_params())
end
end
return nerv.ParamRepo.merge(param_repos)
end
|