local LMSampler = nerv.class('nerv.LMSampler') function LMSampler:__init(global_conf) self.log_pre = "LMSampler" self.gconf = global_conf self.vocab = self.gconf.vocab self.sen_end_token = self.vocab.sen_end_token self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id end function LMSampler:load_dagL(dagL) self.batch_size = self.gconf.batch_size self.chunk_size = self.gconf.chunk_size nerv.printf("%s loading dagL\n", self.log_pre) self.dagL = dagL self.dagL_inputs = {} self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1) self.dagL_inputs[1]:fill(self.sen_end_id - 1) self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) self.dagL_inputs[2]:fill(0) self.dagL_outputs = {} self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size()) self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size()) self.store = {} for i = 1, self.batch_size do self.store[i] = {} self.store[i][1] = {} self.store[i][1].w = self.sen_end_token self.store[i][1].id = self.sen_end_id self.store[i][1].p = 0 end self.repo = {} end function LMSampler:sample_to_store(smout) for i = 1, self.batch_size do local ran = math.random() local s = 0 local id = self.vocab:size() for j = 0, self.vocab:size() - 1 do s = s + smout[i - 1][j] if s >= ran then id = j + 1 break end end if #self.store[i] >= self.chunk_size - 2 then id = self.sen_end_id end local tmp = {} tmp.w = self.vocab:get_word_id(id).str tmp.id = id tmp.p = smout[i - 1][id - 1] table.insert(self.store[i], tmp) end end --Returns: LMResult function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) local dagL = self.dagL local inputs = self.dagL_inputs local outputs = self.dagL_outputs while #self.repo < sample_num do dagL:propagate(inputs, outputs) inputs[2]:copy_fromd(outputs[2]) --copy hidden activation self.smout_d:softmax(outputs[1]) self.smout_d:copy_toh(self.smout_h) self:sample_to_store(self.smout_h) for i = 1, self.batch_size do inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end if #self.store[i] >= 3 then self.repo[#self.repo + 1] = self.store[i] end inputs[2][i - 1]:fill(0) self.store[i] = {} self.store[i][1] = {} self.store[i][1].w = self.sen_end_token self.store[i][1].id = self.sen_end_id self.store[i][1].p = 0 end end collectgarbage("collect") end local res = {} for i = 1, sample_num do res[i] = self.repo[#self.repo] self.repo[#self.repo] = nil end return res end