diff options
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua new file mode 100644 index 0000000..e0dcd95 --- /dev/null +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -0,0 +1,163 @@ +require 'lmptb.lmvocab' + +local LMReader = nerv.class("nerv.LMSeqReader") + +local printf = nerv.printf + +--global_conf: table +--batch_size: int +--vocab: nerv.LMVocab +function LMReader:__init(global_conf, batch_size, chunk_size, vocab) + self.gconf = global_conf + self.fh = nil --file handle to read, nil means currently no file + self.batch_size = batch_size + self.chunk_size = chunk_size + self.log_pre = "[LOG]LMSeqReader:" + self.vocab = vocab + self.streams = nil +end + +--fn: string +--Initialize all streams +function LMReader:open_file(fn) + if (self.fh ~= nil) then + nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn) + end + printf("%s opening file %s...\n", self.log_pre, fn) + print("batch_size:", self.batch_size, "chunk_size", self.chunk_size) + self.fh = io.open(fn, "r") + self.streams = {} + for i = 1, self.batch_size, 1 do + self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0} + end +end + +--id: int +--Refresh stream id, read a line from file, will check whether this line is cntklm-style +function LMReader:refresh_stream(id) + if (self.streams[id] == nil) then + nerv.error("stream %d does not exit.", id) + end + local st = self.streams[id] + if (st.store[st.head] ~= nil) then return end + if (self.fh == nil) then return end + local list = self.vocab:read_line(self.fh) + if (list == nil) then --file has end + printf("%s file expires, closing.\n", self.log_pre) + self.fh:close() + self.fh = nil + return + end + + --some sanity check + if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntklm style input + nerv.error("%s sentence not begin or end with </s> : %s", self.log_pre, table.tostring(list)); + end + for i = 2, #list - 1, 1 do + if (list[i] == self.vocab.sen_end_token) then + nerv.error("%s Got </s> in the middle of a line(%s) in file", self.log_pre, table.tostring(list)) + end + end + + for i = 1, #list, 1 do + st.tail = st.tail + 1 + st.store[st.tail] = list[i] + end +end + +--feeds: a table that will be filled by the reader +--Returns: bool +function LMReader:get_batch(feeds) + if (feeds == nil or type(feeds) ~= "table") then + nerv.error("feeds is not a table") + end + + feeds["inputs_s"] = {} + feeds["labels_s"] = {} + local inputs_s = feeds.inputs_s + local labels_s = feeds.labels_s + for i = 1, self.chunk_size, 1 do + inputs_s[i] = {} + labels_s[i] = {} + end + + local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label + local flags = feeds.flags_now + local flagsPack = feeds.flagsPack_now + + local got_new = false + for i = 1, self.batch_size, 1 do + local st = self.streams[i] + for j = 1, self.chunk_size, 1 do + flags[j][i] = 0 + self:refresh_stream(i) + if (st.store[st.head] ~= nil) then + inputs_s[j][i] = st.store[st.head] + inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + else + inputs_s[j][i] = self.vocab.null_token + inputs_m[j][1][i - 1][0] = 0 + end + inputs_m[j][2][i - 1]:fill(0) + if (st.store[st.head + 1] ~= nil) then + labels_s[j][i] = st.store[st.head + 1] + inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 + else + if (inputs_s[j][i] ~= self.vocab.null_token) then + nerv.error("reader error : input not null but label is null_token") + end + labels_s[j][i] = self.vocab.null_token + end + if (inputs_s[j][i] ~= self.vocab.null_token) then + if (labels_s[j][i] == self.vocab.null_token) then + nerv.error("reader error : label is null while input is not null") + end + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) + got_new = true + st.store[st.head] = nil + st.head = st.head + 1 + if (labels_s[j][i] == self.vocab.sen_end_token) then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END) + st.store[st.head] = nil --sentence end is passed + st.head = st.head + 1 + end + if (inputs_s[j][i] == self.vocab.sen_end_token) then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) + end + end + end + end + + for j = 1, self.chunk_size, 1 do + flagsPack[j] = 0 + for i = 1, self.batch_size, 1 do + flagsPack[j] = bit.bor(flagsPack[j], flags[j][i]) + end + end + + if (got_new == false) then + return false + else + return true + end +end + +--[[ +do + local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/some-text" + --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" + local vocab = nerv.LMVocab() + vocab:build_file(test_fn) + local batch_size = 3 + local feeder = nerv.LMFeeder({}, batch_size, vocab) + feeder:open_file(test_fn) + while (1) do + local list = feeder:get_batch() + if (list == nil) then break end + for i = 1, batch_size, 1 do + printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) + end + printf("\n") + end +end +]]-- |