aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua15
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua4
2 files changed, 15 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index ead8d4c..40471d5 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -39,7 +39,8 @@ function LMReader:open_file(fn)
for i = 1, self.batch_size, 1 do
self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0}
end
-
+ self.stat = {} --stat collected during file reading
+ self.stat.al_sen_start = true --check whether it's always sentence_start at the begining of a minibatch
self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix)
for j = 1, self.chunk_size, 1 do
self.bak_inputs_m[j] = {}
@@ -169,7 +170,17 @@ function LMReader:get_batch(feeds)
inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1])
end
- if (got_new == false) then
+ --check for self.al_sen_start
+ for i = 1, self.batch_size do
+ if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then
+ self.stat.al_sen_start = false
+ end
+ end
+
+ if got_new == false then
+ nerv.info("lmseqreader file ends, printing stats...")
+ print("al_sen_start:", self.stat.al_sen_start)
+
return false
else
return true
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
index b98ff95..9127559 100644
--- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -7,7 +7,7 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te
--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
local vocab = nerv.LMVocab()
vocab:build_file(test_fn)
-local chunk_size = 5
+local chunk_size = 20
local batch_size = 3
local global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
@@ -30,7 +30,7 @@ local global_conf = {
vocab = vocab
}
-local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab)
+local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab, {["se_mode"] = true})
reader:open_file(test_fn)
local feeds = {}
feeds.flags_now = {}