From f26288ba61d3d16866e1b227a71e7d9c46923436 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Fri, 11 Mar 2016 13:32:00 +0800 Subject: update mini_batch_init --- lua/network.lua | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) (limited to 'lua/network.lua') diff --git a/lua/network.lua b/lua/network.lua index 6280f24..0c11321 100644 --- a/lua/network.lua +++ b/lua/network.lua @@ -57,12 +57,11 @@ function nn:get_data(data) ret[i].err_output[t][1] = err_output[t] ret[i].err_output[t][2] = softmax_output[t] end - ret[i].info = {} - ret[i].info.seq_length = data[i].seq_len - ret[i].info.new_seq = {} + ret[i].seq_length = data[i].seq_len + ret[i].new_seq = {} for j = 1, self.gconf.batch_size do if data[i].seq_start[j] then - table.insert(ret[i].info.new_seq, j) + table.insert(ret[i].new_seq, j) end end end @@ -70,34 +69,39 @@ function nn:get_data(data) end function nn:process(data, do_train) + local timer = self.gconf.timer local total_err = 0 local total_frame = 0 for id = 1, #data do if do_train then self.gconf.dropout_rate = self.gconf.dropout + data[id].do_train = true else self.gconf.dropout_rate = 0 + data[id].do_train = false end - self.network:mini_batch_init(data[id].info) - local input = {} - for t = 1, self.gconf.chunk_size do - input[t] = {data[id].input[t][1], data[id].input[t][2]:decompress(self.gconf.vocab_size)} - end - self.network:propagate(input, data[id].output) + timer:tic('network') + self.network:mini_batch_init(data[id]) + self.network:propagate() + timer:toc('network') for t = 1, self.gconf.chunk_size do local tmp = data[id].output[t][1]:new_to_host() for i = 1, self.gconf.batch_size do - if t <= data[id].info.seq_length[i] then + if t <= data[id].seq_length[i] then total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) total_frame = total_frame + 1 end end end if do_train then - self.network:back_propagate(data[id].err_input, data[id].err_output, input, data[id].output) - self.network:update(data[id].err_input, input, data[id].output) + timer:tic('network') + self.network:back_propagate() + self.network:update() + timer:toc('network') end + timer:tic('gc') collectgarbage('collect') + timer:toc('gc') end return math.pow(10, - total_err / total_frame) end -- cgit v1.2.3