From f26288ba61d3d16866e1b227a71e7d9c46923436 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Fri, 11 Mar 2016 13:32:00 +0800 Subject: update mini_batch_init --- nerv/nn/network.lua | 63 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 26 deletions(-) (limited to 'nerv/nn/network.lua') diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index 39df5f0..35e11e3 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -2,8 +2,9 @@ local network = nerv.class('nerv.Network') function network:__init(id, global_conf, network_conf) self.id = id - self.dim_in = network_conf.network.dim_in - self.dim_out = network_conf.network.dim_out + self.network = network_conf.network + self.dim_in = self.network.dim_in + self.dim_out = self.network.dim_out self.gconf = global_conf if self.gconf.use_cpu then self.mat_type = self.gconf.mmat_type @@ -18,7 +19,6 @@ function network:__init(id, global_conf, network_conf) self.layers = {} self.input_conn = {} self.output_conn = {} - self.network = network_conf.network self.socket = self:compile(self.network) for i = 1, #self.dim_in do local edge = self.socket.inputs[i] @@ -368,8 +368,21 @@ function network:set_err_output(err_output) end end -function network:mini_batch_init(information) - self.info = information +--[[ + [info] is a table that contains information of current mini-batch. These fields must be contained: + [input], [output] : matrix array which stores the network input and output + [seq_length] : a table contains the length of every sequences + [new_seq]: a table contains the batch number of new sequences + [do_train]: a bool value indicates do train or not + if [do_train] is true, these fileds also must be contained: + [err_input], [err_output] : matrix array which stores the network err_input and err_output +--]] +function network:mini_batch_init(info) + self.info = info + self:set_input(self.info.input) + self:set_output(self.info.output) + + -- calculate border self.max_length = 0 self.border = {} for i = 1, self.chunk_size do @@ -387,6 +400,7 @@ function network:mini_batch_init(information) table.insert(self.border[chunk], i) end end + -- copy legacy for t = 1 - self.delay, 0 do for i = 1, #self.layers do @@ -402,23 +416,27 @@ function network:mini_batch_init(information) end end end - -- flush border gradient - for t = self.max_length + 1, self.max_length + self.delay do - if t > self.chunk_size then - break - end - for i = 1, #self.layers do - local dim_in, _ = self.layers[i]:get_dim() - for j = 1, #dim_in do - self.err_output[t][i][j]:fill(0) + + if self.info.do_train then + self:set_err_input(self.info.err_input) + self:set_err_output(self.info.err_output) + + -- flush border gradient + for t = self.max_length + 1, self.max_length + self.delay do + if t > self.chunk_size then + break + end + for i = 1, #self.layers do + local dim_in, _ = self.layers[i]:get_dim() + for j = 1, #dim_in do + self.err_output[t][i][j]:fill(0) + end end end end end -function network:propagate(input, output) - self:set_input(input) - self:set_output(output) +function network:propagate() for i = 1, #self.queue do local t, id = self.queue[i].chunk, self.queue[i].id if t <= self.max_length then @@ -435,11 +453,7 @@ function network:propagate(input, output) end end -function network:back_propagate(bp_err, next_bp_err, input, output) - self:set_input(input) - self:set_output(output) - self:set_err_input(bp_err) - self:set_err_output(next_bp_err) +function network:back_propagate() for i = #self.queue, 1, -1 do local t, id = self.queue[i].chunk, self.queue[i].id if t <= self.max_length then @@ -462,10 +476,7 @@ function network:back_propagate(bp_err, next_bp_err, input, output) end end -function network:update(bp_err, input, output) - self:set_input(input) - self:set_output(output) - self:set_err_input(bp_err) +function network:update() for i = 1, #self.queue do local t, id = self.queue[i].chunk, self.queue[i].id if t <= self.max_length then -- cgit v1.2.3-70-g09d2