summaryrefslogtreecommitdiff
path: root/nn/layer_dag.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nn/layer_dag.lua')
-rw-r--r--nn/layer_dag.lua28
1 files changed, 19 insertions, 9 deletions
diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua
index 3951bfa..2dda7c9 100644
--- a/nn/layer_dag.lua
+++ b/nn/layer_dag.lua
@@ -38,7 +38,7 @@ local function discover(id, layers, layer_repo)
return ref
end
-function nerv.DAGLayer:__init(id, global_conf, layer_conf)
+function DAGLayer:__init(id, global_conf, layer_conf)
local layers = {}
local inputs = {}
local outputs = {}
@@ -131,7 +131,7 @@ function nerv.DAGLayer:__init(id, global_conf, layer_conf)
self.gconf = global_conf
end
-function nerv.DAGLayer:init(batch_size) -- topology sort
+function DAGLayer:init(batch_size) -- topology sort
for i, conn in ipairs(self.parsed_conn) do
local _, output_dim
local ref_from, port_from, ref_to, port_to
@@ -174,7 +174,7 @@ function nerv.DAGLayer:init(batch_size) -- topology sort
end
end
-function nerv.DAGLayer:set_inputs(input)
+function DAGLayer:set_inputs(input)
for i = 1, #self.dim_in do
local layer = self.inputs[i][1]
local port = self.inputs[i][2]
@@ -182,7 +182,7 @@ function nerv.DAGLayer:set_inputs(input)
end
end
-function nerv.DAGLayer:set_outputs(output)
+function DAGLayer:set_outputs(output)
for i = 1, #self.dim_out do
local layer = self.outputs[i][1]
local port = self.outputs[i][2]
@@ -190,7 +190,7 @@ function nerv.DAGLayer:set_outputs(output)
end
end
-function nerv.DAGLayer:set_err_inputs(bp_err)
+function DAGLayer:set_err_inputs(bp_err)
for i = 1, #self.dim_out do
local layer = self.outputs[i][1]
local port = self.outputs[i][2]
@@ -198,7 +198,7 @@ function nerv.DAGLayer:set_err_inputs(bp_err)
end
end
-function nerv.DAGLayer:set_err_outputs(next_bp_err)
+function DAGLayer:set_err_outputs(next_bp_err)
for i = 1, #self.dim_in do
local layer = self.inputs[i][1]
local port = self.inputs[i][2]
@@ -206,7 +206,7 @@ function nerv.DAGLayer:set_err_outputs(next_bp_err)
end
end
-function nerv.DAGLayer:update(bp_err, input, output)
+function DAGLayer:update(bp_err, input, output)
self:set_err_inputs(bp_err)
self:set_inputs(input)
self:set_outputs(output)
@@ -217,7 +217,7 @@ function nerv.DAGLayer:update(bp_err, input, output)
end
end
-function nerv.DAGLayer:propagate(input, output)
+function DAGLayer:propagate(input, output)
self:set_inputs(input)
self:set_outputs(output)
for i = 1, #self.queue do
@@ -227,7 +227,7 @@ function nerv.DAGLayer:propagate(input, output)
end
end
-function nerv.DAGLayer:back_propagate(next_bp_err, bp_err, input, output)
+function DAGLayer:back_propagate(next_bp_err, bp_err, input, output)
self:set_err_outputs(next_bp_err)
self:set_err_inputs(bp_err)
self:set_inputs(input)
@@ -238,3 +238,13 @@ function nerv.DAGLayer:back_propagate(next_bp_err, bp_err, input, output)
ref.layer:back_propagate(ref.err_outputs, ref.err_inputs, ref.inputs, ref.outputs)
end
end
+
+function DAGLayer:get_params()
+ local res = {}
+ for id, ref in pairs(self.queue) do
+ for i, p in ipairs(ref.layer:get_params()) do
+ table.insert(res, p)
+ end
+ end
+ return res
+end