aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r--nerv/nn/network.lua60
1 files changed, 41 insertions, 19 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 7e2af4e..bb03be4 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -398,10 +398,12 @@ function network:make_initial_store()
end
for d = 1, self.delay do
for t = 1 - d, 0 do
- for i = 1, #self.output_edge[d] do
- local edge = self.output_edge[d][i]
- local id, port = edge[1], edge[2]
- self.legacy[t][id][port] = memory[t][id][port]
+ if t + self.chunk_size >= 1 then
+ for i = 1, #self.output_edge[d] do
+ local edge = self.output_edge[d][i]
+ local id, port = edge[1], edge[2]
+ self.legacy[t][id][port] = memory[t][id][port]
+ end
end
end
end
@@ -486,17 +488,19 @@ function network:mini_batch_init(info)
self.gconf.mask[t]:copy_fromh(tmp)
end
- -- calculate border
+ -- calculate max length
self.max_length = 0
+ for i = 1, self.batch_size do
+ self.max_length = math.max(self.max_length, self.info.seq_length[i])
+ end
+
+ -- calculate border
self.timestamp = self.timestamp + 1
for i = 1, self.batch_size do
- if self.info.seq_length[i] > self.max_length then
- self.max_length = self.info.seq_length[i]
- end
local border = self.info.seq_length[i]
for d = 1, self.delay do
for t = border + 1, border + d do
- if t > self.chunk_size then
+ if t > self.max_length then
break
end
for j = 1, #self.output_edge[-d] do
@@ -532,23 +536,41 @@ function network:mini_batch_init(info)
end
end
+ -- flush border gradient
+ if self.info.do_train then
+ local border = self.max_length
+ for d = 1, self.delay do
+ for t = border + 1, border + d do
+ if t > self.chunk_size then
+ break
+ end
+ for j = 1, #self.input_edge[d] do
+ local edge = self.input_edge[d][j]
+ local id, port = edge[1], edge[2]
+ self.err_output[t][id][port]:fill(0)
+ end
+ end
+ end
+ end
+
-- copy legacy
for d = 1, self.delay do
for t = 1 - d, 0 do
- for i = 1, #self.output_edge[d] do
- local edge = self.output_edge[d][i]
- local id, port = edge[1], edge[2]
- if t + self.chunk_size >= 1 and self.output_conn[id][port][1] ~= 0 then
- self.legacy[t][id][port]:copy_from(self.output[t + self.chunk_size][id][port])
- end
- for j = 1, #self.info.new_seq do
- local batch = self.info.new_seq[j]
- self.legacy[t][id][port][batch - 1]:fill(self.nn_act_default)
+ if t + self.chunk_size >= 1 then
+ for i = 1, #self.output_edge[d] do
+ local edge = self.output_edge[d][i]
+ local id, port = edge[1], edge[2]
+ if self.output_conn[id][port][1] ~= 0 then
+ self.legacy[t][id][port]:copy_from(self.output[t + self.chunk_size][id][port])
+ end
+ for j = 1, #self.info.new_seq do
+ local batch = self.info.new_seq[j]
+ self.legacy[t][id][port][batch - 1]:fill(self.nn_act_default)
+ end
end
end
end
end
-
end
function network:propagate()