nerv.include('select_linear.lua') local nn = nerv.class('nerv.NN') function nn:__init(global_conf, layers, connections) self.gconf = global_conf self.network = self:get_network(layers, connections) self.output = {} self.err_output = {} for i = 1, self.gconf.chunk_size do self.output[i] = {self.gconf.cumat_type(self.gconf.batch_size, 1)} self.err_output[i] = {self.gconf.cumat_type(self.gconf.batch_size, 1)} self.err_output[i][2] = self.gconf.cumat_type(self.gconf.batch_size, 1) end end function nn:get_network(layers, connections) local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf) local graph = nerv.GraphLayer('graph', self.gconf, {dim_in = {1, self.gconf.vocab_size}, dim_out = {1}, layer_repo = layer_repo, connections = connections}) local network = nerv.Network('network', self.gconf, {network = graph, clip = self.gconf.clip}) network:init(self.gconf.batch_size, self.gconf.chunk_size) return network end function nn:process(data, do_train, reader) local timer = self.gconf.timer local buffer = nerv.SeqBuffer(self.gconf, { batch_size = self.gconf.batch_size, chunk_size = self.gconf.chunk_size, readers = {reader}, }) local total_err = 0 local total_frame = 0 self.network:epoch_init() while true do timer:tic('IO') data = buffer:get_data() if data == nil then break end local err_input = {} if do_train then for t = 1, self.gconf.chunk_size do local tmp = self.gconf.mmat_type(self.gconf.batch_size, 1) for i = 1, self.gconf.batch_size do if t <= data.seq_length[i] then tmp[i - 1][0] = 1 else tmp[i - 1][0] = 0 end end err_input[t] = {self.gconf.cumat_type.new_from_host(tmp)} end end local info = {input = {}, output = self.output, err_input = err_input, do_train = do_train, err_output = self.err_output, seq_length = data.seq_length, new_seq = data.new_seq} for t = 1, self.gconf.chunk_size do info.input[t] = {data.data['input'][t]} info.input[t][2] = data.data['label'][t] end timer:toc('IO') timer:tic('network') self.network:mini_batch_init(info) self.network:propagate() timer:toc('network') timer:tic('IO') for t = 1, self.gconf.chunk_size do local tmp = info.output[t][1]:new_to_host() for i = 1, self.gconf.batch_size do total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) end end for i = 1, self.gconf.batch_size do total_frame = total_frame + info.seq_length[i] end timer:toc('IO') timer:tic('network') if do_train then self.network:back_propagate() self.network:update() end timer:toc('network') timer:tic('gc') collectgarbage('collect') timer:toc('gc') end return math.pow(10, - total_err / total_frame) end function nn:epoch(train_reader, val_reader) local train_error = self:process(self.train_data, true, train_reader) local tmp = self.gconf.dropout_rate self.gconf.dropout_rate = 0 local val_error = self:process(self.val_data, false, val_reader) self.gconf.dropout_rate = tmp return train_error, val_error end