aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r--nerv/nn/network.lua32
1 files changed, 17 insertions, 15 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 3cf052b..0bbcc59 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -320,7 +320,7 @@ function network:make_initial_store()
end
function network:set_input(input)
- for t = 1, #self.chunk_size do
+ for t = 1, self.chunk_size do
for i = 1, #self.dim_in do
local edge = self.socket.inputs[i]
local id, port, time = edge[1], edge[2], edge[3]
@@ -332,7 +332,7 @@ function network:set_input(input)
end
function network:set_output(output)
- for t = 1, #self.chunk_size do
+ for t = 1, self.chunk_size do
for i = 1, #self.dim_out do
local edge = self.socket.outputs[i]
local id, port, time = edge[1], edge[2], edge[3]
@@ -344,7 +344,7 @@ function network:set_output(output)
end
function network:set_err_input(err_input)
- for t = 1, #self.chunk_size do
+ for t = 1, self.chunk_size do
for i = 1, #self.dim_out do
local edge = self.socket.outputs[i]
local id, port, time = edge[1], edge[2], edge[3]
@@ -391,7 +391,9 @@ function network:mini_batch_init(information)
for i = 1, #self.layers do
local _, dim_out = self.layers[i]:get_dim()
for j = 1, #dim_out do
- self.legacy[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j])
+ if t + self.chunk_size >= 1 and self.output_conn[i][j][1] ~= 0 then
+ self.legacy[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j])
+ end
for k = 1, #self.info.new_seq do
local batch = self.info.new_seq[k]
self.legacy[t][i][j][batch - 1]:fill(self.nn_act_default)
@@ -414,8 +416,8 @@ function network:mini_batch_init(information)
end
function network:propagate(input, output)
- network:set_input(input)
- network:set_output(output)
+ self:set_input(input)
+ self:set_output(output)
for i = 1, #self.queue do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then
@@ -433,18 +435,18 @@ function network:propagate(input, output)
end
function network:back_propagate(bp_err, next_bp_err, input, output)
- network:set_input(input)
- network:set_output(output)
- network:set_err_input(bp_err)
- network:set_err_output(next_bp_err)
+ self:set_input(input)
+ self:set_output(output)
+ self:set_err_input(bp_err)
+ self:set_err_output(next_bp_err)
for i = #self.queue, 1, -1 do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then
-- flush border gradient
for j = 1, #self.border[t] do
local batch = self.border[t][j]
- local dim_in, _ = self.layers[id]:get_dim()
- for k = 1, #dim_in do
+ local _, dim_out = self.layers[id]:get_dim()
+ for k = 1, #dim_out do
self.err_input[t][id][k][batch - 1]:fill(0)
end
end
@@ -460,9 +462,9 @@ function network:back_propagate(bp_err, next_bp_err, input, output)
end
function network:update(bp_err, input, output)
- network:set_input(input)
- network:set_output(output)
- network:set_err_input(bp_err)
+ self:set_input(input)
+ self:set_output(output)
+ self:set_err_input(bp_err)
for i = 1, #self.queue do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then