aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/Makefile2
-rw-r--r--nerv/layer/graph.lua46
-rw-r--r--nerv/layer/rnn.lua4
-rw-r--r--nerv/main.lua4
-rw-r--r--nerv/nn/init.lua1
-rw-r--r--nerv/nn/network.lua68
6 files changed, 105 insertions, 20 deletions
diff --git a/nerv/Makefile b/nerv/Makefile
index ba97579..c9c3e42 100644
--- a/nerv/Makefile
+++ b/nerv/Makefile
@@ -35,7 +35,7 @@ LUA_LIBS := matrix/init.lua io/init.lua init.lua \
layer/window.lua layer/bias.lua layer/combiner.lua layer/mse.lua \
layer/elem_mul.lua layer/lstm.lua layer/lstm_gate.lua layer/dropout.lua layer/gru.lua \
layer/graph.lua layer/rnn.lua \
- nn/init.lua nn/layer_repo.lua nn/param_repo.lua nn/layer_dag.lua \
+ nn/init.lua nn/layer_repo.lua nn/param_repo.lua nn/layer_dag.lua nn/network.lua \
io/sgd_buffer.lua \
tnn/init.lua tnn/sutil.lua tnn/tnn.lua
diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua
index 83cf810..36a9672 100644
--- a/nerv/layer/graph.lua
+++ b/nerv/layer/graph.lua
@@ -21,20 +21,23 @@ local function parse_id(str)
return id, port
end
-local function discover(id, layers, layer_repo)
+function GraphLayer:discover(id, layer_repo)
if id == '<output>' then
id = '<input>'
end
+ local layers = self.layers
local ref = layers[id]
if ref == nil then
local layer = layer_repo:get_layer(id)
local dim_in, dim_out = layer:get_dim()
+ self.layer_num = self.layer_num + 1
ref = {
layer = layer,
inputs = {},
outputs = {},
dim_in = dim_in,
dim_out = dim_out,
+ id = self.layer_num,
}
layers[id] = ref
end
@@ -42,32 +45,37 @@ local function discover(id, layers, layer_repo)
end
function GraphLayer:graph_init(layer_repo, connections)
- self.connections = connections
- self.sublayer = nerv.LayerRepo({}, nerv.ParamRepo(), self.gconf)
-
- -- check data dimension between connected ports
local layers = {}
layers['<input>'] = {
inputs = {},
outputs = {},
dim_in = self.dim_out,
dim_out = self.dim_in,
+ id = 0,
}
- for _, edge in pairs(self.connections) do
- local from = edge[1]
- local to = edge[2]
+ self.layers = layers
+ self.layer_num = 0
+ self.connections = {}
+
+ -- check data dimension between connected ports
+ for _, edge in pairs(connections) do
+ local from, to, time = edge[1], edge[2], edge[3]
local id_from, port_from = parse_id(from)
local id_to, port_to = parse_id(to)
- local ref_from = discover(id_from, layers, layer_repo)
- local ref_to = discover(id_to, layers, layer_repo)
+ local ref_from = self:discover(id_from, layer_repo)
+ local ref_to = self:discover(id_to, layer_repo)
if ref_to.inputs[port_to] ~= nil then
nerv.error('%s has already been attached', to)
end
if ref_from.dim_out[port_from] ~= ref_to.dim_in[port_to] then
nerv.error('mismatching data dimension between %s and %s', from, to)
end
+ if ref_from.id == 0 and ref_to.id == 0 then
+ nerv.error('short-circuit connection between <input> and <output>')
+ end
ref_from.outputs[port_from] = true
ref_to.inputs[port_to] = true
+ table.insert(self.connections, {ref_from.id, port_from, ref_to.id, port_to, time})
end
-- check dangling ports
@@ -83,7 +91,6 @@ function GraphLayer:graph_init(layer_repo, connections)
nerv.error('dangling output port %d os layer %s', i, id)
end
end
- self.sublayer.layers[id] = ref.layer
end
end
for i = 1, #self.dim_in do
@@ -100,19 +107,26 @@ end
function GraphLayer:set_attr(name, value)
self[name] = value
- for id, layer in pairs(self.sublayer.layers) do
- layer:set_attr(name, value)
+ for id, ref in pairs(self.layers) do
+ if id ~= '<input>' then
+ ref.layer:set_attr(name, value)
+ end
end
end
function GraphLayer:get_sublayer(id)
- return self.sublayer:get_layer(id)
+ if self.layers[id] == nil or id == '<input>' then
+ nerv.error('layer with id %s not found', id)
+ end
+ return self.layers[id].layer
end
function GraphLayer:get_params()
local param_repos = {}
- for id, layer in pairs(self.sublayer.layers) do
- table.insert(param_repos, layer:get_params())
+ for id, ref in pairs(self.layers) do
+ if id ~= '<input>' then
+ table.insert(param_repos, ref.layer:get_params())
+ end
end
return nerv.ParamRepo.merge(param_repos)
end
diff --git a/nerv/layer/rnn.lua b/nerv/layer/rnn.lua
index a93530f..8816891 100644
--- a/nerv/layer/rnn.lua
+++ b/nerv/layer/rnn.lua
@@ -29,8 +29,8 @@ function RNNLayer:__init(id, global_conf, layer_conf)
local connections = {
{'<input>[1]', 'main[1]', 0},
{'main[1]', 'sigmoid[1]', 0},
- {'sigmoid[1]', 'main[2]', 0},
- {'sigmoid[1]', '<output>[1]', 1},
+ {'sigmoid[1]', 'main[2]', 1},
+ {'sigmoid[1]', '<output>[1]', 0},
}
self:graph_init(layer_repo, connections)
diff --git a/nerv/main.lua b/nerv/main.lua
index 85e291c..0633e87 100644
--- a/nerv/main.lua
+++ b/nerv/main.lua
@@ -28,4 +28,6 @@ local connections = {
{'output[1]', '<output>[1]', 0},
}
-local network = nerv.GraphLayer('network', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections})
+local graph = nerv.GraphLayer('network', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections})
+
+local network = nerv.Network(graph)
diff --git a/nerv/nn/init.lua b/nerv/nn/init.lua
index cbaf52b..c32ea09 100644
--- a/nerv/nn/init.lua
+++ b/nerv/nn/init.lua
@@ -1,3 +1,4 @@
nerv.include('layer_repo.lua')
nerv.include('param_repo.lua')
nerv.include('layer_dag.lua')
+nerv.include('network.lua')
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
new file mode 100644
index 0000000..6cee08b
--- /dev/null
+++ b/nerv/nn/network.lua
@@ -0,0 +1,68 @@
+local network = nerv.class('nerv.Network')
+
+function network:__init(graph)
+ self.layers = {}
+ self.socket = self:compile(graph)
+ for i = 1, #self.layers do
+ print(self.layers[i].layer.id)
+ local _, dim_out = self.layers[i].layer:get_dim()
+ for j = 1, #dim_out do
+ for k = 1, #self.layers[i].connections[j] do
+ local connections = self.layers[i].connections[j][k]
+ print(i, connections[1], connections[2], connections[3])
+ end
+ end
+ end
+end
+
+function network:compile(layer)
+ local socket = {inputs = {}, outputs = {}}
+ if not nerv.is_type(layer, 'nerv.GraphLayer') then
+ table.insert(self.layers, {layer = layer, connections = {}})
+ local id = #self.layers
+ local dim_in, dim_out = layer:get_dim()
+ for i = 1, #dim_in do
+ socket.inputs[i] = {{id, i, 0}}
+ end
+ for i = 1, #dim_out do
+ socket.outputs[i] = {id, i, 0}
+ self.layers[id].connections[i] = {}
+ end
+ else
+ local sublayer_socket = {}
+ for id, sublayer in pairs(layer.layers) do
+ if id ~= '<input>' then
+ sublayer_socket[sublayer.id] = self:compile(sublayer.layer)
+ end
+ end
+ local dim_in, _ = layer:get_dim()
+ for i = 1, #dim_in do
+ socket.inputs[i] = {}
+ end
+ for _, edge in pairs(layer.connections) do
+ -- id = 0 means <input> or <output>
+ local id_from, port_from = edge[1], edge[2]
+ local id_to, port_to = edge[3], edge[4]
+ local time = edge[5]
+ if id_from == 0 then
+ for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do
+ local id, port, t = input[1], input[2], input[3] + time
+ table.insert(socket.inputs[port_from], {id, port, t})
+ end
+ else
+ local output = sublayer_socket[id_from].outputs[port_from]
+ local id, port, t = output[1], output[2], output[3] + time
+ if id_to == 0 then
+ socket.outputs[port_to] = {id, port, t}
+ else
+ local connections = self.layers[id].connections[port]
+ for _, input in pairs(sublayer_socket[id_to].inputs[port_to]) do
+ local id1, port1, t1 = input[1], input[2], input[3]
+ table.insert(connections, {id1, port1, t + t1})
+ end
+ end
+ end
+ end
+ end
+ return socket
+end