aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmutil.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmutil.lua')
-rw-r--r--nerv/examples/lmptb/lmptb/lmutil.lua29
1 files changed, 28 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmutil.lua b/nerv/examples/lmptb/lmptb/lmutil.lua
index 71e8e17..6d66d6e 100644
--- a/nerv/examples/lmptb/lmptb/lmutil.lua
+++ b/nerv/examples/lmptb/lmptb/lmutil.lua
@@ -1,11 +1,38 @@
local Util = nerv.class("nerv.LMUtil")
+local mysplit = function(inputstr, sep)
+ if sep == nil then
+ sep = "%s"
+ end
+ local t={} ; i=1
+ for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
+ t[i] = str
+ i = i + 1
+ end
+ return t
+end
+
--function rounds a number to the given number of decimal places.
function Util.round(num, idp)
local mult = 10^(idp or 0)
return math.floor(num * mult + 0.5) / mult
end
+--fh: file_handle
+--Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil
+function Util.read_line(fh)
+ local l_str, list
+
+ repeat
+ l_str = fh:read("*line")
+ if (l_str == nil) then return nil end
+ list = mysplit(l_str)
+ until #list >= 1
+
+ return list
+end
+
+
--list: table, list of string(word)
--vocab: nerv.LMVocab
--ty: nerv.CuMatrix
@@ -114,7 +141,7 @@ function Result:logp_sample(cla)
end
function Result:status(cla)
- return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">"
+ return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <UNK_CN " .. self[cla].cn_unk .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">"
end
local Timer = nerv.class("nerv.Timer")