diff options
author | Qi Liu <[email protected]> | 2016-03-15 10:56:35 +0800 |
---|---|---|
committer | Qi Liu <[email protected]> | 2016-03-15 10:56:35 +0800 |
commit | 51a3beef4a7cbd94278a406664212b6597aedd93 (patch) | |
tree | 298d976792164102cf7e12fb1f351f2f50857326 /nerv/nn | |
parent | b08da1fef90e93b188704056cdae651d7865f98d (diff) |
speedup duplicate layer
Diffstat (limited to 'nerv/nn')
-rw-r--r-- | nerv/nn/network.lua | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index 910cdad..b06028e 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -228,9 +228,11 @@ function network:make_initial_store() err_memory[t][i][j] = self.mat_type(self.batch_size, dim_in[j]) err_memory[t][i][j]:fill(0) end - for j = 1, #dim_out do - memory[t][i][j] = self.mat_type(self.batch_size, dim_out[j]) - memory[t][i][j]:fill(self.nn_act_default) + if t < 1 or t > self.chunk_size or not nerv.is_type(self.layers[i], 'nerv.DuplicateLayer') then + for j = 1, #dim_out do + memory[t][i][j] = self.mat_type(self.batch_size, dim_out[j]) + memory[t][i][j]:fill(self.nn_act_default) + end end end if t < 1 or t > self.chunk_size then @@ -288,6 +290,28 @@ function network:make_initial_store() end end + -- reference copy for duplicate layer + for i = 1, #self.queue do + local t, id = self.queue[i].chunk, self.queue[i].id + if nerv.is_type(self.layers[id], 'nerv.DuplicateLayer') then + local _, dim_out = self.layers[id]:get_dim() + for j = 1, #dim_out do + if self.output[t][id][j] ~= nil then + nerv.error('duplicate output reference not nil') + end + self.output[t][id][j] = self.input[t][id][1] + local edge = self.output_conn[id][j] + local to, port, time = edge[1], edge[2], edge[3] + t + if time >= 1 and time <= self.chunk_size then + if self.input[time][to][port] ~= nil then + nerv.error('duplicate input reference not nil') + end + self.input[time][to][port] = self.output[t][id][j] + end + end + end + end + -- check dangling reference for t = 1, self.chunk_size do for i = 1, #self.dim_in do |