aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmseqreader.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua113
1 files changed, 79 insertions, 34 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index e0dcd95..40471d5 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -1,4 +1,5 @@
require 'lmptb.lmvocab'
+--require 'tnn.init'
local LMReader = nerv.class("nerv.LMSeqReader")
@@ -7,7 +8,7 @@ local printf = nerv.printf
--global_conf: table
--batch_size: int
--vocab: nerv.LMVocab
-function LMReader:__init(global_conf, batch_size, chunk_size, vocab)
+function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf)
self.gconf = global_conf
self.fh = nil --file handle to read, nil means currently no file
self.batch_size = batch_size
@@ -15,6 +16,13 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab)
self.log_pre = "[LOG]LMSeqReader:"
self.vocab = vocab
self.streams = nil
+ if r_conf == nil then
+ r_conf = {}
+ end
+ self.se_mode = false --sentence end mode, when a sentence end is met, the stream after will be null
+ if r_conf.se_mode == true then
+ self.se_mode = true
+ end
end
--fn: string
@@ -24,12 +32,21 @@ function LMReader:open_file(fn)
nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn)
end
printf("%s opening file %s...\n", self.log_pre, fn)
- print("batch_size:", self.batch_size, "chunk_size", self.chunk_size)
+ print(self.log_pre, "batch_size:", self.batch_size, "chunk_size", self.chunk_size)
+ print(self.log_pre, "se_mode:", self.se_mode)
self.fh = io.open(fn, "r")
self.streams = {}
for i = 1, self.batch_size, 1 do
self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0}
end
+ self.stat = {} --stat collected during file reading
+ self.stat.al_sen_start = true --check whether it's always sentence_start at the begining of a minibatch
+ self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix)
+ for j = 1, self.chunk_size, 1 do
+ self.bak_inputs_m[j] = {}
+ self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1)
+ --self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used
+ end
end
--id: int
@@ -78,7 +95,7 @@ function LMReader:get_batch(feeds)
local labels_s = feeds.labels_s
for i = 1, self.chunk_size, 1 do
inputs_s[i] = {}
- labels_s[i] = {}
+ labels_s[i] = {}
end
local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label
@@ -86,45 +103,62 @@ function LMReader:get_batch(feeds)
local flagsPack = feeds.flagsPack_now
local got_new = false
+ for j = 1, self.chunk_size, 1 do
+ inputs_m[j][2]:fill(0)
+ end
for i = 1, self.batch_size, 1 do
local st = self.streams[i]
+ local end_stream = false --used for se_mode, indicating that this stream is ended
for j = 1, self.chunk_size, 1 do
flags[j][i] = 0
- self:refresh_stream(i)
- if (st.store[st.head] ~= nil) then
- inputs_s[j][i] = st.store[st.head]
- inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
- else
+ if end_stream == true then
+ if self.se_mode == false then
+ nerv.error("lmseqreader:getbatch: error, end_stream is true while se_mode is false")
+ end
inputs_s[j][i] = self.vocab.null_token
- inputs_m[j][1][i - 1][0] = 0
- end
- inputs_m[j][2][i - 1]:fill(0)
- if (st.store[st.head + 1] ~= nil) then
- labels_s[j][i] = st.store[st.head + 1]
- inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
+ self.bak_inputs_m[j][1][i - 1][0] = 0
+ labels_s[j][i] = self.vocab.null_token
else
- if (inputs_s[j][i] ~= self.vocab.null_token) then
- nerv.error("reader error : input not null but label is null_token")
+ self:refresh_stream(i)
+ if st.store[st.head] ~= nil then
+ inputs_s[j][i] = st.store[st.head]
+ --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ else
+ inputs_s[j][i] = self.vocab.null_token
+ --inputs_m[j][1][i - 1][0] = 0
+ self.bak_inputs_m[j][1][i - 1][0] = 0
end
- labels_s[j][i] = self.vocab.null_token
- end
- if (inputs_s[j][i] ~= self.vocab.null_token) then
- if (labels_s[j][i] == self.vocab.null_token) then
- nerv.error("reader error : label is null while input is not null")
+ if st.store[st.head + 1] ~= nil then
+ labels_s[j][i] = st.store[st.head + 1]
+ inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
+ else
+ if (inputs_s[j][i] ~= self.vocab.null_token) then
+ nerv.error("reader error : input not null but label is null_token")
+ end
+ labels_s[j][i] = self.vocab.null_token
end
- flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM)
- got_new = true
- st.store[st.head] = nil
- st.head = st.head + 1
- if (labels_s[j][i] == self.vocab.sen_end_token) then
- flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END)
- st.store[st.head] = nil --sentence end is passed
+ if inputs_s[j][i] ~= self.vocab.null_token then
+ if labels_s[j][i] == self.vocab.null_token then
+ nerv.error("reader error : label is null while input is not null")
+ end
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label
+ got_new = true
+ st.store[st.head] = nil
st.head = st.head + 1
- end
- if (inputs_s[j][i] == self.vocab.sen_end_token) then
- flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
- end
- end
+ if labels_s[j][i] == self.vocab.sen_end_token then
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END)
+ st.store[st.head] = nil --sentence end is passed
+ st.head = st.head + 1
+ if self.se_mode == true then
+ end_stream = true --meet sentence end, this stream ends now
+ end
+ end
+ if inputs_s[j][i] == self.vocab.sen_end_token then
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
+ end
+ end
+ end
end
end
@@ -133,9 +167,20 @@ function LMReader:get_batch(feeds)
for i = 1, self.batch_size, 1 do
flagsPack[j] = bit.bor(flagsPack[j], flags[j][i])
end
+ inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1])
end
- if (got_new == false) then
+ --check for self.al_sen_start
+ for i = 1, self.batch_size do
+ if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then
+ self.stat.al_sen_start = false
+ end
+ end
+
+ if got_new == false then
+ nerv.info("lmseqreader file ends, printing stats...")
+ print("al_sen_start:", self.stat.al_sen_start)
+
return false
else
return true