aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_sampler.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_sampler.lua')
-rw-r--r--nerv/examples/lmptb/lm_sampler.lua7
1 files changed, 4 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua
index 2a4f1c3..d194af9 100644
--- a/nerv/examples/lmptb/lm_sampler.lua
+++ b/nerv/examples/lmptb/lm_sampler.lua
@@ -3,18 +3,19 @@ local LMSampler = nerv.class('nerv.LMSampler')
function LMSampler:__init(global_conf)
self.log_pre = "LMSampler"
self.gconf = global_conf
+ self.batch_size = self.gconf.batch_size
+ self.chunk_size = self.gconf.chunk_size --largest sample sentence length
self.vocab = self.gconf.vocab
self.sen_end_token = self.vocab.sen_end_token
self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id
end
function LMSampler:load_dagL(dagL)
- self.batch_size = self.gconf.batch_size
- self.chunk_size = self.gconf.chunk_size
-
+
nerv.printf("%s loading dagL\n", self.log_pre)
self.dagL = dagL
+ self.dagL:init(self.batch_size)
self.dagL_inputs = {}
self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1)