From f26288ba61d3d16866e1b227a71e7d9c46923436 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Fri, 11 Mar 2016 13:32:00 +0800 Subject: update mini_batch_init --- lua/config.lua | 4 ++-- lua/main.lua | 4 +++- lua/network.lua | 30 +++++++++++++++++------------- lua/reader.lua | 3 ++- 4 files changed, 24 insertions(+), 17 deletions(-) (limited to 'lua') diff --git a/lua/config.lua b/lua/config.lua index 9d73b64..1ec1198 100644 --- a/lua/config.lua +++ b/lua/config.lua @@ -12,7 +12,7 @@ function get_global_conf() layer_num = 1, chunk_size = 15, batch_size = 20, - max_iter = 1, + max_iter = 3, param_random = function() return (math.random() / 5 - 0.1) end, dropout = 0.5, timer = nerv.Timer(), @@ -34,7 +34,7 @@ function get_layers(global_conf) output = {dim_in = {global_conf.hidden_size}, dim_out = {global_conf.vocab_size}, pr = pr} }, ['nerv.SoftmaxCELayer'] = { - softmax = {dim_in = {global_conf.vocab_size, global_conf.vocab_size}, dim_out = {1}}, + softmax = {dim_in = {global_conf.vocab_size, global_conf.vocab_size}, dim_out = {1}, compressed = true}, }, } for i = 1, global_conf.layer_num do diff --git a/lua/main.lua b/lua/main.lua index 684efac..ce0270a 100644 --- a/lua/main.lua +++ b/lua/main.lua @@ -9,7 +9,7 @@ local timer = global_conf.timer timer:tic('IO') local data_path = 'nerv/nerv/examples/lmptb/PTBdata/' -local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds') +local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds') local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds') local train_data = train_reader:get_all_batch(global_conf) @@ -41,3 +41,5 @@ for i = 1, global_conf.max_iter do end timer:toc('global') timer:check('global') +timer:check('network') +timer:check('gc') diff --git a/lua/network.lua b/lua/network.lua index 6280f24..0c11321 100644 --- a/lua/network.lua +++ b/lua/network.lua @@ -57,12 +57,11 @@ function nn:get_data(data) ret[i].err_output[t][1] = err_output[t] ret[i].err_output[t][2] = softmax_output[t] end - ret[i].info = {} - ret[i].info.seq_length = data[i].seq_len - ret[i].info.new_seq = {} + ret[i].seq_length = data[i].seq_len + ret[i].new_seq = {} for j = 1, self.gconf.batch_size do if data[i].seq_start[j] then - table.insert(ret[i].info.new_seq, j) + table.insert(ret[i].new_seq, j) end end end @@ -70,34 +69,39 @@ function nn:get_data(data) end function nn:process(data, do_train) + local timer = self.gconf.timer local total_err = 0 local total_frame = 0 for id = 1, #data do if do_train then self.gconf.dropout_rate = self.gconf.dropout + data[id].do_train = true else self.gconf.dropout_rate = 0 + data[id].do_train = false end - self.network:mini_batch_init(data[id].info) - local input = {} - for t = 1, self.gconf.chunk_size do - input[t] = {data[id].input[t][1], data[id].input[t][2]:decompress(self.gconf.vocab_size)} - end - self.network:propagate(input, data[id].output) + timer:tic('network') + self.network:mini_batch_init(data[id]) + self.network:propagate() + timer:toc('network') for t = 1, self.gconf.chunk_size do local tmp = data[id].output[t][1]:new_to_host() for i = 1, self.gconf.batch_size do - if t <= data[id].info.seq_length[i] then + if t <= data[id].seq_length[i] then total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) total_frame = total_frame + 1 end end end if do_train then - self.network:back_propagate(data[id].err_input, data[id].err_output, input, data[id].output) - self.network:update(data[id].err_input, input, data[id].output) + timer:tic('network') + self.network:back_propagate() + self.network:update() + timer:toc('network') end + timer:tic('gc') collectgarbage('collect') + timer:toc('gc') end return math.pow(10, - total_err / total_frame) end diff --git a/lua/reader.lua b/lua/reader.lua index 2e51a9c..0c7bcb6 100644 --- a/lua/reader.lua +++ b/lua/reader.lua @@ -58,7 +58,8 @@ function Reader:get_all_batch(global_conf) for i = 1, global_conf.batch_size do pos[i] = nil end - while true do + --while true do + for i = 1, 100 do local input = {} local output = {} for i = 1, global_conf.chunk_size do -- cgit v1.2.3-70-g09d2