aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2016-02-05 21:42:05 +0800
committertxh18 <[email protected]>2016-02-05 21:42:05 +0800
commit3d7a2be2d8ac3083617df2b7194921971f0ac94e (patch)
treebff8357e781ea594c99bfb52694bcadf54c90bb3
parenta1d8c0a2369ea72df77821f7b298903e9470e676 (diff)
..
-rw-r--r--nerv/examples/lmptb/lm_sampler.lua13
1 files changed, 9 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua
index d194af9..9d31f17 100644
--- a/nerv/examples/lmptb/lm_sampler.lua
+++ b/nerv/examples/lmptb/lm_sampler.lua
@@ -8,10 +8,12 @@ function LMSampler:__init(global_conf)
self.vocab = self.gconf.vocab
self.sen_end_token = self.vocab.sen_end_token
self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id
+
+ self.loaded = false
+
end
-function LMSampler:load_dagL(dagL)
-
+function LMSampler:load_dagL(dagL)
nerv.printf("%s loading dagL\n", self.log_pre)
self.dagL = dagL
@@ -40,9 +42,11 @@ function LMSampler:load_dagL(dagL)
self.store[i][1].p = 0
end
self.repo = {}
+
+ self.loaded = true
end
-function LMSampler:sample_to_store(ssout)
+function LMSampler:sample_to_store(ssout) --private
for i = 1, self.batch_size do
local ran = math.random()
local id = 1
@@ -88,8 +92,9 @@ function LMSampler:sample_to_store(ssout)
end
end
---Returns: LMResult
function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
+ assert(self.loaded == true)
+
local dagL = self.dagL
local inputs = self.dagL_inputs
local outputs = self.dagL_outputs