diff options
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmutil.lua')
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmutil.lua | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmutil.lua b/nerv/examples/lmptb/lmptb/lmutil.lua index 71e8e17..6d66d6e 100644 --- a/nerv/examples/lmptb/lmptb/lmutil.lua +++ b/nerv/examples/lmptb/lmptb/lmutil.lua @@ -1,11 +1,38 @@ local Util = nerv.class("nerv.LMUtil") +local mysplit = function(inputstr, sep) + if sep == nil then + sep = "%s" + end + local t={} ; i=1 + for str in string.gmatch(inputstr, "([^"..sep.."]+)") do + t[i] = str + i = i + 1 + end + return t +end + --function rounds a number to the given number of decimal places. function Util.round(num, idp) local mult = 10^(idp or 0) return math.floor(num * mult + 0.5) / mult end +--fh: file_handle +--Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil +function Util.read_line(fh) + local l_str, list + + repeat + l_str = fh:read("*line") + if (l_str == nil) then return nil end + list = mysplit(l_str) + until #list >= 1 + + return list +end + + --list: table, list of string(word) --vocab: nerv.LMVocab --ty: nerv.CuMatrix @@ -114,7 +141,7 @@ function Result:logp_sample(cla) end function Result:status(cla) - return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">" + return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <UNK_CN " .. self[cla].cn_unk .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">" end local Timer = nerv.class("nerv.Timer") |