diff options
-rw-r--r-- | nerv/examples/lmptb/lm_sampler.lua | 35 | ||||
-rw-r--r-- | nerv/examples/lmptb/sample_grulm_ptb_main.lua | 4 |
2 files changed, 32 insertions, 7 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua index c25a75c..2a4f1c3 100644 --- a/nerv/examples/lmptb/lm_sampler.lua +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -27,7 +27,8 @@ function LMSampler:load_dagL(dagL) self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) - self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size()) + self.ssout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) + self.ssout_h = global_conf.mmat_type(self.batch_size, self.vocab:size()) self.store = {} for i = 1, self.batch_size do @@ -40,9 +41,27 @@ function LMSampler:load_dagL(dagL) self.repo = {} end -function LMSampler:sample_to_store(smout) +function LMSampler:sample_to_store(ssout) for i = 1, self.batch_size do local ran = math.random() + local id = 1 + local low = 0 + local high = ssout:ncol() - 1 + if ssout[i - 1][high] < 0.9999 or ssout[i - 1][high] > 1.0001 then + nerv.error("%s ERROR, softmax output summation(%f) seems to have some problem", self.log_pre, ssout[i - 1][high]) + end + if ssout[i - 1][low] < ran then + while low + 1 < high do + local mid = math.floor((low + high) / 2) + if ssout[i - 1][mid] < ran then + low = mid + else + high = mid + end + end + id = high + 1 + end + --[[ local s = 0 local id = self.vocab:size() for j = 0, self.vocab:size() - 1 do @@ -52,13 +71,18 @@ function LMSampler:sample_to_store(smout) break end end + ]]-- if #self.store[i] >= self.chunk_size - 2 then id = self.sen_end_id end local tmp = {} tmp.w = self.vocab:get_word_id(id).str tmp.id = id - tmp.p = smout[i - 1][id - 1] + if id == 1 then + tmp.p = ssout[i - 1][id - 1] + else + tmp.p = ssout[i - 1][id - 1] - ssout[i - 1][id - 2] + end table.insert(self.store[i], tmp) end end @@ -74,9 +98,10 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) inputs[2]:copy_fromd(outputs[2]) --copy hidden activation self.smout_d:softmax(outputs[1]) - self.smout_d:copy_toh(self.smout_h) + self.ssout_d:prefixsum_row(self.smout_d) + self.ssout_d:copy_toh(self.ssout_h) - self:sample_to_store(self.smout_h) + self:sample_to_store(self.ssout_h) for i = 1, self.batch_size do inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end diff --git a/nerv/examples/lmptb/sample_grulm_ptb_main.lua b/nerv/examples/lmptb/sample_grulm_ptb_main.lua index 9a13d36..30dfe26 100644 --- a/nerv/examples/lmptb/sample_grulm_ptb_main.lua +++ b/nerv/examples/lmptb/sample_grulm_ptb_main.lua @@ -424,11 +424,11 @@ if commands["sampling"] == 1 then local dagL = load_net_dagL(global_conf, global_conf.fn_to_sample) local sampler = nerv.LMSampler(global_conf) sampler:load_dagL(dagL) - for k = 1, 5 do + for k = 1, 1 do local res = sampler:lm_sample_rnn_dagL(10, {}) for i = 1, #res do for j = 1, #res[i] do - nerv.printf("%s ", res[i][j].w) + nerv.printf("%s(%f) ", res[i][j].w, res[i][j].p) end nerv.printf("\n") end |