diff options
-rw-r--r-- | nerv/examples/lmptb/lm_sampler.lua | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua index d194af9..9d31f17 100644 --- a/nerv/examples/lmptb/lm_sampler.lua +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -8,10 +8,12 @@ function LMSampler:__init(global_conf) self.vocab = self.gconf.vocab self.sen_end_token = self.vocab.sen_end_token self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id + + self.loaded = false + end -function LMSampler:load_dagL(dagL) - +function LMSampler:load_dagL(dagL) nerv.printf("%s loading dagL\n", self.log_pre) self.dagL = dagL @@ -40,9 +42,11 @@ function LMSampler:load_dagL(dagL) self.store[i][1].p = 0 end self.repo = {} + + self.loaded = true end -function LMSampler:sample_to_store(ssout) +function LMSampler:sample_to_store(ssout) --private for i = 1, self.batch_size do local ran = math.random() local id = 1 @@ -88,8 +92,9 @@ function LMSampler:sample_to_store(ssout) end end ---Returns: LMResult function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) + assert(self.loaded == true) + local dagL = self.dagL local inputs = self.dagL_inputs local outputs = self.dagL_outputs |