diff options
-rw-r--r-- | nerv/Makefile | 2 | ||||
-rw-r--r-- | nerv/layer/graph.lua | 46 | ||||
-rw-r--r-- | nerv/layer/rnn.lua | 4 | ||||
-rw-r--r-- | nerv/main.lua | 4 | ||||
-rw-r--r-- | nerv/nn/init.lua | 1 | ||||
-rw-r--r-- | nerv/nn/network.lua | 68 |
6 files changed, 105 insertions, 20 deletions
diff --git a/nerv/Makefile b/nerv/Makefile index ba97579..c9c3e42 100644 --- a/nerv/Makefile +++ b/nerv/Makefile @@ -35,7 +35,7 @@ LUA_LIBS := matrix/init.lua io/init.lua init.lua \ layer/window.lua layer/bias.lua layer/combiner.lua layer/mse.lua \ layer/elem_mul.lua layer/lstm.lua layer/lstm_gate.lua layer/dropout.lua layer/gru.lua \ layer/graph.lua layer/rnn.lua \ - nn/init.lua nn/layer_repo.lua nn/param_repo.lua nn/layer_dag.lua \ + nn/init.lua nn/layer_repo.lua nn/param_repo.lua nn/layer_dag.lua nn/network.lua \ io/sgd_buffer.lua \ tnn/init.lua tnn/sutil.lua tnn/tnn.lua diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua index 83cf810..36a9672 100644 --- a/nerv/layer/graph.lua +++ b/nerv/layer/graph.lua @@ -21,20 +21,23 @@ local function parse_id(str) return id, port end -local function discover(id, layers, layer_repo) +function GraphLayer:discover(id, layer_repo) if id == '<output>' then id = '<input>' end + local layers = self.layers local ref = layers[id] if ref == nil then local layer = layer_repo:get_layer(id) local dim_in, dim_out = layer:get_dim() + self.layer_num = self.layer_num + 1 ref = { layer = layer, inputs = {}, outputs = {}, dim_in = dim_in, dim_out = dim_out, + id = self.layer_num, } layers[id] = ref end @@ -42,32 +45,37 @@ local function discover(id, layers, layer_repo) end function GraphLayer:graph_init(layer_repo, connections) - self.connections = connections - self.sublayer = nerv.LayerRepo({}, nerv.ParamRepo(), self.gconf) - - -- check data dimension between connected ports local layers = {} layers['<input>'] = { inputs = {}, outputs = {}, dim_in = self.dim_out, dim_out = self.dim_in, + id = 0, } - for _, edge in pairs(self.connections) do - local from = edge[1] - local to = edge[2] + self.layers = layers + self.layer_num = 0 + self.connections = {} + + -- check data dimension between connected ports + for _, edge in pairs(connections) do + local from, to, time = edge[1], edge[2], edge[3] local id_from, port_from = parse_id(from) local id_to, port_to = parse_id(to) - local ref_from = discover(id_from, layers, layer_repo) - local ref_to = discover(id_to, layers, layer_repo) + local ref_from = self:discover(id_from, layer_repo) + local ref_to = self:discover(id_to, layer_repo) if ref_to.inputs[port_to] ~= nil then nerv.error('%s has already been attached', to) end if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then nerv.error('mismatching data dimension between %s and %s', from, to) end + if ref_from.id == 0 and ref_to.id == 0 then + nerv.error('short-circuit connection between <input> and <output>') + end ref_from.outputs[port_from] = true ref_to.inputs[port_to] = true + table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time}) end -- check dangling ports @@ -83,7 +91,6 @@ function GraphLayer:graph_init(layer_repo, connections) nerv.error('dangling output port %d os layer %s', i, id) end end - self.sublayer.layers[id] = ref.layer end end for i = 1, #self.dim_in do @@ -100,19 +107,26 @@ end function GraphLayer:set_attr(name, value) self[name] = value - for id, layer in pairs(self.sublayer.layers) do - layer:set_attr(name, value) + for id, ref in pairs(self.layers) do + if id ~= '<input>' then + ref.layer:set_attr(name, value) + end end end function GraphLayer:get_sublayer(id) - return self.sublayer:get_layer(id) + if self.layers[id] == nil or id == '<input>' then + nerv.error('layer with id %s not found', id) + end + return self.layers[id].layer end function GraphLayer:get_params() local param_repos = {} - for id, layer in pairs(self.sublayer.layers) do - table.insert(param_repos, layer:get_params()) + for id, ref in pairs(self.layers) do + if id ~= '<input>' then + table.insert(param_repos, ref.layer:get_params()) + end end return nerv.ParamRepo.merge(param_repos) end diff --git a/nerv/layer/rnn.lua b/nerv/layer/rnn.lua index a93530f..8816891 100644 --- a/nerv/layer/rnn.lua +++ b/nerv/layer/rnn.lua @@ -29,8 +29,8 @@ function RNNLayer:__init(id, global_conf, layer_conf) local connections = { {'<input>[1]', 'main[1]', 0}, {'main[1]', 'sigmoid[1]', 0}, - {'sigmoid[1]', 'main[2]', 0}, - {'sigmoid[1]', '<output>[1]', 1}, + {'sigmoid[1]', 'main[2]', 1}, + {'sigmoid[1]', '<output>[1]', 0}, } self:graph_init(layer_repo, connections) diff --git a/nerv/main.lua b/nerv/main.lua index 85e291c..0633e87 100644 --- a/nerv/main.lua +++ b/nerv/main.lua @@ -28,4 +28,6 @@ local connections = { {'output[1]', '<output>[1]', 0}, } -local network = nerv.GraphLayer('network', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections}) +local graph = nerv.GraphLayer('network', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections}) + +local network = nerv.Network(graph) diff --git a/nerv/nn/init.lua b/nerv/nn/init.lua index cbaf52b..c32ea09 100644 --- a/nerv/nn/init.lua +++ b/nerv/nn/init.lua @@ -1,3 +1,4 @@ nerv.include('layer_repo.lua') nerv.include('param_repo.lua') nerv.include('layer_dag.lua') +nerv.include('network.lua') diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua new file mode 100644 index 0000000..6cee08b --- /dev/null +++ b/nerv/nn/network.lua @@ -0,0 +1,68 @@ +local network = nerv.class('nerv.Network') + +function network:__init(graph) + self.layers = {} + self.socket = self:compile(graph) + for i = 1, #self.layers do + print(self.layers[i].layer.id) + local _, dim_out = self.layers[i].layer:get_dim() + for j = 1, #dim_out do + for k = 1, #self.layers[i].connections[j] do + local connections = self.layers[i].connections[j][k] + print(i, connections[1], connections[2], connections[3]) + end + end + end +end + +function network:compile(layer) + local socket = {inputs = {}, outputs = {}} + if not nerv.is_type(layer, 'nerv.GraphLayer') then + table.insert(self.layers, {layer = layer, connections = {}}) + local id = #self.layers + local dim_in, dim_out = layer:get_dim() + for i = 1, #dim_in do + socket.inputs[i] = {{id, i, 0}} + end + for i = 1, #dim_out do + socket.outputs[i] = {id, i, 0} + self.layers[id].connections[i] = {} + end + else + local sublayer_socket = {} + for id, sublayer in pairs(layer.layers) do + if id ~= '<input>' then + sublayer_socket[sublayer.id] = self:compile(sublayer.layer) + end + end + local dim_in, _ = layer:get_dim() + for i = 1, #dim_in do + socket.inputs[i] = {} + end + for _, edge in pairs(layer.connections) do + -- id = 0 means <input> or <output> + local id_from, port_from = edge[1], edge[2] + local id_to, port_to = edge[3], edge[4] + local time = edge[5] + if id_from == 0 then + for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do + local id, port, t = input[1], input[2], input[3] + time + table.insert(socket.inputs[port_from], {id, port, t}) + end + else + local output = sublayer_socket[id_from].outputs[port_from] + local id, port, t = output[1], output[2], output[3] + time + if id_to == 0 then + socket.outputs[port_to] = {id, port, t} + else + local connections = self.layers[id].connections[port] + for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do + local id1, port1, t1 = input[1], input[2], input[3] + table.insert(connections, {id1, port1, t + t1}) + end + end + end + end + end + return socket +end |