diff options
Diffstat (limited to 'nerv/examples/lmptb/lm_sampler.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_sampler.lua | 60 |
1 files changed, 45 insertions, 15 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua index c25a75c..c9adf85 100644 --- a/nerv/examples/lmptb/lm_sampler.lua +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -3,31 +3,34 @@ local LMSampler = nerv.class('nerv.LMSampler') function LMSampler:__init(global_conf) self.log_pre = "LMSampler" self.gconf = global_conf + self.batch_size = self.gconf.batch_size + self.chunk_size = self.gconf.chunk_size --largest sample sentence length self.vocab = self.gconf.vocab self.sen_end_token = self.vocab.sen_end_token self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id + + self.loaded = false end -function LMSampler:load_dagL(dagL) - self.batch_size = self.gconf.batch_size - self.chunk_size = self.gconf.chunk_size - +function LMSampler:load_dagL(dagL) nerv.printf("%s loading dagL\n", self.log_pre) self.dagL = dagL + self.dagL:init(self.batch_size) self.dagL_inputs = {} - self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1) + self.dagL_inputs[1] = self.gconf.cumat_type(self.gconf.batch_size, 1) self.dagL_inputs[1]:fill(self.sen_end_id - 1) - self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) + self.dagL_inputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size) self.dagL_inputs[2]:fill(0) self.dagL_outputs = {} - self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size()) - self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) + self.dagL_outputs[1] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab:size()) + self.dagL_outputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size) - self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) - self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size()) + self.smout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size()) + self.ssout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size()) + self.ssout_h = self.gconf.mmat_type(self.batch_size, self.vocab:size()) self.store = {} for i = 1, self.batch_size do @@ -38,11 +41,31 @@ function LMSampler:load_dagL(dagL) self.store[i][1].p = 0 end self.repo = {} + + self.loaded = true end -function LMSampler:sample_to_store(smout) +function LMSampler:sample_to_store(ssout) --private for i = 1, self.batch_size do local ran = math.random() + local id = 1 + local low = 0 + local high = ssout:ncol() - 1 + if ssout[i - 1][high] < 0.9999 or ssout[i - 1][high] > 1.0001 then + nerv.error("%s ERROR, softmax output summation(%f) seems to have some problem", self.log_pre, ssout[i - 1][high]) + end + if ssout[i - 1][low] < ran then + while low + 1 < high do + local mid = math.floor((low + high) / 2) + if ssout[i - 1][mid] < ran then + low = mid + else + high = mid + end + end + id = high + 1 + end + --[[ local s = 0 local id = self.vocab:size() for j = 0, self.vocab:size() - 1 do @@ -52,19 +75,25 @@ function LMSampler:sample_to_store(smout) break end end + ]]-- if #self.store[i] >= self.chunk_size - 2 then id = self.sen_end_id end local tmp = {} tmp.w = self.vocab:get_word_id(id).str tmp.id = id - tmp.p = smout[i - 1][id - 1] + if id == 1 then + tmp.p = ssout[i - 1][id - 1] + else + tmp.p = ssout[i - 1][id - 1] - ssout[i - 1][id - 2] + end table.insert(self.store[i], tmp) end end ---Returns: LMResult function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) + assert(self.loaded == true) + local dagL = self.dagL local inputs = self.dagL_inputs local outputs = self.dagL_outputs @@ -74,9 +103,10 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) inputs[2]:copy_fromd(outputs[2]) --copy hidden activation self.smout_d:softmax(outputs[1]) - self.smout_d:copy_toh(self.smout_h) + self.ssout_d:prefixsum_row(self.smout_d) + self.ssout_d:copy_toh(self.ssout_h) - self:sample_to_store(self.smout_h) + self:sample_to_store(self.ssout_h) for i = 1, self.batch_size do inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end |