From 4d48cb10ba5bcd6e441a1919d61a64d0a6b4bee9 Mon Sep 17 00:00:00 2001 From: txh18 Date: Fri, 30 Oct 2015 00:21:40 +0800 Subject: still working in lmseqreader --- nerv/examples/lmptb/lmptb/lmseqreader.lua | 61 ++++++++++++++++-------- nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 5 +- nerv/examples/lmptb/m-tests/some-text | 20 ++++---- 3 files changed, 54 insertions(+), 32 deletions(-) diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index edc3ff4..26dc3be 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -7,10 +7,11 @@ local printf = nerv.printf --global_conf: table --batch_size: int --vocab: nerv.LMVocab -function LMReader:__init(global_conf, batch_size, vocab) +function LMReader:__init(global_conf, batch_size, seq_size, vocab) self.gconf = global_conf self.fh = nil --file handle to read, nil means currently no file self.batch_size = batch_size + self.seq_size = seq_size self.log_pre = "[LOG]LMFeeder:" self.vocab = vocab self.streams = nil @@ -26,12 +27,12 @@ function LMReader:open_file(fn) self.fh = io.open(fn, "r") self.streams = {} for i = 1, self.batch_size, 1 do - self.streams[i] = {["store"] = {self.vocab.sen_end_token}, ["head"] = 1, ["tail"] = 1} + self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0} end end --id: int ---Refresh stream id, read a line from file +--Refresh stream id, read a line from file, will check whether this line is cntklm-style function LMReader:refresh_stream(id) if (self.streams[id] == nil) then nerv.error("stream %d does not exit.", id) @@ -40,41 +41,61 @@ function LMReader:refresh_stream(id) if (st.store[st.head] ~= nil) then return end if (self.fh == nil) then return end local list = self.vocab:read_line(self.fh) - if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntk style input - nerv.error("sentence not begin or end with : %s", table.tostring(list)); - end if (list == nil) then --file has end printf("%s file expires, closing.\n", self.log_pre) self.fh:close() self.fh = nil return end + + --some sanity check + if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntklm style input + nerv.error("%s sentence not begin or end with : %s", self.log_pre, table.tostring(list)); + end + for i = 2, #list - 1, 1 do + if (list[i] == self.vocab.sen_end_token) then + nerv.error("%s Got in the middle of a line(%s) in file", self.log_pre, table.tostring(list)) + end + end + for i = 1, #list, 1 do st.tail = st.tail + 1 st.store[st.tail] = list[i] end end ---Returns: nil/table ---If gets something, return a list of string, vocab.null_token indicates end of string -function LMReader:get_batch() +function LMReader:get_batch(input, label) local got_new = false local list = {} - for i = 1, self.batch_size, 1 do - self:refresh_stream(i) + for i = 1, self.seq_size, 1 do local st = self.streams[i] - list[i] = st.store[st.head] - if (list[i] == nil) then list[i] = self.vocab.null_token end - if (list[i] ~= nil and list[i] ~= self.vocab.null_token)then - got_new = true - st.store[st.head] = nil - st.head = st.head + 1 - end + for j = 1, self.batch_size, 1 do + self:refresh_stream(i) + if (st.store[st.head] ~= nil) then + input[i][j] = st.store[st.head] + else + input[i][j] = self.vocab.null_token + end + if (st.store[st.head + 1] ~= nil) then + label[i][j] = st.store[st.head + 1] + else + label[i][j] = self.vocab.null_token + end + if (input[i][j] ~= self.vocab.null_token) then + got_new = true + st.store[st.head] = nil + st.head = st.head + 1 + if (label[i][j] == self.vocab.sen_end_token) then + st.store[st.head] = nil --sentence end is passed + st.head = st.head + 1 + end + end + end end if (got_new == false) then - return nil + return false else - return list + return true end end diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index b90e651..bdea740 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -6,8 +6,9 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" local vocab = nerv.LMVocab() vocab:build_file(test_fn) -local batch_size = 3 -local reader = nerv.LMSeqReader({}, batch_size, vocab) +local batch_size = 5 +local seq_size = 3 +local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab) reader:open_file(test_fn) while (1) do local list = reader:get_batch() diff --git a/nerv/examples/lmptb/m-tests/some-text b/nerv/examples/lmptb/m-tests/some-text index e905b60..cdfbd2c 100644 --- a/nerv/examples/lmptb/m-tests/some-text +++ b/nerv/examples/lmptb/m-tests/some-text @@ -1,10 +1,10 @@ -aa bb cc aa bb cc aa bb cc aa bb cc aa bb cc aa -aa bb cc aa bb cc aa bb cc aa -aa bbcc aa bb cc aa bb cc aa -aa bb cc aa -aa bb cc aa -aa bb cc aa -aa -aa bb cc aa -aa bb cc aa -aa bb cc aa bb cc aa + aa bb cc aa bb cc aa bb cc aa bb cc aa bb cc aa + aa bb cc aa bb cc aa bb cc aa + aa bb cc aa bb cc aa bb cc aa + aa bb cc aa + aa bb cc aa + aa bb cc aa + aa + aa bb cc aa + aa bb cc aa + aa bb cc aa bb cc aa -- cgit v1.2.3