diff options
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 15 | ||||
-rw-r--r-- | nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 4 |
2 files changed, 15 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index ead8d4c..40471d5 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -39,7 +39,8 @@ function LMReader:open_file(fn) for i = 1, self.batch_size, 1 do self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0} end - + self.stat = {} --stat collected during file reading + self.stat.al_sen_start = true --check whether it's always sentence_start at the begining of a minibatch self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix) for j = 1, self.chunk_size, 1 do self.bak_inputs_m[j] = {} @@ -169,7 +170,17 @@ function LMReader:get_batch(feeds) inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1]) end - if (got_new == false) then + --check for self.al_sen_start + for i = 1, self.batch_size do + if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then + self.stat.al_sen_start = false + end + end + + if got_new == false then + nerv.info("lmseqreader file ends, printing stats...") + print("al_sen_start:", self.stat.al_sen_start) + return false else return true diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index b98ff95..9127559 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -7,7 +7,7 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" local vocab = nerv.LMVocab() vocab:build_file(test_fn) -local chunk_size = 5 +local chunk_size = 20 local batch_size = 3 local global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, @@ -30,7 +30,7 @@ local global_conf = { vocab = vocab } -local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab) +local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab, {["se_mode"] = true}) reader:open_file(test_fn) local feeds = {} feeds.flags_now = {} |