aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_sampler.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_sampler.lua')
-rw-r--r--nerv/examples/lmptb/lm_sampler.lua60
1 files changed, 45 insertions, 15 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua
index c25a75c..c9adf85 100644
--- a/nerv/examples/lmptb/lm_sampler.lua
+++ b/nerv/examples/lmptb/lm_sampler.lua
@@ -3,31 +3,34 @@ local LMSampler = nerv.class('nerv.LMSampler')
function LMSampler:__init(global_conf)
self.log_pre = "LMSampler"
self.gconf = global_conf
+ self.batch_size = self.gconf.batch_size
+ self.chunk_size = self.gconf.chunk_size --largest sample sentence length
self.vocab = self.gconf.vocab
self.sen_end_token = self.vocab.sen_end_token
self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id
+
+ self.loaded = false
end
-function LMSampler:load_dagL(dagL)
- self.batch_size = self.gconf.batch_size
- self.chunk_size = self.gconf.chunk_size
-
+function LMSampler:load_dagL(dagL)
nerv.printf("%s loading dagL\n", self.log_pre)
self.dagL = dagL
+ self.dagL:init(self.batch_size)
self.dagL_inputs = {}
- self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1)
+ self.dagL_inputs[1] = self.gconf.cumat_type(self.gconf.batch_size, 1)
self.dagL_inputs[1]:fill(self.sen_end_id - 1)
- self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
+ self.dagL_inputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size)
self.dagL_inputs[2]:fill(0)
self.dagL_outputs = {}
- self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size())
- self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
+ self.dagL_outputs[1] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab:size())
+ self.dagL_outputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size)
- self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
- self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size())
+ self.smout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size())
+ self.ssout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size())
+ self.ssout_h = self.gconf.mmat_type(self.batch_size, self.vocab:size())
self.store = {}
for i = 1, self.batch_size do
@@ -38,11 +41,31 @@ function LMSampler:load_dagL(dagL)
self.store[i][1].p = 0
end
self.repo = {}
+
+ self.loaded = true
end
-function LMSampler:sample_to_store(smout)
+function LMSampler:sample_to_store(ssout) --private
for i = 1, self.batch_size do
local ran = math.random()
+ local id = 1
+ local low = 0
+ local high = ssout:ncol() - 1
+ if ssout[i - 1][high] < 0.9999 or ssout[i - 1][high] > 1.0001 then
+ nerv.error("%s ERROR, softmax output summation(%f) seems to have some problem", self.log_pre, ssout[i - 1][high])
+ end
+ if ssout[i - 1][low] < ran then
+ while low + 1 < high do
+ local mid = math.floor((low + high) / 2)
+ if ssout[i - 1][mid] < ran then
+ low = mid
+ else
+ high = mid
+ end
+ end
+ id = high + 1
+ end
+ --[[
local s = 0
local id = self.vocab:size()
for j = 0, self.vocab:size() - 1 do
@@ -52,19 +75,25 @@ function LMSampler:sample_to_store(smout)
break
end
end
+ ]]--
if #self.store[i] >= self.chunk_size - 2 then
id = self.sen_end_id
end
local tmp = {}
tmp.w = self.vocab:get_word_id(id).str
tmp.id = id
- tmp.p = smout[i - 1][id - 1]
+ if id == 1 then
+ tmp.p = ssout[i - 1][id - 1]
+ else
+ tmp.p = ssout[i - 1][id - 1] - ssout[i - 1][id - 2]
+ end
table.insert(self.store[i], tmp)
end
end
---Returns: LMResult
function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
+ assert(self.loaded == true)
+
local dagL = self.dagL
local inputs = self.dagL_inputs
local outputs = self.dagL_outputs
@@ -74,9 +103,10 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
inputs[2]:copy_fromd(outputs[2]) --copy hidden activation
self.smout_d:softmax(outputs[1])
- self.smout_d:copy_toh(self.smout_h)
+ self.ssout_d:prefixsum_row(self.smout_d)
+ self.ssout_d:copy_toh(self.ssout_h)
- self:sample_to_store(self.smout_h)
+ self:sample_to_store(self.ssout_h)
for i = 1, self.batch_size do
inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1
if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end