aboutsummaryrefslogtreecommitdiff
path: root/nn/layer_dag.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nn/layer_dag.lua')
-rw-r--r--nn/layer_dag.lua59
1 files changed, 38 insertions, 21 deletions
diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua
index 1ab18fa..4ee829e 100644
--- a/nn/layer_dag.lua
+++ b/nn/layer_dag.lua
@@ -44,6 +44,7 @@ function nerv.DAGLayer:__init(id, global_conf, layer_conf)
local outputs = {}
local dim_in = layer_conf.dim_in
local dim_out = layer_conf.dim_out
+ local parsed_conn = {}
for from, to in pairs(layer_conf.connections) do
local id_from, port_from = parse_id(from)
local id_to, port_to = parse_id(to)
@@ -76,32 +77,18 @@ function nerv.DAGLayer:__init(id, global_conf, layer_conf)
if output_dim[port_from] ~= input_dim[port_to] then
nerv.error("mismatching data dimension between %s and %s", from, to)
end
- local mid = global_conf.mat_type(global_conf.batch_size,
- output_dim[port_from])
- local err_mid = mid:create()
-
- ref_from.outputs[port_from] = mid
- ref_to.inputs[port_to] = mid
-
- ref_from.err_inputs[port_from] = err_mid
- ref_to.err_outputs[port_to] = err_mid
+ table.insert(parsed_conn,
+ {{ref_from, port_from}, {ref_to, port_to}})
table.insert(ref_from.next_layers, ref_to) -- add edge
ref_to.in_deg = ref_to.in_deg + 1 -- increase the in-degree of the target layer
end
end
- self.layers = layers
- self.inputs = inputs
- self.outputs = outputs
- self.dim_in = dim_in
- self.dim_out = dim_out
-end
-function nerv.DAGLayer:init(id) -- topology sort
local queue = {}
local l = 1
local r = 1
- for id, ref in pairs(self.layers) do
+ for id, ref in pairs(layers) do
if ref.in_deg == 0 then
table.insert(queue, ref)
nerv.utils.printf("adding source layer: %s\n", id)
@@ -126,20 +113,50 @@ function nerv.DAGLayer:init(id) -- topology sort
for i = 1, #queue do
nerv.utils.printf("queued layer: %s\n", queue[i].layer.id)
end
- self.queue = queue
- for id, ref in pairs(self.layers) do
+
+ for id, ref in pairs(layers) do
-- check wether the graph is connected
if ref.visited == false then
nerv.utils.printf("warning: layer %s is ignored\n", id)
end
+ end
+
+ self.layers = layers
+ self.inputs = inputs
+ self.outputs = outputs
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.parsed_conn = parsed_conn
+ self.queue = queue
+ self.gconf = global_conf
+end
+
+function nerv.DAGLayer:init(batch_size) -- topology sort
+ for i, conn in ipairs(self.parsed_conn) do
+ local _, output_dim
+ local ref_from, port_from, ref_to, port_to
+ ref_from, port_from = unpack(conn[1])
+ ref_to, port_to = unpack(conn[2])
+ _, output_dim = ref_from.layer:get_dim()
+ local mid = self.gconf.cumat_type(batch_size,
+ output_dim[port_from])
+ local err_mid = mid:create()
+
+ ref_from.outputs[port_from] = mid
+ ref_to.inputs[port_to] = mid
+
+ ref_from.err_inputs[port_from] = err_mid
+ ref_to.err_outputs[port_to] = err_mid
+ end
+ for id, ref in pairs(self.layers) do
for i = 1, ref.input_len do
if ref.inputs[i] == nil then
- nerv.error("dangling port %d of layer %s", i, id)
+ nerv.error("dangling input port %d of layer %s", i, id)
end
end
for i = 1, ref.output_len do
if ref.outputs[i] == nil then
- nerv.error("dangling port %d of layer %s", i, id)
+ nerv.error("dangling output port %d of layer %s", i, id)
end
end
-- initialize sub layers