aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lmptb/lmfeeder.lua3
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua3
-rw-r--r--nerv/examples/lmptb/lmptb/lmutil.lua27
-rw-r--r--nerv/examples/lmptb/lmptb/lmvocab.lua12
4 files changed, 31 insertions, 14 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmfeeder.lua b/nerv/examples/lmptb/lmptb/lmfeeder.lua
index 34631bf..e140f38 100644
--- a/nerv/examples/lmptb/lmptb/lmfeeder.lua
+++ b/nerv/examples/lmptb/lmptb/lmfeeder.lua
@@ -1,4 +1,5 @@
require 'lmptb.lmvocab'
+require 'lmptb.lmutil'
local Feeder = nerv.class("nerv.LMFeeder")
@@ -39,7 +40,7 @@ function Feeder:refresh_stream(id)
local st = self.streams[id]
if (st.store[st.head] ~= nil) then return end
if (self.fh == nil) then return end
- local list = self.vocab:read_line(self.fh)
+ local list = nerv.LMUtil.read_line(self.fh)
if (list == nil) then --file has end
printf("%s file expires, closing.\n", self.log_pre)
self.fh:close()
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index ed791d2..b603911 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -1,4 +1,5 @@
require 'lmptb.lmvocab'
+require 'lmptb.lmutil'
--require 'tnn.init'
local LMReader = nerv.class("nerv.LMSeqReader")
@@ -58,7 +59,7 @@ function LMReader:refresh_stream(id)
local st = self.streams[id]
if (st.store[st.head] ~= nil) then return end
if (self.fh == nil) then return end
- local list = self.vocab:read_line(self.fh)
+ local list = nerv.LMUtil.read_line(self.fh)
if (list == nil) then --file has end
printf("%s file expires, closing.\n", self.log_pre)
self.fh:close()
diff --git a/nerv/examples/lmptb/lmptb/lmutil.lua b/nerv/examples/lmptb/lmptb/lmutil.lua
index 71e8e17..27b4b10 100644
--- a/nerv/examples/lmptb/lmptb/lmutil.lua
+++ b/nerv/examples/lmptb/lmptb/lmutil.lua
@@ -1,11 +1,38 @@
local Util = nerv.class("nerv.LMUtil")
+local mysplit = function(inputstr, sep)
+ if sep == nil then
+ sep = "%s"
+ end
+ local t={} ; i=1
+ for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
+ t[i] = str
+ i = i + 1
+ end
+ return t
+end
+
--function rounds a number to the given number of decimal places.
function Util.round(num, idp)
local mult = 10^(idp or 0)
return math.floor(num * mult + 0.5) / mult
end
+--fh: file_handle
+--Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil
+function Util.read_line(fh)
+ local l_str, list
+
+ repeat
+ l_str = fh:read("*line")
+ if (l_str == nil) then return nil end
+ list = mysplit(l_str)
+ until #list >= 1
+
+ return list
+end
+
+
--list: table, list of string(word)
--vocab: nerv.LMVocab
--ty: nerv.CuMatrix
diff --git a/nerv/examples/lmptb/lmptb/lmvocab.lua b/nerv/examples/lmptb/lmptb/lmvocab.lua
index 3d256c0..2ad0e7e 100644
--- a/nerv/examples/lmptb/lmptb/lmvocab.lua
+++ b/nerv/examples/lmptb/lmptb/lmvocab.lua
@@ -101,18 +101,6 @@ function Vocab:get_word_id(key)
return self.map_id(key)
end
---fh: file_handle
---Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil
-function Vocab:read_line(fh)
- local l_str = fh:read("*line")
- if (l_str == nil) then return nil end
- local list = mysplit(l_str)
- if (list[(#list)] ~= self.sen_end_token) then
- list[#list + 1] = self.sen_end_token
- end
- return list
-end
-
--fn: string
--Add all words in fn to the vocab
function Vocab:build_file(fn)