diff options
author | TianxingHe <htx_2006@hotmail.com> | 2016-01-19 20:13:49 -0800 |
---|---|---|
committer | TianxingHe <htx_2006@hotmail.com> | 2016-01-19 20:13:49 -0800 |
commit | dcad8a3f80fc55ca93984d981f9b829d2e4ea728 (patch) | |
tree | 61b9bc1d043883bb5d85dcb86cfb621396d75c41 /nerv/examples/lmptb/lm_sampler.lua | |
parent | 7449dd19c4d1669b483693f61add9d574e46f0b2 (diff) | |
parent | 37dec2610c92d03813c4e91ed58791ab60da6646 (diff) |
Merge pull request #21 from cloudygoose/txh18/rnnlm
Txh18/rnnlm new changes to lm side
Diffstat (limited to 'nerv/examples/lmptb/lm_sampler.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_sampler.lua | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua new file mode 100644 index 0000000..c25a75c --- /dev/null +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -0,0 +1,104 @@ +local LMSampler = nerv.class('nerv.LMSampler') + +function LMSampler:__init(global_conf) + self.log_pre = "LMSampler" + self.gconf = global_conf + self.vocab = self.gconf.vocab + self.sen_end_token = self.vocab.sen_end_token + self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id +end + +function LMSampler:load_dagL(dagL) + self.batch_size = self.gconf.batch_size + self.chunk_size = self.gconf.chunk_size + + nerv.printf("%s loading dagL\n", self.log_pre) + + self.dagL = dagL + + self.dagL_inputs = {} + self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1) + self.dagL_inputs[1]:fill(self.sen_end_id - 1) + self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) + self.dagL_inputs[2]:fill(0) + + self.dagL_outputs = {} + self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size()) + self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) + + self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) + self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size()) + + self.store = {} + for i = 1, self.batch_size do + self.store[i] = {} + self.store[i][1] = {} + self.store[i][1].w = self.sen_end_token + self.store[i][1].id = self.sen_end_id + self.store[i][1].p = 0 + end + self.repo = {} +end + +function LMSampler:sample_to_store(smout) + for i = 1, self.batch_size do + local ran = math.random() + local s = 0 + local id = self.vocab:size() + for j = 0, self.vocab:size() - 1 do + s = s + smout[i - 1][j] + if s >= ran then + id = j + 1 + break + end + end + if #self.store[i] >= self.chunk_size - 2 then + id = self.sen_end_id + end + local tmp = {} + tmp.w = self.vocab:get_word_id(id).str + tmp.id = id + tmp.p = smout[i - 1][id - 1] + table.insert(self.store[i], tmp) + end +end + +--Returns: LMResult +function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) + local dagL = self.dagL + local inputs = self.dagL_inputs + local outputs = self.dagL_outputs + + while #self.repo < sample_num do + dagL:propagate(inputs, outputs) + inputs[2]:copy_fromd(outputs[2]) --copy hidden activation + + self.smout_d:softmax(outputs[1]) + self.smout_d:copy_toh(self.smout_h) + + self:sample_to_store(self.smout_h) + for i = 1, self.batch_size do + inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 + if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end + if #self.store[i] >= 3 then + self.repo[#self.repo + 1] = self.store[i] + end + inputs[2][i - 1]:fill(0) + self.store[i] = {} + self.store[i][1] = {} + self.store[i][1].w = self.sen_end_token + self.store[i][1].id = self.sen_end_id + self.store[i][1].p = 0 + end + end + + collectgarbage("collect") + end + + local res = {} + for i = 1, sample_num do + res[i] = self.repo[#self.repo] + self.repo[#self.repo] = nil + end + return res +end |