diff options
Diffstat (limited to 'nerv/examples/network_debug/network.lua')
-rw-r--r-- | nerv/examples/network_debug/network.lua | 120 |
1 files changed, 57 insertions, 63 deletions
diff --git a/nerv/examples/network_debug/network.lua b/nerv/examples/network_debug/network.lua index 5518e27..386c3b0 100644 --- a/nerv/examples/network_debug/network.lua +++ b/nerv/examples/network_debug/network.lua @@ -2,11 +2,17 @@ nerv.include('select_linear.lua') local nn = nerv.class('nerv.NN') -function nn:__init(global_conf, train_data, val_data, layers, connections) +function nn:__init(global_conf, layers, connections) self.gconf = global_conf self.network = self:get_network(layers, connections) - self.train_data = self:get_data(train_data) - self.val_data = self:get_data(val_data) + + self.output = {} + self.err_output = {} + for i = 1, self.gconf.chunk_size do + self.output[i] = {self.gconf.cumat_type(self.gconf.batch_size, 1)} + self.err_output[i] = {self.gconf.cumat_type(self.gconf.batch_size, 1)} + self.err_output[i][2] = self.gconf.cumat_type(self.gconf.batch_size, 1) + end end function nn:get_network(layers, connections) @@ -20,79 +26,67 @@ function nn:get_network(layers, connections) return network end -function nn:get_data(data) - local err_output = {} - local softmax_output = {} - local output = {} - for i = 1, self.gconf.chunk_size do - err_output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1) - softmax_output[i] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab_size) - output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1) - end - local ret = {} - for i = 1, #data do - ret[i] = {} - ret[i].input = {} - ret[i].output = {} - ret[i].err_input = {} - ret[i].err_output = {} - for t = 1, self.gconf.chunk_size do - ret[i].input[t] = {} - ret[i].output[t] = {} - ret[i].err_input[t] = {} - ret[i].err_output[t] = {} - ret[i].input[t][1] = data[i].input[t] - ret[i].input[t][2] = data[i].output[t] - ret[i].output[t][1] = output[t] - local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1) - for j = 1, self.gconf.batch_size do - if t <= data[i].seq_len[j] then - err_input[j - 1][0] = 1 - else - err_input[j - 1][0] = 0 +function nn:process(data, do_train, reader) + local timer = self.gconf.timer + local buffer = nerv.SeqBuffer(self.gconf, { + batch_size = self.gconf.batch_size, chunk_size = self.gconf.chunk_size, + readers = {reader}, + }) + local total_err = 0 + local total_frame = 0 + self.network:epoch_init() + while true do + timer:tic('IO') + data = buffer:get_data() + if data == nil then + break + end + local err_input = {} + if do_train then + for t = 1, self.gconf.chunk_size do + local tmp = self.gconf.mmat_type(self.gconf.batch_size, 1) + for i = 1, self.gconf.batch_size do + if t <= data.seq_length[i] then + tmp[i - 1][0] = 1 + else + tmp[i - 1][0] = 0 + end end + err_input[t] = {self.gconf.cumat_type.new_from_host(tmp)} end - ret[i].err_input[t][1] = self.gconf.cumat_type.new_from_host(err_input) - ret[i].err_output[t][1] = err_output[t] - ret[i].err_output[t][2] = softmax_output[t] end - ret[i].seq_length = data[i].seq_len - ret[i].new_seq = {} - for j = 1, self.gconf.batch_size do - if data[i].seq_start[j] then - table.insert(ret[i].new_seq, j) - end + local info = {input = {}, output = self.output, err_input = err_input, do_train = do_train, + err_output = self.err_output, seq_length = data.seq_length, new_seq = data.new_seq} + for t = 1, self.gconf.chunk_size do + info.input[t] = {data.data['input'][t]} + info.input[t][2] = data.data['label'][t] end - end - return ret -end + timer:toc('IO') -function nn:process(data, do_train) - local timer = self.gconf.timer - local total_err = 0 - local total_frame = 0 - self.network:epoch_init() - for id = 1, #data do - data[id].do_train = do_train timer:tic('network') - self.network:mini_batch_init(data[id]) + self.network:mini_batch_init(info) self.network:propagate() timer:toc('network') + + timer:tic('IO') for t = 1, self.gconf.chunk_size do - local tmp = data[id].output[t][1]:new_to_host() + local tmp = info.output[t][1]:new_to_host() for i = 1, self.gconf.batch_size do - if t <= data[id].seq_length[i] then - total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) - total_frame = total_frame + 1 - end + total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) end end + for i = 1, self.gconf.batch_size do + total_frame = total_frame + info.seq_length[i] + end + timer:toc('IO') + + timer:tic('network') if do_train then - timer:tic('network') self.network:back_propagate() self.network:update() - timer:toc('network') end + timer:toc('network') + timer:tic('gc') collectgarbage('collect') timer:toc('gc') @@ -100,11 +94,11 @@ function nn:process(data, do_train) return math.pow(10, - total_err / total_frame) end -function nn:epoch() - local train_error = self:process(self.train_data, true) +function nn:epoch(train_reader, val_reader) + local train_error = self:process(self.train_data, true, train_reader) local tmp = self.gconf.dropout_rate self.gconf.dropout_rate = 0 - local val_error = self:process(self.val_data, false) + local val_error = self:process(self.val_data, false, val_reader) self.gconf.dropout_rate = tmp return train_error, val_error end |