From ae051c0195374ddfab6a0e693d2c2cfa34f24e99 Mon Sep 17 00:00:00 2001 From: txh18 Date: Tue, 19 Jan 2016 19:28:42 +0800 Subject: added lm_sampler, todo:test it --- nerv/examples/lmptb/lm_sampler.lua | 98 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 nerv/examples/lmptb/lm_sampler.lua (limited to 'nerv/examples/lmptb/lm_sampler.lua') diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua new file mode 100644 index 0000000..c165127 --- /dev/null +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -0,0 +1,98 @@ +local LMSampler = nerv.class('nerv.LMSampler') + +function LMSampler:__init(global_conf) + self.log_pre = "LMSampler" + self.gconf = global_conf + self.vocab = self.gconf.vocab + self.sen_end_token = self.vocab.sen_end_token + self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id +end + +function LMSampler:load_dagL(dagL) + self.batch_size = self.gconf.batch_size + self.chunk_size = self.gconf.chunk_size + + nerv.printf("%s loading dagL\n", self.log_pre) + + self.dagL = dagL + + self.dagL_inputs = {} + self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1) + self.dagL_inputs[1]:fill(self.sen_end_id - 1) + self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) + self.dagL_inputs[2]:fill(0) + + self.dagL_outputs = {} + self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size()) + self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) + + self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) + self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size()) + + self.store = {} + for i = 1, self.batch_size do + self.store[i] = {} + self.store[i][1] = {} + self.store[i][1].w = self.sen_end_token + self.store[i][1].id = self.sen_end_id + self.store[i][1].p = 0 + end +end + +function LMSampler:sample_to_store(smout) + for i = 1, self.batch_size do + local ran = math.random() + local s = 0, id = self.vocab:size() + for j = 0, self.vocab:size() - 1 do + s = s + smout[i][j] + if s >= ran then + id = j + 1 + break + end + end + if #self.store[i] >= self.chunk_size - 2 then + id = self.sen_end_id + end + local tmp = {} + tmp.w = self.vocab:get_word_id(id).str + tmp.id = id + tmp.p = smout[i][id] + table.insert(self.store[i], tmp) + end +end + +--Returns: LMResult +function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) + local dagL = self.dagL + local inputs = self.dagL_inputs + local outputs = self.dagL_outputs + + local res = {} + while #res < sample_num do + dagL:propagate(inputs, outputs) + inputs[2]:copy_fromd(outputs[2]) --copy hidden activation + + self.smout_d:softmax(outputs[1]) + self.smout_d:copy_toh(self.smout_h) + + self:sample_to_store(self.smout_h) + for i = 1, self.batch_size do + inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 + if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end + if #self.store[i] >= 3 then + res[#res + 1] = self.store[i] + end + inputs[2][i - 1]:fill(0) + self.store[i] = {} + self.store[i][1] = {} + self.store[i][1].w = self.sen_end_token + self.store[i][1].id = self.sen_end_id + self.store[i][1].p = 0 + end + end + + collectgarbage("collect") + end + + return res +end -- cgit v1.2.3 From 37dec2610c92d03813c4e91ed58791ab60da6646 Mon Sep 17 00:00:00 2001 From: txh18 Date: Tue, 19 Jan 2016 21:30:26 +0800 Subject: big fixes about lm_sampler --- nerv/examples/lmptb/lm_sampler.lua | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'nerv/examples/lmptb/lm_sampler.lua') diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua index c165127..c25a75c 100644 --- a/nerv/examples/lmptb/lm_sampler.lua +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -37,14 +37,16 @@ function LMSampler:load_dagL(dagL) self.store[i][1].id = self.sen_end_id self.store[i][1].p = 0 end + self.repo = {} end function LMSampler:sample_to_store(smout) for i = 1, self.batch_size do local ran = math.random() - local s = 0, id = self.vocab:size() + local s = 0 + local id = self.vocab:size() for j = 0, self.vocab:size() - 1 do - s = s + smout[i][j] + s = s + smout[i - 1][j] if s >= ran then id = j + 1 break @@ -56,7 +58,7 @@ function LMSampler:sample_to_store(smout) local tmp = {} tmp.w = self.vocab:get_word_id(id).str tmp.id = id - tmp.p = smout[i][id] + tmp.p = smout[i - 1][id - 1] table.insert(self.store[i], tmp) end end @@ -66,9 +68,8 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) local dagL = self.dagL local inputs = self.dagL_inputs local outputs = self.dagL_outputs - - local res = {} - while #res < sample_num do + + while #self.repo < sample_num do dagL:propagate(inputs, outputs) inputs[2]:copy_fromd(outputs[2]) --copy hidden activation @@ -80,7 +81,7 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end if #self.store[i] >= 3 then - res[#res + 1] = self.store[i] + self.repo[#self.repo + 1] = self.store[i] end inputs[2][i - 1]:fill(0) self.store[i] = {} @@ -94,5 +95,10 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) collectgarbage("collect") end + local res = {} + for i = 1, sample_num do + res[i] = self.repo[#self.repo] + self.repo[#self.repo] = nil + end return res end -- cgit v1.2.3