aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmseqreader.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua163
1 files changed, 163 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
new file mode 100644
index 0000000..e0dcd95
--- /dev/null
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -0,0 +1,163 @@
+require 'lmptb.lmvocab'
+
+local LMReader = nerv.class("nerv.LMSeqReader")
+
+local printf = nerv.printf
+
+--global_conf: table
+--batch_size: int
+--vocab: nerv.LMVocab
+function LMReader:__init(global_conf, batch_size, chunk_size, vocab)
+ self.gconf = global_conf
+ self.fh = nil --file handle to read, nil means currently no file
+ self.batch_size = batch_size
+ self.chunk_size = chunk_size
+ self.log_pre = "[LOG]LMSeqReader:"
+ self.vocab = vocab
+ self.streams = nil
+end
+
+--fn: string
+--Initialize all streams
+function LMReader:open_file(fn)
+ if (self.fh ~= nil) then
+ nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn)
+ end
+ printf("%s opening file %s...\n", self.log_pre, fn)
+ print("batch_size:", self.batch_size, "chunk_size", self.chunk_size)
+ self.fh = io.open(fn, "r")
+ self.streams = {}
+ for i = 1, self.batch_size, 1 do
+ self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0}
+ end
+end
+
+--id: int
+--Refresh stream id, read a line from file, will check whether this line is cntklm-style
+function LMReader:refresh_stream(id)
+ if (self.streams[id] == nil) then
+ nerv.error("stream %d does not exit.", id)
+ end
+ local st = self.streams[id]
+ if (st.store[st.head] ~= nil) then return end
+ if (self.fh == nil) then return end
+ local list = self.vocab:read_line(self.fh)
+ if (list == nil) then --file has end
+ printf("%s file expires, closing.\n", self.log_pre)
+ self.fh:close()
+ self.fh = nil
+ return
+ end
+
+ --some sanity check
+ if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntklm style input
+ nerv.error("%s sentence not begin or end with </s> : %s", self.log_pre, table.tostring(list));
+ end
+ for i = 2, #list - 1, 1 do
+ if (list[i] == self.vocab.sen_end_token) then
+ nerv.error("%s Got </s> in the middle of a line(%s) in file", self.log_pre, table.tostring(list))
+ end
+ end
+
+ for i = 1, #list, 1 do
+ st.tail = st.tail + 1
+ st.store[st.tail] = list[i]
+ end
+end
+
+--feeds: a table that will be filled by the reader
+--Returns: bool
+function LMReader:get_batch(feeds)
+ if (feeds == nil or type(feeds) ~= "table") then
+ nerv.error("feeds is not a table")
+ end
+
+ feeds["inputs_s"] = {}
+ feeds["labels_s"] = {}
+ local inputs_s = feeds.inputs_s
+ local labels_s = feeds.labels_s
+ for i = 1, self.chunk_size, 1 do
+ inputs_s[i] = {}
+ labels_s[i] = {}
+ end
+
+ local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label
+ local flags = feeds.flags_now
+ local flagsPack = feeds.flagsPack_now
+
+ local got_new = false
+ for i = 1, self.batch_size, 1 do
+ local st = self.streams[i]
+ for j = 1, self.chunk_size, 1 do
+ flags[j][i] = 0
+ self:refresh_stream(i)
+ if (st.store[st.head] ~= nil) then
+ inputs_s[j][i] = st.store[st.head]
+ inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ else
+ inputs_s[j][i] = self.vocab.null_token
+ inputs_m[j][1][i - 1][0] = 0
+ end
+ inputs_m[j][2][i - 1]:fill(0)
+ if (st.store[st.head + 1] ~= nil) then
+ labels_s[j][i] = st.store[st.head + 1]
+ inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
+ else
+ if (inputs_s[j][i] ~= self.vocab.null_token) then
+ nerv.error("reader error : input not null but label is null_token")
+ end
+ labels_s[j][i] = self.vocab.null_token
+ end
+ if (inputs_s[j][i] ~= self.vocab.null_token) then
+ if (labels_s[j][i] == self.vocab.null_token) then
+ nerv.error("reader error : label is null while input is not null")
+ end
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM)
+ got_new = true
+ st.store[st.head] = nil
+ st.head = st.head + 1
+ if (labels_s[j][i] == self.vocab.sen_end_token) then
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END)
+ st.store[st.head] = nil --sentence end is passed
+ st.head = st.head + 1
+ end
+ if (inputs_s[j][i] == self.vocab.sen_end_token) then
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
+ end
+ end
+ end
+ end
+
+ for j = 1, self.chunk_size, 1 do
+ flagsPack[j] = 0
+ for i = 1, self.batch_size, 1 do
+ flagsPack[j] = bit.bor(flagsPack[j], flags[j][i])
+ end
+ end
+
+ if (got_new == false) then
+ return false
+ else
+ return true
+ end
+end
+
+--[[
+do
+ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/some-text"
+ --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
+ local vocab = nerv.LMVocab()
+ vocab:build_file(test_fn)
+ local batch_size = 3
+ local feeder = nerv.LMFeeder({}, batch_size, vocab)
+ feeder:open_file(test_fn)
+ while (1) do
+ local list = feeder:get_batch()
+ if (list == nil) then break end
+ for i = 1, batch_size, 1 do
+ printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id)
+ end
+ printf("\n")
+ end
+end
+]]--