aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/layer_dag.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/layer_dag.lua')
-rw-r--r--nerv/nn/layer_dag.lua23
1 files changed, 21 insertions, 2 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
index 25297c2..a262a72 100644
--- a/nerv/nn/layer_dag.lua
+++ b/nerv/nn/layer_dag.lua
@@ -92,7 +92,7 @@ function DAGLayer:__init(id, global_conf, layer_conf)
for id, ref in pairs(layers) do
if ref.in_deg == 0 then
table.insert(queue, ref)
- nerv.info("adding source layer: %s", id)
+ --nerv.info("adding source layer: %s", id)
r = r + 1
end
end
@@ -112,7 +112,7 @@ function DAGLayer:__init(id, global_conf, layer_conf)
end
end
for i = 1, #queue do
- nerv.info("enqueued layer: %s %s", queue[i].layer, queue[i].layer.id)
+ --nerv.info("enqueued layer: %s %s", queue[i].layer, queue[i].layer.id)
end
for id, ref in pairs(layers) do
@@ -225,6 +225,25 @@ function DAGLayer:update(bp_err, input, output)
end
end
+function DAGLayer:gradient(bp_err, input, output)
+ self:set_err_inputs(bp_err)
+ self:set_inputs(input)
+ self:set_outputs(output)
+ -- print("gradient")
+ for id, ref in pairs(self.queue) do
+ -- print(ref.layer.id)
+ ref.layer:gradient(ref.err_inputs, ref.inputs, ref.outputs)
+ end
+end
+
+function DAGLayer:update_gradient()
+ -- print("update gradient")
+ for id, ref in pairs(self.queue) do
+ -- print(ref.layer.id)
+ ref.layer:update_gradient()
+ end
+end
+
function DAGLayer:propagate(input, output)
self:set_inputs(input)
self:set_outputs(output)