diff options
Diffstat (limited to 'nerv/nn')
-rw-r--r-- | nerv/nn/network.lua | 60 |
1 files changed, 41 insertions, 19 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index 6f7fe10..cd80b1e 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -398,10 +398,12 @@ function network:make_initial_store() end for d = 1, self.delay do for t = 1 - d, 0 do - for i = 1, #self.output_edge[d] do - local edge = self.output_edge[d][i] - local id, port = edge[1], edge[2] - self.legacy[t][id][port] = memory[t][id][port] + if t + self.chunk_size >= 1 then + for i = 1, #self.output_edge[d] do + local edge = self.output_edge[d][i] + local id, port = edge[1], edge[2] + self.legacy[t][id][port] = memory[t][id][port] + end end end end @@ -486,17 +488,19 @@ function network:mini_batch_init(info) self.gconf.mask[t]:copy_fromh(tmp) end - -- calculate border + -- calculate max length self.max_length = 0 + for i = 1, self.batch_size do + self.max_length = math.max(self.max_length, self.info.seq_length[i]) + end + + -- calculate border self.timestamp = self.timestamp + 1 for i = 1, self.batch_size do - if self.info.seq_length[i] > self.max_length then - self.max_length = self.info.seq_length[i] - end local border = self.info.seq_length[i] for d = 1, self.delay do for t = border + 1, border + d do - if t > self.chunk_size then + if t > self.max_length then break end for j = 1, #self.output_edge[-d] do @@ -532,23 +536,41 @@ function network:mini_batch_init(info) end end + -- flush border gradient + if self.info.do_train then + local border = self.max_length + for d = 1, self.delay do + for t = border + 1, border + d do + if t > self.chunk_size then + break + end + for j = 1, #self.input_edge[d] do + local edge = self.input_edge[d][j] + local id, port = edge[1], edge[2] + self.err_output[t][id][port]:fill(0) + end + end + end + end + -- copy legacy for d = 1, self.delay do for t = 1 - d, 0 do - for i = 1, #self.output_edge[d] do - local edge = self.output_edge[d][i] - local id, port = edge[1], edge[2] - if t + self.chunk_size >= 1 and self.output_conn[id][port][1] ~= 0 then - self.legacy[t][id][port]:copy_from(self.output[t + self.chunk_size][id][port]) - end - for j = 1, #self.info.new_seq do - local batch = self.info.new_seq[j] - self.legacy[t][id][port][batch - 1]:fill(self.nn_act_default) + if t + self.chunk_size >= 1 then + for i = 1, #self.output_edge[d] do + local edge = self.output_edge[d][i] + local id, port = edge[1], edge[2] + if self.output_conn[id][port][1] ~= 0 then + self.legacy[t][id][port]:copy_from(self.output[t + self.chunk_size][id][port]) + end + for j = 1, #self.info.new_seq do + local batch = self.info.new_seq[j] + self.legacy[t][id][port][batch - 1]:fill(self.nn_act_default) + end end end end end - end function network:propagate() |