nerv.include('select_linear.lua') local reader = nerv.class('nerv.TNNReader') function reader:__init(global_conf, data) self.gconf = global_conf self.offset = 0 self.data = data end function reader:get_batch(feeds) self.offset = self.offset + 1 if self.offset > #self.data then return false end for i = 1, self.gconf.chunk_size do feeds.inputs_m[i][1]:copy_from(self.data[self.offset].input[i]) feeds.inputs_m[i][2]:copy_from(self.data[self.offset].output[i]:decompress(self.gconf.vocab_size)) end feeds.flags_now = self.data[self.offset].flags feeds.flagsPack_now = self.data[self.offset].flagsPack return true end function reader:has_data(t, i) return t <= self.data[self.offset].seq_len[i] end function reader:get_err_input() return self.data[self.offset].err_input end local nn = nerv.class('nerv.NN') function nn:__init(global_conf, train_data, val_data, layers, connections) self.gconf = global_conf self.tnn = self:get_tnn(layers, connections) self.train_data = self:get_data(train_data) self.val_data = self:get_data(val_data) end function nn:get_tnn(layers, connections) self.gconf.dropout_rate = 0 local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf) local tnn = nerv.TNN('TNN', self.gconf, {dim_in = {1, self.gconf.vocab_size}, dim_out = {1}, sub_layers = layer_repo, connections = connections, clip = self.gconf.clip}) tnn:init(self.gconf.batch_size, self.gconf.chunk_size) return tnn end function nn:get_data(data) local ret = {} for i = 1, #data do ret[i] = {} ret[i].input = data[i].input ret[i].output = data[i].output ret[i].flags = {} ret[i].err_input = {} for t = 1, self.gconf.chunk_size do ret[i].flags[t] = {} local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1) for j = 1, self.gconf.batch_size do if t <= data[i].seq_len[j] then ret[i].flags[t][j] = nerv.TNN.FC.SEQ_NORM err_input[j - 1][0] = 1 else ret[i].flags[t][j] = 0 err_input[j - 1][0] = 0 end end ret[i].err_input[t] = self.gconf.cumat_type.new_from_host(err_input) end for j = 1, self.gconf.batch_size do if data[i].seq_start[j] then ret[i].flags[1][j] = bit.bor(ret[i].flags[1][j], nerv.TNN.FC.SEQ_START) end if data[i].seq_end[j] then local t = data[i].seq_len[j] ret[i].flags[t][j] = bit.bor(ret[i].flags[t][j], nerv.TNN.FC.SEQ_END) end end ret[i].flagsPack = {} for t = 1, self.gconf.chunk_size do ret[i].flagsPack[t] = 0 for j = 1, self.gconf.batch_size do ret[i].flagsPack[t] = bit.bor(ret[i].flagsPack[t], ret[i].flags[t][j]) end end ret[i].seq_len = data[i].seq_len end return ret end function nn:process(data, do_train) local total_err = 0 local total_frame = 0 local reader = nerv.TNNReader(self.gconf, data) while true do local r, _ = self.tnn:getfeed_from_reader(reader) if not r then break end if do_train then self.gconf.dropout_rate = self.gconf.dropout else self.gconf.dropout_rate = 0 end self.tnn:net_propagate() for t = 1, self.gconf.chunk_size do local tmp = self.tnn.outputs_m[t][1]:new_to_host() for i = 1, self.gconf.batch_size do if reader:has_data(t, i) then total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) total_frame = total_frame + 1 end end end if do_train then local err_input = reader:get_err_input() for i = 1, self.gconf.chunk_size do self.tnn.err_inputs_m[i][1]:copy_from(err_input[i]) end self.tnn:net_backpropagate(false) self.tnn:net_backpropagate(true) end collectgarbage('collect') end return math.pow(10, - total_err / total_frame) end function nn:epoch() local train_error = self:process(self.train_data, true) local val_error = self:process(self.val_data, false) return train_error, val_error end