diff options
author | Determinant <[email protected]> | 2015-06-20 20:00:25 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-06-20 20:00:25 +0800 |
commit | f3f4e74eb4dbb8829e5ee136ba4b0c0a7938b551 (patch) | |
tree | 8beb12182020267ce32904d646ad0c736c27dcd2 /nn/layer_dag.lua | |
parent | 2ab9610a4fff798c1668cdc041515256fa813865 (diff) |
change concept of ParamRepo; provide generalized param update; code clean-up; #25 #26 #27 #29
Diffstat (limited to 'nn/layer_dag.lua')
-rw-r--r-- | nn/layer_dag.lua | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua index 2dda7c9..8e30216 100644 --- a/nn/layer_dag.lua +++ b/nn/layer_dag.lua @@ -85,13 +85,14 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end + -- topology sort local queue = {} local l = 1 local r = 1 for id, ref in pairs(layers) do if ref.in_deg == 0 then table.insert(queue, ref) - nerv.utils.printf("adding source layer: %s\n", id) + nerv.info("adding source layer: %s", id) r = r + 1 end end @@ -111,13 +112,13 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end for i = 1, #queue do - nerv.utils.printf("queued layer: %s\n", queue[i].layer.id) + nerv.info("enqueued layer: %s", queue[i].layer.id) end for id, ref in pairs(layers) do -- check wether the graph is connected if ref.visited == false then - nerv.utils.printf("warning: layer %s is ignored\n", id) + nerv.warning("layer %s is ignored", id) end end @@ -131,7 +132,7 @@ function DAGLayer:__init(id, global_conf, layer_conf) self.gconf = global_conf end -function DAGLayer:init(batch_size) -- topology sort +function DAGLayer:init(batch_size) for i, conn in ipairs(self.parsed_conn) do local _, output_dim local ref_from, port_from, ref_to, port_to @@ -160,7 +161,7 @@ function DAGLayer:init(batch_size) -- topology sort end end -- initialize sub layers - ref.layer:init() + ref.layer:init(batch_size) end for i = 1, #self.dim_in do if self.inputs[i] == nil then @@ -227,7 +228,7 @@ function DAGLayer:propagate(input, output) end end -function DAGLayer:back_propagate(next_bp_err, bp_err, input, output) +function DAGLayer:back_propagate(bp_err, next_bp_err, input, output) self:set_err_outputs(next_bp_err) self:set_err_inputs(bp_err) self:set_inputs(input) @@ -235,16 +236,14 @@ function DAGLayer:back_propagate(next_bp_err, bp_err, input, output) for i = #self.queue, 1, -1 do local ref = self.queue[i] -- print(ref.layer.id) - ref.layer:back_propagate(ref.err_outputs, ref.err_inputs, ref.inputs, ref.outputs) + ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs) end end function DAGLayer:get_params() - local res = {} + local param_repos = {} for id, ref in pairs(self.queue) do - for i, p in ipairs(ref.layer:get_params()) do - table.insert(res, p) - end + table.insert(param_repos, ref.layer:get_params()) end - return res + return nerv.ParamRepo.merge(param_repos) end |