From 3dd235c8b6ea7ef275381866d11d1be828d27a06 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Tue, 15 Mar 2016 12:48:07 +0800 Subject: fix duplicate bug on & --- nerv/nn/network.lua | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) (limited to 'nerv/nn/network.lua') diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index b06028e..6f7fe10 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -27,7 +27,17 @@ function network:__init(id, global_conf, network_conf) if self.input_conn[id][port] ~= nil then nerv.error('duplicate edge') end - self.input_conn[id][port] = {0, i, time} + if nerv.is_type(self.layers[id], 'nerv.DuplicateLayer') then + local tmp = nerv.IdentityLayer('', self.gconf, {dim_in = {self.dim_in[i]}, dim_out = {self.dim_in[i]}}) + table.insert(self.layers, tmp) + local new_id = #self.layers + self.input_conn[new_id] = {{0, i, time}} + self.output_conn[new_id] = {{id, port, 0}} + self.input_conn[id][port] = {new_id, 1, 0} + self.socket.inputs[i] = {new_id, 1, time} + else + self.input_conn[id][port] = {0, i, time} + end end for i = 1, #self.dim_out do local edge = self.socket.outputs[i] @@ -35,7 +45,17 @@ function network:__init(id, global_conf, network_conf) if self.output_conn[id][port] ~= nil then nerv.error('duplicate edge') end - self.output_conn[id][port] = {0, i, time} + if nerv.is_type(self.layers[id], 'nerv.DuplicateLayer') then + local tmp = nerv.IdentityLayer('', self.gconf, {dim_in = {self.dim_out[i]}, dim_out = {self.dim_out[i]}}) + table.insert(self.layers, tmp) + local new_id = #self.layers + self.input_conn[new_id] = {{id, port, 0}} + self.output_conn[new_id] = {{0, i, time}} + self.output_conn[id][port] = {new_id, 1, 0} + self.socket.outputs[i] = {new_id, 1, time} + else + self.output_conn[id][port] = {0, i, time} + end end self.delay = 0 @@ -140,8 +160,10 @@ function network:init(batch_size, chunk_size) collectgarbage('collect') self.flush = {} + self.gconf.mask = {} for t = 1, self.chunk_size do self.flush[t] = {} + self.gconf.mask[t] = self.mat_type(self.batch_size, 1) end end @@ -348,6 +370,7 @@ function network:make_initial_store() local dim_in, dim_out = self.layers[i]:get_dim() for j = 1, #dim_in do if self.input[t][i][j] == nil then + print(t,i,j,self.layers[i].id) nerv.error('input reference dangling') end if self.err_output[t][i][j] == nil then @@ -450,6 +473,19 @@ function network:mini_batch_init(info) self:set_err_output(self.info.err_output) end + -- calculate mask + for t = 1, self.chunk_size do + local tmp = self.gconf.mmat_type(self.batch_size, 1) + for i = 1, self.batch_size do + if t <= self.info.seq_length[i] then + tmp[i - 1][0] = 1 + else + tmp[i - 1][0] = 0 + end + end + self.gconf.mask[t]:copy_fromh(tmp) + end + -- calculate border self.max_length = 0 self.timestamp = self.timestamp + 1 -- cgit v1.2.3