nerv.include('select_linear.lua') local nn = nerv.class('nerv.NN') function nn:__init(global_conf, train_data, val_data, layers, connections) self.gconf = global_conf self.network = self:get_network(layers, connections) self.train_data = self:get_data(train_data) self.val_data = self:get_data(val_data) end function nn:get_network(layers, connections) self.gconf.dropout_rate = 0 local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf) local graph = nerv.GraphLayer('graph', self.gconf, {dim_in = {1, self.gconf.vocab_size}, dim_out = {1}, layer_repo = layer_repo, connections = connections}) local network = nerv.Network('network', self.gconf, {network = graph, clip = self.gconf.clip}) network:init(self.gconf.batch_size, self.gconf.chunk_size) return network end function nn:get_data(data) local err_output = {} local softmax_output = {} local output = {} for i = 1, self.gconf.chunk_size do err_output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1) softmax_output[i] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab_size) output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1) end local ret = {} for i = 1, #data do ret[i] = {} ret[i].input = {} ret[i].output = {} ret[i].err_input = {} ret[i].err_output = {} for t = 1, self.gconf.chunk_size do ret[i].input[t] = {} ret[i].output[t] = {} ret[i].err_input[t] = {} ret[i].err_output[t] = {} ret[i].input[t][1] = data[i].input[t] ret[i].input[t][2] = data[i].output[t] ret[i].output[t][1] = output[t] local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1) for j = 1, self.gconf.batch_size do if t <= data[i].seq_len[j] then err_input[j - 1][0] = 1 else err_input[j - 1][0] = 0 end end ret[i].err_input[t][1] = self.gconf.cumat_type.new_from_host(err_input) ret[i].err_output[t][1] = err_output[t] ret[i].err_output[t][2] = softmax_output[t] end ret[i].seq_length = data[i].seq_len ret[i].new_seq = {} for j = 1, self.gconf.batch_size do if data[i].seq_start[j] then table.insert(ret[i].new_seq, j) end end end return ret end function nn:process(data, do_train) local timer = self.gconf.timer local total_err = 0 local total_frame = 0 for id = 1, #data do if do_train then self.gconf.dropout_rate = self.gconf.dropout data[id].do_train = true else self.gconf.dropout_rate = 0 data[id].do_train = false end timer:tic('network') self.network:mini_batch_init(data[id]) self.network:propagate() timer:toc('network') for t = 1, self.gconf.chunk_size do local tmp = data[id].output[t][1]:new_to_host() for i = 1, self.gconf.batch_size do if t <= data[id].seq_length[i] then total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) total_frame = total_frame + 1 end end end if do_train then timer:tic('network') self.network:back_propagate() self.network:update() timer:toc('network') end timer:tic('gc') collectgarbage('collect') timer:toc('gc') end return math.pow(10, - total_err / total_frame) end function nn:epoch() local train_error = self:process(self.train_data, true) local val_error = self:process(self.val_data, false) return train_error, val_error end