diff options
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index ead8d4c..40471d5 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -39,7 +39,8 @@ function LMReader:open_file(fn) for i = 1, self.batch_size, 1 do self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0} end - + self.stat = {} --stat collected during file reading + self.stat.al_sen_start = true --check whether it's always sentence_start at the begining of a minibatch self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix) for j = 1, self.chunk_size, 1 do self.bak_inputs_m[j] = {} @@ -169,7 +170,17 @@ function LMReader:get_batch(feeds) inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1]) end - if (got_new == false) then + --check for self.al_sen_start + for i = 1, self.batch_size do + if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then + self.stat.al_sen_start = false + end + end + + if got_new == false then + nerv.info("lmseqreader file ends, printing stats...") + print("al_sen_start:", self.stat.al_sen_start) + return false else return true |