aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2016-02-06 21:33:41 +0800
committertxh18 <[email protected]>2016-02-06 21:33:41 +0800
commit152c89dc8af3d5d7ace79f65616f192a71b96b0d (patch)
tree9c238d3cddbbcf9093ff35a4ad232ffc89a75bbb
parent8782772c4a68d45a403e610efa75a1c8f401c7e7 (diff)
bug fixes in lm_sampler
-rw-r--r--nerv/examples/lmptb/lm_sampler.lua15
1 files changed, 7 insertions, 8 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua
index 9d31f17..c9adf85 100644
--- a/nerv/examples/lmptb/lm_sampler.lua
+++ b/nerv/examples/lmptb/lm_sampler.lua
@@ -10,7 +10,6 @@ function LMSampler:__init(global_conf)
self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id
self.loaded = false
-
end
function LMSampler:load_dagL(dagL)
@@ -20,18 +19,18 @@ function LMSampler:load_dagL(dagL)
self.dagL:init(self.batch_size)
self.dagL_inputs = {}
- self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1)
+ self.dagL_inputs[1] = self.gconf.cumat_type(self.gconf.batch_size, 1)
self.dagL_inputs[1]:fill(self.sen_end_id - 1)
- self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
+ self.dagL_inputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size)
self.dagL_inputs[2]:fill(0)
self.dagL_outputs = {}
- self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size())
- self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
+ self.dagL_outputs[1] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab:size())
+ self.dagL_outputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size)
- self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
- self.ssout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
- self.ssout_h = global_conf.mmat_type(self.batch_size, self.vocab:size())
+ self.smout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size())
+ self.ssout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size())
+ self.ssout_h = self.gconf.mmat_type(self.batch_size, self.vocab:size())
self.store = {}
for i = 1, self.batch_size do