diff options
Diffstat (limited to 'nn')
-rw-r--r-- | nn/layer_dag.lua | 59 |
1 files changed, 38 insertions, 21 deletions
diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua index 1ab18fa..4ee829e 100644 --- a/nn/layer_dag.lua +++ b/nn/layer_dag.lua @@ -44,6 +44,7 @@ function nerv.DAGLayer:__init(id, global_conf, layer_conf) local outputs = {} local dim_in = layer_conf.dim_in local dim_out = layer_conf.dim_out + local parsed_conn = {} for from, to in pairs(layer_conf.connections) do local id_from, port_from = parse_id(from) local id_to, port_to = parse_id(to) @@ -76,32 +77,18 @@ function nerv.DAGLayer:__init(id, global_conf, layer_conf) if output_dim[port_from] ~= input_dim[port_to] then nerv.error("mismatching data dimension between %s and %s", from, to) end - local mid = global_conf.mat_type(global_conf.batch_size, - output_dim[port_from]) - local err_mid = mid:create() - - ref_from.outputs[port_from] = mid - ref_to.inputs[port_to] = mid - - ref_from.err_inputs[port_from] = err_mid - ref_to.err_outputs[port_to] = err_mid + table.insert(parsed_conn, + {{ref_from, port_from}, {ref_to, port_to}}) table.insert(ref_from.next_layers, ref_to) -- add edge ref_to.in_deg = ref_to.in_deg + 1 -- increase the in-degree of the target layer end end - self.layers = layers - self.inputs = inputs - self.outputs = outputs - self.dim_in = dim_in - self.dim_out = dim_out -end -function nerv.DAGLayer:init(id) -- topology sort local queue = {} local l = 1 local r = 1 - for id, ref in pairs(self.layers) do + for id, ref in pairs(layers) do if ref.in_deg == 0 then table.insert(queue, ref) nerv.utils.printf("adding source layer: %s\n", id) @@ -126,20 +113,50 @@ function nerv.DAGLayer:init(id) -- topology sort for i = 1, #queue do nerv.utils.printf("queued layer: %s\n", queue[i].layer.id) end - self.queue = queue - for id, ref in pairs(self.layers) do + + for id, ref in pairs(layers) do -- check wether the graph is connected if ref.visited == false then nerv.utils.printf("warning: layer %s is ignored\n", id) end + end + + self.layers = layers + self.inputs = inputs + self.outputs = outputs + self.dim_in = dim_in + self.dim_out = dim_out + self.parsed_conn = parsed_conn + self.queue = queue + self.gconf = global_conf +end + +function nerv.DAGLayer:init(batch_size) -- topology sort + for i, conn in ipairs(self.parsed_conn) do + local _, output_dim + local ref_from, port_from, ref_to, port_to + ref_from, port_from = unpack(conn[1]) + ref_to, port_to = unpack(conn[2]) + _, output_dim = ref_from.layer:get_dim() + local mid = self.gconf.cumat_type(batch_size, + output_dim[port_from]) + local err_mid = mid:create() + + ref_from.outputs[port_from] = mid + ref_to.inputs[port_to] = mid + + ref_from.err_inputs[port_from] = err_mid + ref_to.err_outputs[port_to] = err_mid + end + for id, ref in pairs(self.layers) do for i = 1, ref.input_len do if ref.inputs[i] == nil then - nerv.error("dangling port %d of layer %s", i, id) + nerv.error("dangling input port %d of layer %s", i, id) end end for i = 1, ref.output_len do if ref.outputs[i] == nil then - nerv.error("dangling port %d of layer %s", i, id) + nerv.error("dangling output port %d of layer %s", i, id) end end -- initialize sub layers |