aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn
diff options
context:
space:
mode:
authorQi Liu <liuq901@163.com>2016-02-29 22:05:43 +0800
committerQi Liu <liuq901@163.com>2016-02-29 22:05:43 +0800
commit1a424bf9233f9b1c67ef135f1a3892b7986c5564 (patch)
tree90d470ebb69425934a2f9673f4f4a3c28a177bd2 /nerv/nn
parent77b558898a2a29097d8697a59a7d23cd2a52975f (diff)
add network & fix graph_layer
Diffstat (limited to 'nerv/nn')
-rw-r--r--nerv/nn/init.lua1
-rw-r--r--nerv/nn/network.lua68
2 files changed, 69 insertions, 0 deletions
diff --git a/nerv/nn/init.lua b/nerv/nn/init.lua
index cbaf52b..c32ea09 100644
--- a/nerv/nn/init.lua
+++ b/nerv/nn/init.lua
@@ -1,3 +1,4 @@
nerv.include('layer_repo.lua')
nerv.include('param_repo.lua')
nerv.include('layer_dag.lua')
+nerv.include('network.lua')
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
new file mode 100644
index 0000000..6cee08b
--- /dev/null
+++ b/nerv/nn/network.lua
@@ -0,0 +1,68 @@
+local network = nerv.class('nerv.Network')
+
+function network:__init(graph)
+ self.layers = {}
+ self.socket = self:compile(graph)
+ for i = 1, #self.layers do
+ print(self.layers[i].layer.id)
+ local _, dim_out = self.layers[i].layer:get_dim()
+ for j = 1, #dim_out do
+ for k = 1, #self.layers[i].connections[j] do
+ local connections = self.layers[i].connections[j][k]
+ print(i, connections[1], connections[2], connections[3])
+ end
+ end
+ end
+end
+
+function network:compile(layer)
+ local socket = {inputs = {}, outputs = {}}
+ if not nerv.is_type(layer, 'nerv.GraphLayer') then
+ table.insert(self.layers, {layer = layer, connections = {}})
+ local id = #self.layers
+ local dim_in, dim_out = layer:get_dim()
+ for i = 1, #dim_in do
+ socket.inputs[i] = {{id, i, 0}}
+ end
+ for i = 1, #dim_out do
+ socket.outputs[i] = {id, i, 0}
+ self.layers[id].connections[i] = {}
+ end
+ else
+ local sublayer_socket = {}
+ for id, sublayer in pairs(layer.layers) do
+ if id ~= '<input>' then
+ sublayer_socket[sublayer.id] = self:compile(sublayer.layer)
+ end
+ end
+ local dim_in, _ = layer:get_dim()
+ for i = 1, #dim_in do
+ socket.inputs[i] = {}
+ end
+ for _, edge in pairs(layer.connections) do
+ -- id = 0 means <input> or <output>
+ local id_from, port_from = edge[1], edge[2]
+ local id_to, port_to = edge[3], edge[4]
+ local time = edge[5]
+ if id_from == 0 then
+ for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do
+ local id, port, t = input[1], input[2], input[3] + time
+ table.insert(socket.inputs[port_from], {id, port, t})
+ end
+ else
+ local output = sublayer_socket[id_from].outputs[port_from]
+ local id, port, t = output[1], output[2], output[3] + time
+ if id_to == 0 then
+ socket.outputs[port_to] = {id, port, t}
+ else
+ local connections = self.layers[id].connections[port]
+ for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do
+ local id1, port1, t1 = input[1], input[2], input[3]
+ table.insert(connections, {id1, port1, t + t1})
+ end
+ end
+ end
+ end
+ end
+ return socket
+end