aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r--nerv/nn/network.lua30
1 files changed, 27 insertions, 3 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 910cdad..b06028e 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -228,9 +228,11 @@ function network:make_initial_store()
err_memory[t][i][j] = self.mat_type(self.batch_size, dim_in[j])
err_memory[t][i][j]:fill(0)
end
- for j = 1, #dim_out do
- memory[t][i][j] = self.mat_type(self.batch_size, dim_out[j])
- memory[t][i][j]:fill(self.nn_act_default)
+ if t < 1 or t > self.chunk_size or not nerv.is_type(self.layers[i], 'nerv.DuplicateLayer') then
+ for j = 1, #dim_out do
+ memory[t][i][j] = self.mat_type(self.batch_size, dim_out[j])
+ memory[t][i][j]:fill(self.nn_act_default)
+ end
end
end
if t < 1 or t > self.chunk_size then
@@ -288,6 +290,28 @@ function network:make_initial_store()
end
end
+ -- reference copy for duplicate layer
+ for i = 1, #self.queue do
+ local t, id = self.queue[i].chunk, self.queue[i].id
+ if nerv.is_type(self.layers[id], 'nerv.DuplicateLayer') then
+ local _, dim_out = self.layers[id]:get_dim()
+ for j = 1, #dim_out do
+ if self.output[t][id][j] ~= nil then
+ nerv.error('duplicate output reference not nil')
+ end
+ self.output[t][id][j] = self.input[t][id][1]
+ local edge = self.output_conn[id][j]
+ local to, port, time = edge[1], edge[2], edge[3] + t
+ if time >= 1 and time <= self.chunk_size then
+ if self.input[time][to][port] ~= nil then
+ nerv.error('duplicate input reference not nil')
+ end
+ self.input[time][to][port] = self.output[t][id][j]
+ end
+ end
+ end
+ end
+
-- check dangling reference
for t = 1, self.chunk_size do
for i = 1, #self.dim_in do