aboutsummaryrefslogtreecommitdiff
path: root/nn/layer_dag.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-20 20:00:25 +0800
committerDeterminant <[email protected]>2015-06-20 20:00:25 +0800
commitf3f4e74eb4dbb8829e5ee136ba4b0c0a7938b551 (patch)
tree8beb12182020267ce32904d646ad0c736c27dcd2 /nn/layer_dag.lua
parent2ab9610a4fff798c1668cdc041515256fa813865 (diff)
change concept of ParamRepo; provide generalized param update; code clean-up; #25 #26 #27 #29
Diffstat (limited to 'nn/layer_dag.lua')
-rw-r--r--nn/layer_dag.lua23
1 files changed, 11 insertions, 12 deletions
diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua
index 2dda7c9..8e30216 100644
--- a/nn/layer_dag.lua
+++ b/nn/layer_dag.lua
@@ -85,13 +85,14 @@ function DAGLayer:__init(id, global_conf, layer_conf)
end
end
+ -- topology sort
local queue = {}
local l = 1
local r = 1
for id, ref in pairs(layers) do
if ref.in_deg == 0 then
table.insert(queue, ref)
- nerv.utils.printf("adding source layer: %s\n", id)
+ nerv.info("adding source layer: %s", id)
r = r + 1
end
end
@@ -111,13 +112,13 @@ function DAGLayer:__init(id, global_conf, layer_conf)
end
end
for i = 1, #queue do
- nerv.utils.printf("queued layer: %s\n", queue[i].layer.id)
+ nerv.info("enqueued layer: %s", queue[i].layer.id)
end
for id, ref in pairs(layers) do
-- check wether the graph is connected
if ref.visited == false then
- nerv.utils.printf("warning: layer %s is ignored\n", id)
+ nerv.warning("layer %s is ignored", id)
end
end
@@ -131,7 +132,7 @@ function DAGLayer:__init(id, global_conf, layer_conf)
self.gconf = global_conf
end
-function DAGLayer:init(batch_size) -- topology sort
+function DAGLayer:init(batch_size)
for i, conn in ipairs(self.parsed_conn) do
local _, output_dim
local ref_from, port_from, ref_to, port_to
@@ -160,7 +161,7 @@ function DAGLayer:init(batch_size) -- topology sort
end
end
-- initialize sub layers
- ref.layer:init()
+ ref.layer:init(batch_size)
end
for i = 1, #self.dim_in do
if self.inputs[i] == nil then
@@ -227,7 +228,7 @@ function DAGLayer:propagate(input, output)
end
end
-function DAGLayer:back_propagate(next_bp_err, bp_err, input, output)
+function DAGLayer:back_propagate(bp_err, next_bp_err, input, output)
self:set_err_outputs(next_bp_err)
self:set_err_inputs(bp_err)
self:set_inputs(input)
@@ -235,16 +236,14 @@ function DAGLayer:back_propagate(next_bp_err, bp_err, input, output)
for i = #self.queue, 1, -1 do
local ref = self.queue[i]
-- print(ref.layer.id)
- ref.layer:back_propagate(ref.err_outputs, ref.err_inputs, ref.inputs, ref.outputs)
+ ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs)
end
end
function DAGLayer:get_params()
- local res = {}
+ local param_repos = {}
for id, ref in pairs(self.queue) do
- for i, p in ipairs(ref.layer:get_params()) do
- table.insert(res, p)
- end
+ table.insert(param_repos, ref.layer:get_params())
end
- return res
+ return nerv.ParamRepo.merge(param_repos)
end