aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_sampler.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_sampler.lua')
-rw-r--r--nerv/examples/lmptb/lm_sampler.lua104
1 files changed, 104 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua
new file mode 100644
index 0000000..c25a75c
--- /dev/null
+++ b/nerv/examples/lmptb/lm_sampler.lua
@@ -0,0 +1,104 @@
+local LMSampler = nerv.class('nerv.LMSampler')
+
+function LMSampler:__init(global_conf)
+ self.log_pre = "LMSampler"
+ self.gconf = global_conf
+ self.vocab = self.gconf.vocab
+ self.sen_end_token = self.vocab.sen_end_token
+ self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id
+end
+
+function LMSampler:load_dagL(dagL)
+ self.batch_size = self.gconf.batch_size
+ self.chunk_size = self.gconf.chunk_size
+
+ nerv.printf("%s loading dagL\n", self.log_pre)
+
+ self.dagL = dagL
+
+ self.dagL_inputs = {}
+ self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1)
+ self.dagL_inputs[1]:fill(self.sen_end_id - 1)
+ self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
+ self.dagL_inputs[2]:fill(0)
+
+ self.dagL_outputs = {}
+ self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size())
+ self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
+
+ self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
+ self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size())
+
+ self.store = {}
+ for i = 1, self.batch_size do
+ self.store[i] = {}
+ self.store[i][1] = {}
+ self.store[i][1].w = self.sen_end_token
+ self.store[i][1].id = self.sen_end_id
+ self.store[i][1].p = 0
+ end
+ self.repo = {}
+end
+
+function LMSampler:sample_to_store(smout)
+ for i = 1, self.batch_size do
+ local ran = math.random()
+ local s = 0
+ local id = self.vocab:size()
+ for j = 0, self.vocab:size() - 1 do
+ s = s + smout[i - 1][j]
+ if s >= ran then
+ id = j + 1
+ break
+ end
+ end
+ if #self.store[i] >= self.chunk_size - 2 then
+ id = self.sen_end_id
+ end
+ local tmp = {}
+ tmp.w = self.vocab:get_word_id(id).str
+ tmp.id = id
+ tmp.p = smout[i - 1][id - 1]
+ table.insert(self.store[i], tmp)
+ end
+end
+
+--Returns: LMResult
+function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
+ local dagL = self.dagL
+ local inputs = self.dagL_inputs
+ local outputs = self.dagL_outputs
+
+ while #self.repo < sample_num do
+ dagL:propagate(inputs, outputs)
+ inputs[2]:copy_fromd(outputs[2]) --copy hidden activation
+
+ self.smout_d:softmax(outputs[1])
+ self.smout_d:copy_toh(self.smout_h)
+
+ self:sample_to_store(self.smout_h)
+ for i = 1, self.batch_size do
+ inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1
+ if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end
+ if #self.store[i] >= 3 then
+ self.repo[#self.repo + 1] = self.store[i]
+ end
+ inputs[2][i - 1]:fill(0)
+ self.store[i] = {}
+ self.store[i][1] = {}
+ self.store[i][1].w = self.sen_end_token
+ self.store[i][1].id = self.sen_end_id
+ self.store[i][1].p = 0
+ end
+ end
+
+ collectgarbage("collect")
+ end
+
+ local res = {}
+ for i = 1, sample_num do
+ res[i] = self.repo[#self.repo]
+ self.repo[#self.repo] = nil
+ end
+ return res
+end