diff options
author | Qi Liu <[email protected]> | 2016-03-14 20:07:15 +0800 |
---|---|---|
committer | Qi Liu <[email protected]> | 2016-03-14 20:07:15 +0800 |
commit | b08da1fef90e93b188704056cdae651d7865f98d (patch) | |
tree | 4ea507d591621920e476c246c393049c8c22616b | |
parent | 35496b6a648d98dc41d6226c1d43650aba58cdfc (diff) |
speedup border flush
-rw-r--r-- | Makefile | 5 | ||||
-rw-r--r-- | nerv/examples/network_debug/config.lua | 2 | ||||
-rw-r--r-- | nerv/examples/network_debug/main.lua | 2 | ||||
-rw-r--r-- | nerv/examples/network_debug/reader.lua | 4 | ||||
-rw-r--r-- | nerv/nn/network.lua | 215 |
5 files changed, 141 insertions, 87 deletions
@@ -1,4 +1,4 @@ -.PHONY: all clean install luajit luarocks speech +.PHONY: all clean install luajit luarocks speech submodule ############## EDIT THESE LINES ##################### SHELL := /bin/bash PREFIX := $(CURDIR)/install/ @@ -26,7 +26,8 @@ export BLAS_LDFLAGS nerv-clean speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean \ Penlight -all: luajit luarocks Penlight nerv +all: nerv +submodule: luajit luajit Penlight luajit: PREFIX=$(PREFIX) ./tools/build_luajit.sh luarocks: diff --git a/nerv/examples/network_debug/config.lua b/nerv/examples/network_debug/config.lua index 9025b78..093bde2 100644 --- a/nerv/examples/network_debug/config.lua +++ b/nerv/examples/network_debug/config.lua @@ -12,7 +12,7 @@ function get_global_conf() layer_num = 1, chunk_size = 15, batch_size = 20, - max_iter = 3, + max_iter = 1, param_random = function() return (math.random() / 5 - 0.1) end, dropout_rate = 0.5, timer = nerv.Timer(), diff --git a/nerv/examples/network_debug/main.lua b/nerv/examples/network_debug/main.lua index 1bee43c..bbcdb6c 100644 --- a/nerv/examples/network_debug/main.lua +++ b/nerv/examples/network_debug/main.lua @@ -20,12 +20,12 @@ for i = 1, global_conf.max_iter do local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds') local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds') local train_ppl, val_ppl = NN:epoch(train_reader, val_reader) + nerv.printf('Epoch %d: %f %f %f\n', i, global_conf.lrate, train_ppl, val_ppl) if val_ppl < best_cv then best_cv = val_ppl else global_conf.lrate = global_conf.lrate / 2.0 end - nerv.printf('Epoch %d: %f %f %f\n', i, global_conf.lrate, train_ppl, val_ppl) timer:toc('Epoch' .. i) timer:check('Epoch' .. i) io.flush() diff --git a/nerv/examples/network_debug/reader.lua b/nerv/examples/network_debug/reader.lua index 70c0c97..76a78cf 100644 --- a/nerv/examples/network_debug/reader.lua +++ b/nerv/examples/network_debug/reader.lua @@ -32,8 +32,8 @@ end function Reader:get_seq(input_file) local f = io.open(input_file, 'r') self.seq = {} - -- while true do - for i = 1, 26 do + while true do + -- for i = 1, 26 do local seq = f:read() if seq == nil then break diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index 2cb83ce..910cdad 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -16,6 +16,7 @@ function network:__init(id, global_conf, network_conf) if self.nn_act_default == nil then self.nn_act_default = 0 end + self.layers = {} self.input_conn = {} self.output_conn = {} @@ -36,16 +37,41 @@ function network:__init(id, global_conf, network_conf) end self.output_conn[id][port] = {0, i, time} end + self.delay = 0 for i = 1, #self.layers do local dim_in, _ = self.layers[i]:get_dim() for j = 1, #dim_in do + if self.input_conn[i][j] == nil then + nerv.error('dangling input') + end local time = self.input_conn[i][j][3] if math.abs(time) > self.delay then self.delay = math.abs(time) end end end + + self.input_edge = {} + self.output_edge = {} + for t = -self.delay, self.delay do + self.input_edge[t] = {} + self.output_edge[t] = {} + end + for i = 1, #self.layers do + local dim_in, dim_out = self.layers[i]:get_dim() + for j = 1, #dim_in do + local time = self.input_conn[i][j][3] + table.insert(self.input_edge[time], {i, j}) + end + for j = 1, #dim_out do + if self.output_conn[i][j] == nil then + nerv.error('dangling output') + end + local time = self.output_conn[i][j][3] + table.insert(self.output_edge[time], {i, j}) + end + end end function network:compile(layer) @@ -112,11 +138,20 @@ function network:init(batch_size, chunk_size) self:make_initial_store() collectgarbage('collect') + + self.flush = {} + for t = 1, self.chunk_size do + self.flush[t] = {} + end end function network:epoch_init() + self.timestamp = 0 for i = 1, #self.layers do self.layers[i]:init(self.batch_size, self.chunk_size) + for t = 1, self.chunk_size do + self.flush[t][i] = {timestamp = 0, input = {}, output = {}} + end end end @@ -134,12 +169,10 @@ function network:topsort() for i = 1, #self.layers do local _, dim_out = self.layers[i]:get_dim() for j = 1, #dim_out do - if self.output_conn[i][j] ~= nil then - local edge = self.output_conn[i][j] - local id, time = edge[1], edge[3] + t - if time >= 1 and time <= self.chunk_size and id ~= 0 then - degree[time][id] = degree[time][id] + 1 - end + local edge = self.output_conn[i][j] + local id, time = edge[1], edge[3] + t + if time >= 1 and time <= self.chunk_size and id ~= 0 then + degree[time][id] = degree[time][id] + 1 end end end @@ -161,15 +194,13 @@ function network:topsort() l = l + 1 local _, dim_out = self.layers[i]:get_dim() for j = 1, #dim_out do - if self.output_conn[i][j] ~= nil then - local edge = self.output_conn[i][j] - local id, time = edge[1], edge[3] + t - if time >= 1 and time <= self.chunk_size and id ~= 0 then - degree[time][id] = degree[time][id] - 1 - if degree[time][id] == 0 then - r = r + 1 - self.queue[r] = {chunk = time, id = id} - end + local edge = self.output_conn[i][j] + local id, time = edge[1], edge[3] + t + if time >= 1 and time <= self.chunk_size and id ~= 0 then + degree[time][id] = degree[time][id] - 1 + if degree[time][id] == 0 then + r = r + 1 + self.queue[r] = {chunk = time, id = id} end end end @@ -202,17 +233,19 @@ function network:make_initial_store() memory[t][i][j]:fill(self.nn_act_default) end end - -- memory[t][0] stores network input - memory[t][0] = {} - for j = 1, #self.dim_in do - memory[t][0][j] = self.mat_type(self.batch_size, self.dim_in[j]) - memory[t][0][j]:fill(self.nn_act_default) - end - -- err_memory[t][0] stores network err_input - err_memory[t][0] = {} - for j = 1, #self.dim_out do - err_memory[t][0][j] = self.mat_type(self.batch_size, self.dim_out[j]) - err_memory[t][0][j]:fill(0) + if t < 1 or t > self.chunk_size then + -- memory[t][0] stores network input + memory[t][0] = {} + for j = 1, #self.dim_in do + memory[t][0][j] = self.mat_type(self.batch_size, self.dim_in[j]) + memory[t][0][j]:fill(self.nn_act_default) + end + -- err_memory[t][0] stores network err_input + err_memory[t][0] = {} + for j = 1, #self.dim_out do + err_memory[t][0][j] = self.mat_type(self.batch_size, self.dim_out[j]) + err_memory[t][0][j]:fill(0) + end end end @@ -314,9 +347,14 @@ function network:make_initial_store() self.legacy[t] = {} for i = 1, #self.layers do self.legacy[t][i] = {} - local _, dim_out = self.layers[i]:get_dim() - for j = 1, #dim_out do - self.legacy[t][i][j] = memory[t][i][j] + end + end + for d = 1, self.delay do + for t = 1 - d, 0 do + for i = 1, #self.output_edge[d] do + local edge = self.output_edge[d][i] + local id, port = edge[1], edge[2] + self.legacy[t][id][port] = memory[t][id][port] end end end @@ -383,59 +421,74 @@ function network:mini_batch_init(info) self.info = info self:set_input(self.info.input) self:set_output(self.info.output) + if self.info.do_train then + self:set_err_input(self.info.err_input) + self:set_err_output(self.info.err_output) + end -- calculate border self.max_length = 0 - self.border = {} - for i = 1, self.chunk_size do - self.border[i] = {} - end + self.timestamp = self.timestamp + 1 for i = 1, self.batch_size do if self.info.seq_length[i] > self.max_length then self.max_length = self.info.seq_length[i] end - for t = 1, self.delay do - local chunk = self.info.seq_length[i] + t - if chunk > self.chunk_size then - break + local border = self.info.seq_length[i] + for d = 1, self.delay do + for t = border + 1, border + d do + if t > self.chunk_size then + break + end + for j = 1, #self.output_edge[-d] do + local edge = self.output_edge[-d][j] + local id, port = edge[1], edge[2] + local flush = self.flush[t][id] + if flush.timestamp ~= self.timestamp then + flush.timestamp = self.timestamp + flush.input = {} + flush.output = {} + end + table.insert(flush.output, {port, i}) + end + end + if self.info.do_train then + for t = border, border - d + 1, -1 do + if t < 1 then + break + end + for j = 1, #self.input_edge[-d] do + local edge = self.input_edge[-d][j] + local id, port = edge[1], edge[2] + local flush = self.flush[t][id] + if flush.timestamp ~= self.timestamp then + flush.timestamp = self.timestamp + flush.input = {} + flush.output = {} + end + table.insert(flush.input, {port, i}) + end + end end - table.insert(self.border[chunk], i) end end -- copy legacy - for t = 1 - self.delay, 0 do - for i = 1, #self.layers do - local _, dim_out = self.layers[i]:get_dim() - for j = 1, #dim_out do - if t + self.chunk_size >= 1 and self.output_conn[i][j][1] ~= 0 then - self.legacy[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j]) + for d = 1, self.delay do + for t = 1 - d, 0 do + for i = 1, #self.output_edge[d] do + local edge = self.output_edge[d][i] + local id, port = edge[1], edge[2] + if t + self.chunk_size >= 1 and self.output_conn[id][port][1] ~= 0 then + self.legacy[t][id][port]:copy_from(self.output[t + self.chunk_size][id][port]) end - for k = 1, #self.info.new_seq do - local batch = self.info.new_seq[k] - self.legacy[t][i][j][batch - 1]:fill(self.nn_act_default) + for j = 1, #self.info.new_seq do + local batch = self.info.new_seq[j] + self.legacy[t][id][port][batch - 1]:fill(self.nn_act_default) end end end end - if self.info.do_train then - self:set_err_input(self.info.err_input) - self:set_err_output(self.info.err_output) - - -- flush border gradient - for t = self.max_length + 1, self.max_length + self.delay do - if t > self.chunk_size then - break - end - for i = 1, #self.layers do - local dim_in, _ = self.layers[i]:get_dim() - for j = 1, #dim_in do - self.err_output[t][i][j]:fill(0) - end - end - end - end end function network:propagate() @@ -445,11 +498,11 @@ function network:propagate() self.layers[id]:propagate(self.input[t][id], self.output[t][id], t) end -- flush border activation - for j = 1, #self.border[t] do - local batch = self.border[t][j] - local _, dim_out = self.layers[id]:get_dim() - for k = 1, #dim_out do - self.output[t][id][k][batch - 1]:fill(self.nn_act_default) + if self.flush[t][id].timestamp == self.timestamp then + for j = 1, #self.flush[t][id].output do + local border = self.flush[t][id].output[j] + local port, batch = border[1], border[2] + self.output[t][id][port][batch - 1]:fill(self.nn_act_default) end end end @@ -459,15 +512,8 @@ function network:back_propagate() for i = #self.queue, 1, -1 do local t, id = self.queue[i].chunk, self.queue[i].id if t <= self.max_length then - -- flush border gradient - for j = 1, #self.border[t] do - local batch = self.border[t][j] - local _, dim_out = self.layers[id]:get_dim() - for k = 1, #dim_out do - self.err_input[t][id][k][batch - 1]:fill(0) - end - end self.layers[id]:back_propagate(self.err_input[t][id], self.err_output[t][id], self.input[t][id], self.output[t][id], t) + -- gradient clip if self.clip ~= nil then local dim_in, _ = self.layers[id]:get_dim() for j = 1, #dim_in do @@ -475,14 +521,21 @@ function network:back_propagate() end end end + -- flush border gradient + if self.flush[t][id].timestamp == self.timestamp then + for j = 1, #self.flush[t][id].input do + local border = self.flush[t][id].input[j] + local port, batch = border[1], border[2] + self.err_output[t][id][port][batch - 1]:fill(0) + end + end end end function network:update() - for i = 1, #self.queue do - local t, id = self.queue[i].chunk, self.queue[i].id - if t <= self.max_length then - self.layers[id]:update(self.err_input[t][id], self.input[t][id], self.output[t][id], t) + for t = 1, self.max_length do + for i = 1, #self.layers do + self.layers[i]:update(self.err_input[t][i], self.input[t][i], self.output[t][i], t) end end end |