diff options
Diffstat (limited to 'nerv/examples/lmptb/lm_sampler.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_sampler.lua | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua index 2a4f1c3..d194af9 100644 --- a/nerv/examples/lmptb/lm_sampler.lua +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -3,18 +3,19 @@ local LMSampler = nerv.class('nerv.LMSampler') function LMSampler:__init(global_conf) self.log_pre = "LMSampler" self.gconf = global_conf + self.batch_size = self.gconf.batch_size + self.chunk_size = self.gconf.chunk_size --largest sample sentence length self.vocab = self.gconf.vocab self.sen_end_token = self.vocab.sen_end_token self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id end function LMSampler:load_dagL(dagL) - self.batch_size = self.gconf.batch_size - self.chunk_size = self.gconf.chunk_size - + nerv.printf("%s loading dagL\n", self.log_pre) self.dagL = dagL + self.dagL:init(self.batch_size) self.dagL_inputs = {} self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1) |