From c73e332126f2c89958b68b47f18fdb5ac0276bde Mon Sep 17 00:00:00 2001 From: txh18 Date: Thu, 3 Dec 2015 13:45:14 +0800 Subject: added al_sen_start stat for lmseqreader --- nerv/examples/lmptb/lmptb/lmseqreader.lua | 15 +++++++++++++-- nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 4 ++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index ead8d4c..40471d5 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -39,7 +39,8 @@ function LMReader:open_file(fn) for i = 1, self.batch_size, 1 do self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0} end - + self.stat = {} --stat collected during file reading + self.stat.al_sen_start = true --check whether it's always sentence_start at the begining of a minibatch self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix) for j = 1, self.chunk_size, 1 do self.bak_inputs_m[j] = {} @@ -169,7 +170,17 @@ function LMReader:get_batch(feeds) inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1]) end - if (got_new == false) then + --check for self.al_sen_start + for i = 1, self.batch_size do + if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then + self.stat.al_sen_start = false + end + end + + if got_new == false then + nerv.info("lmseqreader file ends, printing stats...") + print("al_sen_start:", self.stat.al_sen_start) + return false else return true diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index b98ff95..9127559 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -7,7 +7,7 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" local vocab = nerv.LMVocab() vocab:build_file(test_fn) -local chunk_size = 5 +local chunk_size = 20 local batch_size = 3 local global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, @@ -30,7 +30,7 @@ local global_conf = { vocab = vocab } -local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab) +local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab, {["se_mode"] = true}) reader:open_file(test_fn) local feeds = {} feeds.flags_now = {} -- cgit v1.2.3