diff options
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmfeeder.lua | 3 | ||||
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 3 | ||||
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmutil.lua | 27 | ||||
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmvocab.lua | 12 |
4 files changed, 31 insertions, 14 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmfeeder.lua b/nerv/examples/lmptb/lmptb/lmfeeder.lua index 34631bf..e140f38 100644 --- a/nerv/examples/lmptb/lmptb/lmfeeder.lua +++ b/nerv/examples/lmptb/lmptb/lmfeeder.lua @@ -1,4 +1,5 @@ require 'lmptb.lmvocab' +require 'lmptb.lmutil' local Feeder = nerv.class("nerv.LMFeeder") @@ -39,7 +40,7 @@ function Feeder:refresh_stream(id) local st = self.streams[id] if (st.store[st.head] ~= nil) then return end if (self.fh == nil) then return end - local list = self.vocab:read_line(self.fh) + local list = nerv.LMUtil.read_line(self.fh) if (list == nil) then --file has end printf("%s file expires, closing.\n", self.log_pre) self.fh:close() diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index ed791d2..b603911 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -1,4 +1,5 @@ require 'lmptb.lmvocab' +require 'lmptb.lmutil' --require 'tnn.init' local LMReader = nerv.class("nerv.LMSeqReader") @@ -58,7 +59,7 @@ function LMReader:refresh_stream(id) local st = self.streams[id] if (st.store[st.head] ~= nil) then return end if (self.fh == nil) then return end - local list = self.vocab:read_line(self.fh) + local list = nerv.LMUtil.read_line(self.fh) if (list == nil) then --file has end printf("%s file expires, closing.\n", self.log_pre) self.fh:close() diff --git a/nerv/examples/lmptb/lmptb/lmutil.lua b/nerv/examples/lmptb/lmptb/lmutil.lua index 71e8e17..27b4b10 100644 --- a/nerv/examples/lmptb/lmptb/lmutil.lua +++ b/nerv/examples/lmptb/lmptb/lmutil.lua @@ -1,11 +1,38 @@ local Util = nerv.class("nerv.LMUtil") +local mysplit = function(inputstr, sep) + if sep == nil then + sep = "%s" + end + local t={} ; i=1 + for str in string.gmatch(inputstr, "([^"..sep.."]+)") do + t[i] = str + i = i + 1 + end + return t +end + --function rounds a number to the given number of decimal places. function Util.round(num, idp) local mult = 10^(idp or 0) return math.floor(num * mult + 0.5) / mult end +--fh: file_handle +--Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil +function Util.read_line(fh) + local l_str, list + + repeat + l_str = fh:read("*line") + if (l_str == nil) then return nil end + list = mysplit(l_str) + until #list >= 1 + + return list +end + + --list: table, list of string(word) --vocab: nerv.LMVocab --ty: nerv.CuMatrix diff --git a/nerv/examples/lmptb/lmptb/lmvocab.lua b/nerv/examples/lmptb/lmptb/lmvocab.lua index 3d256c0..2ad0e7e 100644 --- a/nerv/examples/lmptb/lmptb/lmvocab.lua +++ b/nerv/examples/lmptb/lmptb/lmvocab.lua @@ -101,18 +101,6 @@ function Vocab:get_word_id(key) return self.map_id(key) end ---fh: file_handle ---Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil -function Vocab:read_line(fh) - local l_str = fh:read("*line") - if (l_str == nil) then return nil end - local list = mysplit(l_str) - if (list[(#list)] ~= self.sen_end_token) then - list[#list + 1] = self.sen_end_token - end - return list -end - --fn: string --Add all words in fn to the vocab function Vocab:build_file(fn) |