From 3dd235c8b6ea7ef275381866d11d1be828d27a06 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Tue, 15 Mar 2016 12:48:07 +0800 Subject: fix duplicate bug on & --- nerv/examples/network_debug/config.lua | 10 +++++++-- nerv/examples/network_debug/reader.lua | 4 ++-- nerv/nn/network.lua | 40 ++++++++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/nerv/examples/network_debug/config.lua b/nerv/examples/network_debug/config.lua index e20d5a9..0429e9a 100644 --- a/nerv/examples/network_debug/config.lua +++ b/nerv/examples/network_debug/config.lua @@ -35,6 +35,10 @@ function get_layers(global_conf) ['nerv.SoftmaxCELayer'] = { softmax = {dim_in = {global_conf.vocab_size, global_conf.vocab_size}, dim_out = {1}, compressed = true}, }, + ['nerv.DuplicateLayer'] = { + dup1 = {dim_in = {1}, dim_out = {1}}, + dup2 = {dim_in = {1}, dim_out = {1}}, + }, } for i = 1, global_conf.layer_num do layers['nerv.LSTMLayer']['lstm' .. i] = {dim_in = {global_conf.hidden_size}, dim_out = {global_conf.hidden_size}, pr = pr} @@ -45,12 +49,14 @@ end function get_connections(global_conf) local connections = { - {'[1]', 'select[1]', 0}, + {'[1]', 'dup1[1]', 0}, + {'dup1[1]', 'select[1]', 0}, {'select[1]', 'lstm1[1]', 0}, {'dropout' .. global_conf.layer_num .. '[1]', 'output[1]', 0}, {'output[1]', 'softmax[1]', 0}, {'[2]', 'softmax[2]', 0}, - {'softmax[1]', '[1]', 0}, + {'softmax[1]', 'dup2[1]', 0}, + {'dup2[1]', '[1]', 0}, } for i = 1, global_conf.layer_num do table.insert(connections, {'lstm' .. i .. '[1]', 'dropout' .. i .. '[1]', 0}) diff --git a/nerv/examples/network_debug/reader.lua b/nerv/examples/network_debug/reader.lua index 76a78cf..70c0c97 100644 --- a/nerv/examples/network_debug/reader.lua +++ b/nerv/examples/network_debug/reader.lua @@ -32,8 +32,8 @@ end function Reader:get_seq(input_file) local f = io.open(input_file, 'r') self.seq = {} - while true do - -- for i = 1, 26 do + -- while true do + for i = 1, 26 do local seq = f:read() if seq == nil then break diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index b06028e..6f7fe10 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -27,7 +27,17 @@ function network:__init(id, global_conf, network_conf) if self.input_conn[id][port] ~= nil then nerv.error('duplicate edge') end - self.input_conn[id][port] = {0, i, time} + if nerv.is_type(self.layers[id], 'nerv.DuplicateLayer') then + local tmp = nerv.IdentityLayer('', self.gconf, {dim_in = {self.dim_in[i]}, dim_out = {self.dim_in[i]}}) + table.insert(self.layers, tmp) + local new_id = #self.layers + self.input_conn[new_id] = {{0, i, time}} + self.output_conn[new_id] = {{id, port, 0}} + self.input_conn[id][port] = {new_id, 1, 0} + self.socket.inputs[i] = {new_id, 1, time} + else + self.input_conn[id][port] = {0, i, time} + end end for i = 1, #self.dim_out do local edge = self.socket.outputs[i] @@ -35,7 +45,17 @@ function network:__init(id, global_conf, network_conf) if self.output_conn[id][port] ~= nil then nerv.error('duplicate edge') end - self.output_conn[id][port] = {0, i, time} + if nerv.is_type(self.layers[id], 'nerv.DuplicateLayer') then + local tmp = nerv.IdentityLayer('', self.gconf, {dim_in = {self.dim_out[i]}, dim_out = {self.dim_out[i]}}) + table.insert(self.layers, tmp) + local new_id = #self.layers + self.input_conn[new_id] = {{id, port, 0}} + self.output_conn[new_id] = {{0, i, time}} + self.output_conn[id][port] = {new_id, 1, 0} + self.socket.outputs[i] = {new_id, 1, time} + else + self.output_conn[id][port] = {0, i, time} + end end self.delay = 0 @@ -140,8 +160,10 @@ function network:init(batch_size, chunk_size) collectgarbage('collect') self.flush = {} + self.gconf.mask = {} for t = 1, self.chunk_size do self.flush[t] = {} + self.gconf.mask[t] = self.mat_type(self.batch_size, 1) end end @@ -348,6 +370,7 @@ function network:make_initial_store() local dim_in, dim_out = self.layers[i]:get_dim() for j = 1, #dim_in do if self.input[t][i][j] == nil then + print(t,i,j,self.layers[i].id) nerv.error('input reference dangling') end if self.err_output[t][i][j] == nil then @@ -450,6 +473,19 @@ function network:mini_batch_init(info) self:set_err_output(self.info.err_output) end + -- calculate mask + for t = 1, self.chunk_size do + local tmp = self.gconf.mmat_type(self.batch_size, 1) + for i = 1, self.batch_size do + if t <= self.info.seq_length[i] then + tmp[i - 1][0] = 1 + else + tmp[i - 1][0] = 0 + end + end + self.gconf.mask[t]:copy_fromh(tmp) + end + -- calculate border self.max_length = 0 self.timestamp = self.timestamp + 1 -- cgit v1.2.3-70-g09d2