1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
|
local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer')
function GraphLayer:__init(id, global_conf, layer_conf)
nerv.Layer.__init(self, id, global_conf, layer_conf)
self.lrepo = layer_conf.layer_repo
self:graph_init(self.lrepo, layer_conf.connections)
end
local function parse_id(str)
local id, port, _
_, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
if id == nil or port == nil then
_, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
if not (id == "<input>" or id == "<output>") then
nerv.error("wrong format of connection id")
end
end
port = tonumber(port)
return id, port
end
function GraphLayer:add_prefix(layers, connections)
local function ap(name)
return self.id .. '.' .. name
end
for layer_type, sublayers in pairs(layers) do
local tmp = {}
for name, layer_config in pairs(sublayers) do
tmp[ap(name)] = layer_config
end
layers[layer_type] = tmp
end
for i = 1, #connections do
local from, to = connections[i][1], connections[i][2]
if parse_id(from) ~= '<input>' then
connections[i][1] = ap(from)
end
if parse_id(to) ~= '<output>' then
connections[i][2] = ap(to)
end
end
end
function GraphLayer:discover(id, layer_repo)
if id == '<output>' then
id = '<input>'
end
local layers = self.layers
local ref = layers[id]
if ref == nil then
local layer = layer_repo:get_layer(id)
local dim_in, dim_out = layer:get_dim()
self.layer_num = self.layer_num + 1
ref = {
layer = layer,
inputs = {},
outputs = {},
dim_in = dim_in,
dim_out = dim_out,
id = self.layer_num,
}
layers[id] = ref
end
return ref
end
local function reverse(connections)
for i = 1, #connections do
connections[i][3] = connections[i][3] * -1
end
end
function GraphLayer:graph_init(layer_repo, connections)
if self.lconf.reversed then
reverse(connections)
end
local layers = {}
layers['<input>'] = {
inputs = {},
outputs = {},
dim_in = self.dim_out,
dim_out = self.dim_in,
id = 0,
}
self.layers = layers
self.layer_num = 0
self.connections = {}
-- check data dimension between connected ports
for _, edge in pairs(connections) do
local from, to, time = edge[1], edge[2], edge[3]
local id_from, port_from = parse_id(from)
local id_to, port_to = parse_id(to)
local ref_from = self:discover(id_from, layer_repo)
local ref_to = self:discover(id_to, layer_repo)
if ref_from.outputs[port_from] ~= nil then
nerv.error('%s has already been attached', from)
end
if ref_to.inputs[port_to] ~= nil then
nerv.error('%s has already been attached', to)
end
if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then
nerv.error('mismatching data dimension between %s and %s', from, to)
end
if ref_from.id == 0 and ref_to.id == 0 then
nerv.error('short-circuit connection between <input> and <output>')
end
ref_from.outputs[port_from] = true
ref_to.inputs[port_to] = true
table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time})
end
-- check dangling ports
for id, ref in pairs(layers) do
if id ~= '<input>' then
for i = 1, #ref.dim_in do
if ref.inputs[i] == nil then
nerv.error('dangling input port %d of layer %s', i, id)
end
end
for i = 1, #ref.dim_out do
if ref.outputs[i] == nil then
nerv.error('dangling output port %d of layer %s', i, id)
end
end
end
end
for i = 1, #self.dim_in do
if layers['<input>'].outputs[i] == nil then
nerv.error('dangling port %d of layer <input>', i)
end
end
for i = 1, #self.dim_out do
if layers['<input>'].inputs[i] == nil then
nerv.error('dangling port %d of layer <output>', i)
end
end
end
function GraphLayer:set_attr(name, value)
self[name] = value
for id, ref in pairs(self.layers) do
if id ~= '<input>' then
ref.layer:set_attr(name, value)
end
end
end
function GraphLayer:get_sublayer(id)
if self.layers[id] == nil or id == '<input>' then
nerv.error('layer with id %s not found', id)
end
return self.layers[id].layer
end
function GraphLayer:get_params()
local param_repos = {}
for id, ref in pairs(self.layers) do
if id ~= '<input>' then
table.insert(param_repos, ref.layer:get_params())
end
end
return nerv.ParamRepo.merge(param_repos, self.loc_type)
end
function GraphLayer:bind_params()
self.lrepo:rebind(self.lconf.pr)
end
|