local LMSampler = nerv.class('nerv.LMSampler') function LMSampler:__init(global_conf) self.log_pre = "LMSampler" self.gconf = global_conf self.batch_size = self.gconf.batch_size self.chunk_size = self.gconf.chunk_size --largest sample sentence length self.vocab = self.gconf.vocab self.sen_end_token = self.vocab.sen_end_token self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id end function LMSampler:load_dagL(dagL) nerv.printf("%s loading dagL\n", self.log_pre) self.dagL = dagL self.dagL:init(self.batch_size) self.dagL_inputs = {} self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1) self.dagL_inputs[1]:fill(self.sen_end_id - 1) self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) self.dagL_inputs[2]:fill(0) self.dagL_outputs = {} self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size()) self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) self.ssout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) self.ssout_h = global_conf.mmat_type(self.batch_size, self.vocab:size()) self.store = {} for i = 1, self.batch_size do self.store[i] = {} self.store[i][1] = {} self.store[i][1].w = self.sen_end_token self.store[i][1].id = self.sen_end_id self.store[i][1].p = 0 end self.repo = {} end function LMSampler:sample_to_store(ssout) for i = 1, self.batch_size do local ran = math.random() local id = 1 local low = 0 local high = ssout:ncol() - 1 if ssout[i - 1][high] < 0.9999 or ssout[i - 1][high] > 1.0001 then nerv.error("%s ERROR, softmax output summation(%f) seems to have some problem", self.log_pre, ssout[i - 1][high]) end if ssout[i - 1][low] < ran then while low + 1 < high do local mid = math.floor((low + high) / 2) if ssout[i - 1][mid] < ran then low = mid else high = mid end end id = high + 1 end --[[ local s = 0 local id = self.vocab:size() for j = 0, self.vocab:size() - 1 do s = s + smout[i - 1][j] if s >= ran then id = j + 1 break end end ]]-- if #self.store[i] >= self.chunk_size - 2 then id = self.sen_end_id end local tmp = {} tmp.w = self.vocab:get_word_id(id).str tmp.id = id if id == 1 then tmp.p = ssout[i - 1][id - 1] else tmp.p = ssout[i - 1][id - 1] - ssout[i - 1][id - 2] end table.insert(self.store[i], tmp) end end --Returns: LMResult function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) local dagL = self.dagL local inputs = self.dagL_inputs local outputs = self.dagL_outputs while #self.repo < sample_num do dagL:propagate(inputs, outputs) inputs[2]:copy_fromd(outputs[2]) --copy hidden activation self.smout_d:softmax(outputs[1]) self.ssout_d:prefixsum_row(self.smout_d) self.ssout_d:copy_toh(self.ssout_h) self:sample_to_store(self.ssout_h) for i = 1, self.batch_size do inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end if #self.store[i] >= 3 then self.repo[#self.repo + 1] = self.store[i] end inputs[2][i - 1]:fill(0) self.store[i] = {} self.store[i][1] = {} self.store[i][1].w = self.sen_end_token self.store[i][1].id = self.sen_end_id self.store[i][1].p = 0 end end collectgarbage("collect") end local res = {} for i = 1, sample_num do res[i] = self.repo[#self.repo] self.repo[#self.repo] = nil end return res end