aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_sampler.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_sampler.lua')
-rw-r--r--nerv/examples/lmptb/lm_sampler.lua35
1 files changed, 30 insertions, 5 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua
index c25a75c..2a4f1c3 100644
--- a/nerv/examples/lmptb/lm_sampler.lua
+++ b/nerv/examples/lmptb/lm_sampler.lua
@@ -27,7 +27,8 @@ function LMSampler:load_dagL(dagL)
self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
- self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size())
+ self.ssout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
+ self.ssout_h = global_conf.mmat_type(self.batch_size, self.vocab:size())
self.store = {}
for i = 1, self.batch_size do
@@ -40,9 +41,27 @@ function LMSampler:load_dagL(dagL)
self.repo = {}
end
-function LMSampler:sample_to_store(smout)
+function LMSampler:sample_to_store(ssout)
for i = 1, self.batch_size do
local ran = math.random()
+ local id = 1
+ local low = 0
+ local high = ssout:ncol() - 1
+ if ssout[i - 1][high] < 0.9999 or ssout[i - 1][high] > 1.0001 then
+ nerv.error("%s ERROR, softmax output summation(%f) seems to have some problem", self.log_pre, ssout[i - 1][high])
+ end
+ if ssout[i - 1][low] < ran then
+ while low + 1 < high do
+ local mid = math.floor((low + high) / 2)
+ if ssout[i - 1][mid] < ran then
+ low = mid
+ else
+ high = mid
+ end
+ end
+ id = high + 1
+ end
+ --[[
local s = 0
local id = self.vocab:size()
for j = 0, self.vocab:size() - 1 do
@@ -52,13 +71,18 @@ function LMSampler:sample_to_store(smout)
break
end
end
+ ]]--
if #self.store[i] >= self.chunk_size - 2 then
id = self.sen_end_id
end
local tmp = {}
tmp.w = self.vocab:get_word_id(id).str
tmp.id = id
- tmp.p = smout[i - 1][id - 1]
+ if id == 1 then
+ tmp.p = ssout[i - 1][id - 1]
+ else
+ tmp.p = ssout[i - 1][id - 1] - ssout[i - 1][id - 2]
+ end
table.insert(self.store[i], tmp)
end
end
@@ -74,9 +98,10 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
inputs[2]:copy_fromd(outputs[2]) --copy hidden activation
self.smout_d:softmax(outputs[1])
- self.smout_d:copy_toh(self.smout_h)
+ self.ssout_d:prefixsum_row(self.smout_d)
+ self.ssout_d:copy_toh(self.ssout_h)
- self:sample_to_store(self.smout_h)
+ self:sample_to_store(self.ssout_h)
for i = 1, self.batch_size do
inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1
if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end