From f26288ba61d3d16866e1b227a71e7d9c46923436 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Fri, 11 Mar 2016 13:32:00 +0800 Subject: update mini_batch_init --- nerv/main.lua | 73 ----------------------------------------------------- nerv/nn/network.lua | 63 ++++++++++++++++++++++++++------------------- 2 files changed, 37 insertions(+), 99 deletions(-) delete mode 100644 nerv/main.lua (limited to 'nerv') diff --git a/nerv/main.lua b/nerv/main.lua deleted file mode 100644 index 7c82ebf..0000000 --- a/nerv/main.lua +++ /dev/null @@ -1,73 +0,0 @@ -local global_conf = { - cumat_type = nerv.CuMatrixFloat, - param_random = function() return 0 end, - lrate = 0.1, - wcost = 0, - momentum = 0.9, - batch_size = 2, -} - -local layer_repo = nerv.LayerRepo( - { - ['nerv.RNNLayer'] = { - rnn1 = {dim_in = {23}, dim_out = {26}}, - rnn2 = {dim_in = {26}, dim_out = {26}}, - }, - ['nerv.AffineLayer'] = { - input = {dim_in = {62}, dim_out = {23}}, - output = {dim_in = {26, 79}, dim_out = {79}}, - }, - ['nerv.SigmoidLayer'] = { - sigmoid = {dim_in = {23}, dim_out = {23}}, - }, - ['nerv.IdentityLayer'] = { - softmax = {dim_in = {79}, dim_out = {79}}, - }, - ['nerv.DuplicateLayer'] = { - dup = {dim_in = {79}, dim_out = {79, 79}}, - }, - }, nerv.ParamRepo(), global_conf) - -local connections = { - {'[1]', 'input[1]', 0}, - {'input[1]', 'sigmoid[1]', 0}, - {'sigmoid[1]', 'rnn1[1]', 0}, - {'rnn1[1]', 'rnn2[1]', 0}, - {'rnn2[1]', 'output[1]', 0}, - {'output[1]', 'dup[1]', 0}, - {'dup[1]', 'output[2]', -1}, - {'dup[2]', 'softmax[1]', 0}, - {'softmax[1]', '[1]', 0}, -} - -local graph = nerv.GraphLayer('graph', global_conf, {dim_in = {62}, dim_out = {79}, layer_repo = layer_repo, connections = connections}) - -local network = nerv.Network('network', global_conf, {network = graph}) - -local batch = global_conf.batch_size -local chunk = 5 -network:init(batch, chunk) - -local input = {} -local output = {} -local err_input = {} -local err_output = {} -local input_size = 62 -local output_size = 79 -for i = 1, chunk do - input[i] = {global_conf.cumat_type(batch, input_size)} - output[i] = {global_conf.cumat_type(batch, output_size)} - err_input[i] = {global_conf.cumat_type(batch, output_size)} - err_output[i] = {global_conf.cumat_type(batch, input_size)} -end - -for i = 1, 100 do - network:mini_batch_init({seq_length = {5, 3}, new_seq = {2}}) - network:propagate(input, output) - network:back_propagate(err_input, err_output, input, output) - network:update(err_input, input, output) -end - -local tmp = network:get_params() - -tmp:export('../../workspace/test.param') diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index 39df5f0..35e11e3 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -2,8 +2,9 @@ local network = nerv.class('nerv.Network') function network:__init(id, global_conf, network_conf) self.id = id - self.dim_in = network_conf.network.dim_in - self.dim_out = network_conf.network.dim_out + self.network = network_conf.network + self.dim_in = self.network.dim_in + self.dim_out = self.network.dim_out self.gconf = global_conf if self.gconf.use_cpu then self.mat_type = self.gconf.mmat_type @@ -18,7 +19,6 @@ function network:__init(id, global_conf, network_conf) self.layers = {} self.input_conn = {} self.output_conn = {} - self.network = network_conf.network self.socket = self:compile(self.network) for i = 1, #self.dim_in do local edge = self.socket.inputs[i] @@ -368,8 +368,21 @@ function network:set_err_output(err_output) end end -function network:mini_batch_init(information) - self.info = information +--[[ + [info] is a table that contains information of current mini-batch. These fields must be contained: + [input], [output] : matrix array which stores the network input and output + [seq_length] : a table contains the length of every sequences + [new_seq]: a table contains the batch number of new sequences + [do_train]: a bool value indicates do train or not + if [do_train] is true, these fileds also must be contained: + [err_input], [err_output] : matrix array which stores the network err_input and err_output +--]] +function network:mini_batch_init(info) + self.info = info + self:set_input(self.info.input) + self:set_output(self.info.output) + + -- calculate border self.max_length = 0 self.border = {} for i = 1, self.chunk_size do @@ -387,6 +400,7 @@ function network:mini_batch_init(information) table.insert(self.border[chunk], i) end end + -- copy legacy for t = 1 - self.delay, 0 do for i = 1, #self.layers do @@ -402,23 +416,27 @@ function network:mini_batch_init(information) end end end - -- flush border gradient - for t = self.max_length + 1, self.max_length + self.delay do - if t > self.chunk_size then - break - end - for i = 1, #self.layers do - local dim_in, _ = self.layers[i]:get_dim() - for j = 1, #dim_in do - self.err_output[t][i][j]:fill(0) + + if self.info.do_train then + self:set_err_input(self.info.err_input) + self:set_err_output(self.info.err_output) + + -- flush border gradient + for t = self.max_length + 1, self.max_length + self.delay do + if t > self.chunk_size then + break + end + for i = 1, #self.layers do + local dim_in, _ = self.layers[i]:get_dim() + for j = 1, #dim_in do + self.err_output[t][i][j]:fill(0) + end end end end end -function network:propagate(input, output) - self:set_input(input) - self:set_output(output) +function network:propagate() for i = 1, #self.queue do local t, id = self.queue[i].chunk, self.queue[i].id if t <= self.max_length then @@ -435,11 +453,7 @@ function network:propagate(input, output) end end -function network:back_propagate(bp_err, next_bp_err, input, output) - self:set_input(input) - self:set_output(output) - self:set_err_input(bp_err) - self:set_err_output(next_bp_err) +function network:back_propagate() for i = #self.queue, 1, -1 do local t, id = self.queue[i].chunk, self.queue[i].id if t <= self.max_length then @@ -462,10 +476,7 @@ function network:back_propagate(bp_err, next_bp_err, input, output) end end -function network:update(bp_err, input, output) - self:set_input(input) - self:set_output(output) - self:set_err_input(bp_err) +function network:update() for i = 1, #self.queue do local t, id = self.queue[i].chunk, self.queue[i].id if t <= self.max_length then -- cgit v1.2.3-70-g09d2