From 2f410245517a62a37c3229174117ae319a042bff Mon Sep 17 00:00:00 2001 From: txh18 Date: Wed, 28 Oct 2015 17:53:46 +0800 Subject: TODO:implement lmseqreader --- nerv/examples/lmptb/lmptb/lmseqreader.lua | 96 ++++++++++++++++++++++++ nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 19 +++++ nerv/examples/lmptb/m-tests/some-text | 10 +++ 3 files changed, 125 insertions(+) create mode 100644 nerv/examples/lmptb/lmptb/lmseqreader.lua create mode 100644 nerv/examples/lmptb/m-tests/lmseqreader_test.lua create mode 100644 nerv/examples/lmptb/m-tests/some-text diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua new file mode 100644 index 0000000..0952842 --- /dev/null +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -0,0 +1,96 @@ +require 'lmptb.lmvocab' + +local LMReader = nerv.class("nerv.LMSeqReader") + +local printf = nerv.printf + +--global_conf: table +--batch_size: int +--vocab: nerv.LMVocab +function LMReader:__init(global_conf, batch_size, vocab) + self.gconf = global_conf + self.fh = nil --file handle to read, nil means currently no file + self.batch_size = batch_size + self.log_pre = "[LOG]LMFeeder:" + self.vocab = vocab + self.streams = nil +end + +--fn: string +--Initialize all streams +function LMReader:open_file(fn) + if (self.fh ~= nil) then + nerv.error("%s error: in open_file, file handle not nil.") + end + printf("%s opening file %s...\n", self.log_pre, fn) + self.fh = io.open(fn, "r") + self.streams = {} + for i = 1, self.batch_size, 1 do + self.streams[i] = {["store"] = {self.vocab.sen_end_token}, ["head"] = 1, ["tail"] = 1} + end +end + +--id: int +--Refresh stream id, read a line from file +function LMReader:refresh_stream(id) + if (self.streams[id] == nil) then + nerv.error("stream %d does not exit.", id) + end + local st = self.streams[id] + if (st.store[st.head] ~= nil) then return end + if (self.fh == nil) then return end + local list = self.vocab:read_line(self.fh) + if (list == nil) then --file has end + printf("%s file expires, closing.\n", self.log_pre) + self.fh:close() + self.fh = nil + return + end + for i = 1, #list, 1 do + st.tail = st.tail + 1 + st.store[st.tail] = list[i] + end +end + +--Returns: nil/table +--If gets something, return a list of string, vocab.null_token indicates end of string +function LMReader:get_batch() + local got_new = false + local list = {} + for i = 1, self.batch_size, 1 do + self:refresh_stream(i) + local st = self.streams[i] + list[i] = st.store[st.head] + if (list[i] == nil) then list[i] = self.vocab.null_token end + if (list[i] ~= nil and list[i] ~= self.vocab.null_token)then + got_new = true + st.store[st.head] = nil + st.head = st.head + 1 + end + end + if (got_new == false) then + return nil + else + return list + end +end + +--[[ +do + local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/some-text" + --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" + local vocab = nerv.LMVocab() + vocab:build_file(test_fn) + local batch_size = 3 + local feeder = nerv.LMFeeder({}, batch_size, vocab) + feeder:open_file(test_fn) + while (1) do + local list = feeder:get_batch() + if (list == nil) then break end + for i = 1, batch_size, 1 do + printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) + end + printf("\n") + end +end +]]-- diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua new file mode 100644 index 0000000..b90e651 --- /dev/null +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -0,0 +1,19 @@ +require 'lmptb.lmseqreader' + +local printf = nerv.printf + +local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text" +--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" +local vocab = nerv.LMVocab() +vocab:build_file(test_fn) +local batch_size = 3 +local reader = nerv.LMSeqReader({}, batch_size, vocab) +reader:open_file(test_fn) +while (1) do + local list = reader:get_batch() + if (list == nil) then break end + for i = 1, batch_size, 1 do + printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) + end + printf("\n") +end diff --git a/nerv/examples/lmptb/m-tests/some-text b/nerv/examples/lmptb/m-tests/some-text new file mode 100644 index 0000000..e905b60 --- /dev/null +++ b/nerv/examples/lmptb/m-tests/some-text @@ -0,0 +1,10 @@ +aa bb cc aa bb cc aa bb cc aa bb cc aa bb cc aa +aa bb cc aa bb cc aa bb cc aa +aa bbcc aa bb cc aa bb cc aa +aa bb cc aa +aa bb cc aa +aa bb cc aa +aa +aa bb cc aa +aa bb cc aa +aa bb cc aa bb cc aa -- cgit v1.2.3