diff options
Diffstat (limited to 'nerv/nn/layer_dag.lua')
-rw-r--r-- | nerv/nn/layer_dag.lua | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua index 25297c2..a262a72 100644 --- a/nerv/nn/layer_dag.lua +++ b/nerv/nn/layer_dag.lua @@ -92,7 +92,7 @@ function DAGLayer:__init(id, global_conf, layer_conf) for id, ref in pairs(layers) do if ref.in_deg == 0 then table.insert(queue, ref) - nerv.info("adding source layer: %s", id) + --nerv.info("adding source layer: %s", id) r = r + 1 end end @@ -112,7 +112,7 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end for i = 1, #queue do - nerv.info("enqueued layer: %s %s", queue[i].layer, queue[i].layer.id) + --nerv.info("enqueued layer: %s %s", queue[i].layer, queue[i].layer.id) end for id, ref in pairs(layers) do @@ -225,6 +225,25 @@ function DAGLayer:update(bp_err, input, output) end end +function DAGLayer:gradient(bp_err, input, output) + self:set_err_inputs(bp_err) + self:set_inputs(input) + self:set_outputs(output) + -- print("gradient") + for id, ref in pairs(self.queue) do + -- print(ref.layer.id) + ref.layer:gradient(ref.err_inputs, ref.inputs, ref.outputs) + end +end + +function DAGLayer:update_gradient() + -- print("update gradient") + for id, ref in pairs(self.queue) do + -- print(ref.layer.id) + ref.layer:update_gradient() + end +end + function DAGLayer:propagate(input, output) self:set_inputs(input) self:set_outputs(output) |