aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-10-28 17:53:46 +0800
committertxh18 <cloudygooseg@gmail.com>2015-10-28 17:53:46 +0800
commit2f410245517a62a37c3229174117ae319a042bff (patch)
treeec9e1331c627699d93173e027589bffbf631e426
parente0fa1a48cb9f91bfcfc60b732b6f137a7a2071ba (diff)
TODO:implement lmseqreader
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua96
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua19
-rw-r--r--nerv/examples/lmptb/m-tests/some-text10
3 files changed, 125 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
new file mode 100644
index 0000000..0952842
--- /dev/null
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -0,0 +1,96 @@
+require 'lmptb.lmvocab'
+
+local LMReader = nerv.class("nerv.LMSeqReader")
+
+local printf = nerv.printf
+
+--global_conf: table
+--batch_size: int
+--vocab: nerv.LMVocab
+function LMReader:__init(global_conf, batch_size, vocab)
+ self.gconf = global_conf
+ self.fh = nil --file handle to read, nil means currently no file
+ self.batch_size = batch_size
+ self.log_pre = "[LOG]LMFeeder:"
+ self.vocab = vocab
+ self.streams = nil
+end
+
+--fn: string
+--Initialize all streams
+function LMReader:open_file(fn)
+ if (self.fh ~= nil) then
+ nerv.error("%s error: in open_file, file handle not nil.")
+ end
+ printf("%s opening file %s...\n", self.log_pre, fn)
+ self.fh = io.open(fn, "r")
+ self.streams = {}
+ for i = 1, self.batch_size, 1 do
+ self.streams[i] = {["store"] = {self.vocab.sen_end_token}, ["head"] = 1, ["tail"] = 1}
+ end
+end
+
+--id: int
+--Refresh stream id, read a line from file
+function LMReader:refresh_stream(id)
+ if (self.streams[id] == nil) then
+ nerv.error("stream %d does not exit.", id)
+ end
+ local st = self.streams[id]
+ if (st.store[st.head] ~= nil) then return end
+ if (self.fh == nil) then return end
+ local list = self.vocab:read_line(self.fh)
+ if (list == nil) then --file has end
+ printf("%s file expires, closing.\n", self.log_pre)
+ self.fh:close()
+ self.fh = nil
+ return
+ end
+ for i = 1, #list, 1 do
+ st.tail = st.tail + 1
+ st.store[st.tail] = list[i]
+ end
+end
+
+--Returns: nil/table
+--If gets something, return a list of string, vocab.null_token indicates end of string
+function LMReader:get_batch()
+ local got_new = false
+ local list = {}
+ for i = 1, self.batch_size, 1 do
+ self:refresh_stream(i)
+ local st = self.streams[i]
+ list[i] = st.store[st.head]
+ if (list[i] == nil) then list[i] = self.vocab.null_token end
+ if (list[i] ~= nil and list[i] ~= self.vocab.null_token)then
+ got_new = true
+ st.store[st.head] = nil
+ st.head = st.head + 1
+ end
+ end
+ if (got_new == false) then
+ return nil
+ else
+ return list
+ end
+end
+
+--[[
+do
+ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/some-text"
+ --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
+ local vocab = nerv.LMVocab()
+ vocab:build_file(test_fn)
+ local batch_size = 3
+ local feeder = nerv.LMFeeder({}, batch_size, vocab)
+ feeder:open_file(test_fn)
+ while (1) do
+ local list = feeder:get_batch()
+ if (list == nil) then break end
+ for i = 1, batch_size, 1 do
+ printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id)
+ end
+ printf("\n")
+ end
+end
+]]--
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
new file mode 100644
index 0000000..b90e651
--- /dev/null
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -0,0 +1,19 @@
+require 'lmptb.lmseqreader'
+
+local printf = nerv.printf
+
+local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text"
+--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
+local vocab = nerv.LMVocab()
+vocab:build_file(test_fn)
+local batch_size = 3
+local reader = nerv.LMSeqReader({}, batch_size, vocab)
+reader:open_file(test_fn)
+while (1) do
+ local list = reader:get_batch()
+ if (list == nil) then break end
+ for i = 1, batch_size, 1 do
+ printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id)
+ end
+ printf("\n")
+end
diff --git a/nerv/examples/lmptb/m-tests/some-text b/nerv/examples/lmptb/m-tests/some-text
new file mode 100644
index 0000000..e905b60
--- /dev/null
+++ b/nerv/examples/lmptb/m-tests/some-text
@@ -0,0 +1,10 @@
+aa bb cc aa bb cc aa bb cc aa bb cc aa bb cc aa
+aa bb cc aa bb cc aa bb cc aa
+aa bbcc aa bb cc aa bb cc aa
+aa bb cc aa
+aa bb cc aa
+aa bb cc aa
+aa
+aa bb cc aa
+aa bb cc aa
+aa bb cc aa bb cc aa