summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua61
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua5
-rw-r--r--nerv/examples/lmptb/m-tests/some-text20
3 files changed, 54 insertions, 32 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index edc3ff4..26dc3be 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -7,10 +7,11 @@ local printf = nerv.printf
--global_conf: table
--batch_size: int
--vocab: nerv.LMVocab
-function LMReader:__init(global_conf, batch_size, vocab)
+function LMReader:__init(global_conf, batch_size, seq_size, vocab)
self.gconf = global_conf
self.fh = nil --file handle to read, nil means currently no file
self.batch_size = batch_size
+ self.seq_size = seq_size
self.log_pre = "[LOG]LMFeeder:"
self.vocab = vocab
self.streams = nil
@@ -26,12 +27,12 @@ function LMReader:open_file(fn)
self.fh = io.open(fn, "r")
self.streams = {}
for i = 1, self.batch_size, 1 do
- self.streams[i] = {["store"] = {self.vocab.sen_end_token}, ["head"] = 1, ["tail"] = 1}
+ self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0}
end
end
--id: int
---Refresh stream id, read a line from file
+--Refresh stream id, read a line from file, will check whether this line is cntklm-style
function LMReader:refresh_stream(id)
if (self.streams[id] == nil) then
nerv.error("stream %d does not exit.", id)
@@ -40,41 +41,61 @@ function LMReader:refresh_stream(id)
if (st.store[st.head] ~= nil) then return end
if (self.fh == nil) then return end
local list = self.vocab:read_line(self.fh)
- if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntk style input
- nerv.error("sentence not begin or end with </s> : %s", table.tostring(list));
- end
if (list == nil) then --file has end
printf("%s file expires, closing.\n", self.log_pre)
self.fh:close()
self.fh = nil
return
end
+
+ --some sanity check
+ if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntklm style input
+ nerv.error("%s sentence not begin or end with </s> : %s", self.log_pre, table.tostring(list));
+ end
+ for i = 2, #list - 1, 1 do
+ if (list[i] == self.vocab.sen_end_token) then
+ nerv.error("%s Got </s> in the middle of a line(%s) in file", self.log_pre, table.tostring(list))
+ end
+ end
+
for i = 1, #list, 1 do
st.tail = st.tail + 1
st.store[st.tail] = list[i]
end
end
---Returns: nil/table
---If gets something, return a list of string, vocab.null_token indicates end of string
-function LMReader:get_batch()
+function LMReader:get_batch(input, label)
local got_new = false
local list = {}
- for i = 1, self.batch_size, 1 do
- self:refresh_stream(i)
+ for i = 1, self.seq_size, 1 do
local st = self.streams[i]
- list[i] = st.store[st.head]
- if (list[i] == nil) then list[i] = self.vocab.null_token end
- if (list[i] ~= nil and list[i] ~= self.vocab.null_token)then
- got_new = true
- st.store[st.head] = nil
- st.head = st.head + 1
- end
+ for j = 1, self.batch_size, 1 do
+ self:refresh_stream(i)
+ if (st.store[st.head] ~= nil) then
+ input[i][j] = st.store[st.head]
+ else
+ input[i][j] = self.vocab.null_token
+ end
+ if (st.store[st.head + 1] ~= nil) then
+ label[i][j] = st.store[st.head + 1]
+ else
+ label[i][j] = self.vocab.null_token
+ end
+ if (input[i][j] ~= self.vocab.null_token) then
+ got_new = true
+ st.store[st.head] = nil
+ st.head = st.head + 1
+ if (label[i][j] == self.vocab.sen_end_token) then
+ st.store[st.head] = nil --sentence end is passed
+ st.head = st.head + 1
+ end
+ end
+ end
end
if (got_new == false) then
- return nil
+ return false
else
- return list
+ return true
end
end
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
index b90e651..bdea740 100644
--- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -6,8 +6,9 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te
--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
local vocab = nerv.LMVocab()
vocab:build_file(test_fn)
-local batch_size = 3
-local reader = nerv.LMSeqReader({}, batch_size, vocab)
+local batch_size = 5
+local seq_size = 3
+local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab)
reader:open_file(test_fn)
while (1) do
local list = reader:get_batch()
diff --git a/nerv/examples/lmptb/m-tests/some-text b/nerv/examples/lmptb/m-tests/some-text
index e905b60..cdfbd2c 100644
--- a/nerv/examples/lmptb/m-tests/some-text
+++ b/nerv/examples/lmptb/m-tests/some-text
@@ -1,10 +1,10 @@
-aa bb cc aa bb cc aa bb cc aa bb cc aa bb cc aa
-aa bb cc aa bb cc aa bb cc aa
-aa bbcc aa bb cc aa bb cc aa
-aa bb cc aa
-aa bb cc aa
-aa bb cc aa
-aa
-aa bb cc aa
-aa bb cc aa
-aa bb cc aa bb cc aa
+</s> aa bb cc aa bb cc aa bb cc aa bb cc aa bb cc aa </s>
+</s> aa bb cc aa bb cc aa bb cc aa </s>
+</s> aa bb cc aa bb cc aa bb cc aa </s>
+</s> aa bb cc aa </s>
+</s> aa bb cc aa </s>
+</s> aa bb cc aa </s>
+</s> aa </s>
+</s> aa bb cc aa </s>
+</s> aa bb cc aa </s>
+</s> aa bb cc aa bb cc aa </s>