From f3f4e74eb4dbb8829e5ee136ba4b0c0a7938b551 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sat, 20 Jun 2015 20:00:25 +0800 Subject: change concept of ParamRepo; provide generalized param update; code clean-up; #25 #26 #27 #29 --- nn/layer_dag.lua | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) (limited to 'nn/layer_dag.lua') diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua index 2dda7c9..8e30216 100644 --- a/nn/layer_dag.lua +++ b/nn/layer_dag.lua @@ -85,13 +85,14 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end + -- topology sort local queue = {} local l = 1 local r = 1 for id, ref in pairs(layers) do if ref.in_deg == 0 then table.insert(queue, ref) - nerv.utils.printf("adding source layer: %s\n", id) + nerv.info("adding source layer: %s", id) r = r + 1 end end @@ -111,13 +112,13 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end for i = 1, #queue do - nerv.utils.printf("queued layer: %s\n", queue[i].layer.id) + nerv.info("enqueued layer: %s", queue[i].layer.id) end for id, ref in pairs(layers) do -- check wether the graph is connected if ref.visited == false then - nerv.utils.printf("warning: layer %s is ignored\n", id) + nerv.warning("layer %s is ignored", id) end end @@ -131,7 +132,7 @@ function DAGLayer:__init(id, global_conf, layer_conf) self.gconf = global_conf end -function DAGLayer:init(batch_size) -- topology sort +function DAGLayer:init(batch_size) for i, conn in ipairs(self.parsed_conn) do local _, output_dim local ref_from, port_from, ref_to, port_to @@ -160,7 +161,7 @@ function DAGLayer:init(batch_size) -- topology sort end end -- initialize sub layers - ref.layer:init() + ref.layer:init(batch_size) end for i = 1, #self.dim_in do if self.inputs[i] == nil then @@ -227,7 +228,7 @@ function DAGLayer:propagate(input, output) end end -function DAGLayer:back_propagate(next_bp_err, bp_err, input, output) +function DAGLayer:back_propagate(bp_err, next_bp_err, input, output) self:set_err_outputs(next_bp_err) self:set_err_inputs(bp_err) self:set_inputs(input) @@ -235,16 +236,14 @@ function DAGLayer:back_propagate(next_bp_err, bp_err, input, output) for i = #self.queue, 1, -1 do local ref = self.queue[i] -- print(ref.layer.id) - ref.layer:back_propagate(ref.err_outputs, ref.err_inputs, ref.inputs, ref.outputs) + ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs) end end function DAGLayer:get_params() - local res = {} + local param_repos = {} for id, ref in pairs(self.queue) do - for i, p in ipairs(ref.layer:get_params()) do - table.insert(res, p) - end + table.insert(param_repos, ref.layer:get_params()) end - return res + return nerv.ParamRepo.merge(param_repos) end -- cgit v1.2.3