From 1a424bf9233f9b1c67ef135f1a3892b7986c5564 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Mon, 29 Feb 2016 22:05:43 +0800 Subject: add network & fix graph_layer --- nerv/nn/network.lua | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 nerv/nn/network.lua (limited to 'nerv/nn/network.lua') diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua new file mode 100644 index 0000000..6cee08b --- /dev/null +++ b/nerv/nn/network.lua @@ -0,0 +1,68 @@ +local network = nerv.class('nerv.Network') + +function network:__init(graph) + self.layers = {} + self.socket = self:compile(graph) + for i = 1, #self.layers do + print(self.layers[i].layer.id) + local _, dim_out = self.layers[i].layer:get_dim() + for j = 1, #dim_out do + for k = 1, #self.layers[i].connections[j] do + local connections = self.layers[i].connections[j][k] + print(i, connections[1], connections[2], connections[3]) + end + end + end +end + +function network:compile(layer) + local socket = {inputs = {}, outputs = {}} + if not nerv.is_type(layer, 'nerv.GraphLayer') then + table.insert(self.layers, {layer = layer, connections = {}}) + local id = #self.layers + local dim_in, dim_out = layer:get_dim() + for i = 1, #dim_in do + socket.inputs[i] = {{id, i, 0}} + end + for i = 1, #dim_out do + socket.outputs[i] = {id, i, 0} + self.layers[id].connections[i] = {} + end + else + local sublayer_socket = {} + for id, sublayer in pairs(layer.layers) do + if id ~= '' then + sublayer_socket[sublayer.id] = self:compile(sublayer.layer) + end + end + local dim_in, _ = layer:get_dim() + for i = 1, #dim_in do + socket.inputs[i] = {} + end + for _, edge in pairs(layer.connections) do + -- id = 0 means or + local id_from, port_from = edge[1], edge[2] + local id_to, port_to = edge[3], edge[4] + local time = edge[5] + if id_from == 0 then + for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do + local id, port, t = input[1], input[2], input[3] + time + table.insert(socket.inputs[port_from], {id, port, t}) + end + else + local output = sublayer_socket[id_from].outputs[port_from] + local id, port, t = output[1], output[2], output[3] + time + if id_to == 0 then + socket.outputs[port_to] = {id, port, t} + else + local connections = self.layers[id].connections[port] + for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do + local id1, port1, t1 = input[1], input[2], input[3] + table.insert(connections, {id1, port1, t + t1}) + end + end + end + end + end + return socket +end -- cgit v1.2.3