From eba6049a82455499c68ee875843b6f44d6164fa5 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 5 Jun 2015 16:56:33 +0800 Subject: add close method for ChunkFile, fix #18 --- nn/layer_dag.lua | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) (limited to 'nn') diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua index 4ee829e..3951bfa 100644 --- a/nn/layer_dag.lua +++ b/nn/layer_dag.lua @@ -210,7 +210,9 @@ function nerv.DAGLayer:update(bp_err, input, output) self:set_err_inputs(bp_err) self:set_inputs(input) self:set_outputs(output) + -- print("update") for id, ref in pairs(self.queue) do + -- print(ref.layer.id) ref.layer:update(ref.err_inputs, ref.inputs, ref.outputs) end end @@ -220,11 +222,7 @@ function nerv.DAGLayer:propagate(input, output) self:set_outputs(output) for i = 1, #self.queue do local ref = self.queue[i] - --[[ - print(ref.inputs[1]) - print(ref.outputs[1]) - print(#ref.inputs, #ref.outputs) - --]] + -- print(ref.layer.id) ref.layer:propagate(ref.inputs, ref.outputs) end end @@ -238,8 +236,5 @@ function nerv.DAGLayer:back_propagate(next_bp_err, bp_err, input, output) local ref = self.queue[i] -- print(ref.layer.id) ref.layer:back_propagate(ref.err_outputs, ref.err_inputs, ref.inputs, ref.outputs) - -- if #ref.err_outputs > 0 then - -- print(ref.err_outputs[1]) - -- end end end -- cgit v1.2.3 From 37af4bed9c3680fdb9db569605f15013e9b6b64d Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 5 Jun 2015 17:53:05 +0800 Subject: add get_params to all layers --- nn/layer_dag.lua | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) (limited to 'nn') diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua index 3951bfa..2dda7c9 100644 --- a/nn/layer_dag.lua +++ b/nn/layer_dag.lua @@ -38,7 +38,7 @@ local function discover(id, layers, layer_repo) return ref end -function nerv.DAGLayer:__init(id, global_conf, layer_conf) +function DAGLayer:__init(id, global_conf, layer_conf) local layers = {} local inputs = {} local outputs = {} @@ -131,7 +131,7 @@ function nerv.DAGLayer:__init(id, global_conf, layer_conf) self.gconf = global_conf end -function nerv.DAGLayer:init(batch_size) -- topology sort +function DAGLayer:init(batch_size) -- topology sort for i, conn in ipairs(self.parsed_conn) do local _, output_dim local ref_from, port_from, ref_to, port_to @@ -174,7 +174,7 @@ function nerv.DAGLayer:init(batch_size) -- topology sort end end -function nerv.DAGLayer:set_inputs(input) +function DAGLayer:set_inputs(input) for i = 1, #self.dim_in do local layer = self.inputs[i][1] local port = self.inputs[i][2] @@ -182,7 +182,7 @@ function nerv.DAGLayer:set_inputs(input) end end -function nerv.DAGLayer:set_outputs(output) +function DAGLayer:set_outputs(output) for i = 1, #self.dim_out do local layer = self.outputs[i][1] local port = self.outputs[i][2] @@ -190,7 +190,7 @@ function nerv.DAGLayer:set_outputs(output) end end -function nerv.DAGLayer:set_err_inputs(bp_err) +function DAGLayer:set_err_inputs(bp_err) for i = 1, #self.dim_out do local layer = self.outputs[i][1] local port = self.outputs[i][2] @@ -198,7 +198,7 @@ function nerv.DAGLayer:set_err_inputs(bp_err) end end -function nerv.DAGLayer:set_err_outputs(next_bp_err) +function DAGLayer:set_err_outputs(next_bp_err) for i = 1, #self.dim_in do local layer = self.inputs[i][1] local port = self.inputs[i][2] @@ -206,7 +206,7 @@ function nerv.DAGLayer:set_err_outputs(next_bp_err) end end -function nerv.DAGLayer:update(bp_err, input, output) +function DAGLayer:update(bp_err, input, output) self:set_err_inputs(bp_err) self:set_inputs(input) self:set_outputs(output) @@ -217,7 +217,7 @@ function nerv.DAGLayer:update(bp_err, input, output) end end -function nerv.DAGLayer:propagate(input, output) +function DAGLayer:propagate(input, output) self:set_inputs(input) self:set_outputs(output) for i = 1, #self.queue do @@ -227,7 +227,7 @@ function nerv.DAGLayer:propagate(input, output) end end -function nerv.DAGLayer:back_propagate(next_bp_err, bp_err, input, output) +function DAGLayer:back_propagate(next_bp_err, bp_err, input, output) self:set_err_outputs(next_bp_err) self:set_err_inputs(bp_err) self:set_inputs(input) @@ -238,3 +238,13 @@ function nerv.DAGLayer:back_propagate(next_bp_err, bp_err, input, output) ref.layer:back_propagate(ref.err_outputs, ref.err_inputs, ref.inputs, ref.outputs) end end + +function DAGLayer:get_params() + local res = {} + for id, ref in pairs(self.queue) do + for i, p in ipairs(ref.layer:get_params()) do + table.insert(res, p) + end + end + return res +end -- cgit v1.2.3