summaryrefslogtreecommitdiff
path: root/nerv/layer/graph.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer/graph.lua')
-rw-r--r--nerv/layer/graph.lua118
1 files changed, 118 insertions, 0 deletions
diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua
new file mode 100644
index 0000000..83cf810
--- /dev/null
+++ b/nerv/layer/graph.lua
@@ -0,0 +1,118 @@
+local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer')
+
+function GraphLayer:__init(id, global_conf, layer_conf)
+ self.id = id
+ self.dim_in = layer_conf.dim_in
+ self.dim_out = layer_conf.dim_out
+ self.gconf = global_conf
+ self:graph_init(layer_conf.layer_repo, layer_conf.connections)
+end
+
+local function parse_id(str)
+ local id, port, _
+ _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
+ if id == nil or port == nil then
+ _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
+ if not (id == "<input>" or id == "<output>") then
+ nerv.error("wrong format of connection id")
+ end
+ end
+ port = tonumber(port)
+ return id, port
+end
+
+local function discover(id, layers, layer_repo)
+ if id == '<output>' then
+ id = '<input>'
+ end
+ local ref = layers[id]
+ if ref == nil then
+ local layer = layer_repo:get_layer(id)
+ local dim_in, dim_out = layer:get_dim()
+ ref = {
+ layer = layer,
+ inputs = {},
+ outputs = {},
+ dim_in = dim_in,
+ dim_out = dim_out,
+ }
+ layers[id] = ref
+ end
+ return ref
+end
+
+function GraphLayer:graph_init(layer_repo, connections)
+ self.connections = connections
+ self.sublayer = nerv.LayerRepo({}, nerv.ParamRepo(), self.gconf)
+
+ -- check data dimension between connected ports
+ local layers = {}
+ layers['<input>'] = {
+ inputs = {},
+ outputs = {},
+ dim_in = self.dim_out,
+ dim_out = self.dim_in,
+ }
+ for _, edge in pairs(self.connections) do
+ local from = edge[1]
+ local to = edge[2]
+ local id_from, port_from = parse_id(from)
+ local id_to, port_to = parse_id(to)
+ local ref_from = discover(id_from, layers, layer_repo)
+ local ref_to = discover(id_to, layers, layer_repo)
+ if ref_to.inputs[port_to] ~= nil then
+ nerv.error('%s has already been attached', to)
+ end
+ if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then
+ nerv.error('mismatching data dimension between %s and %s', from, to)
+ end
+ ref_from.outputs[port_from] = true
+ ref_to.inputs[port_to] = true
+ end
+
+ -- check dangling ports
+ for id, ref in pairs(layers) do
+ if id ~= '<input>' then
+ for i = 1, #ref.dim_in do
+ if ref.inputs[i] == nil then
+ nerv.error('dangling input port %d of layer %s', i, id)
+ end
+ end
+ for i = 1, #ref.dim_out do
+ if ref.outputs[i] == nil then
+ nerv.error('dangling output port %d os layer %s', i, id)
+ end
+ end
+ self.sublayer.layers[id] = ref.layer
+ end
+ end
+ for i = 1, #self.dim_in do
+ if layers['<input>'].outputs[i] == nil then
+ nerv.error('dangling port %d of layer <input>', i)
+ end
+ end
+ for i = 1, #self.dim_out do
+ if layers['<input>'].inputs[i] == nil then
+ nerv.error('dangling port %d of layer <output>', i)
+ end
+ end
+end
+
+function GraphLayer:set_attr(name, value)
+ self[name] = value
+ for id, layer in pairs(self.sublayer.layers) do
+ layer:set_attr(name, value)
+ end
+end
+
+function GraphLayer:get_sublayer(id)
+ return self.sublayer:get_layer(id)
+end
+
+function GraphLayer:get_params()
+ local param_repos = {}
+ for id, layer in pairs(self.sublayer.layers) do
+ table.insert(param_repos, layer:get_params())
+ end
+ return nerv.ParamRepo.merge(param_repos)
+end