aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmutil.lua
blob: 6d66d6e8d2796c69af1689d06ede64458cc23e50 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14

                                      











                                                            





                                                                 














                                                                                                                                           
















                                                                                               
























                                                                                                                                          
                                                




                                                                          
                                                        
            
                           




            













































                                                                                    




                                           
                           
                                                                                                                                                                                                                                                                            
   







                                      
                                








                                    
                                                                  



                                        
                           

       
local Util = nerv.class("nerv.LMUtil")

local mysplit = function(inputstr, sep)
    if sep == nil then
        sep = "%s"
    end
    local t={} ; i=1
    for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
        t[i] = str
        i = i + 1
    end
    return t
end

--function rounds a number to the given number of decimal places.
function Util.round(num, idp)
    local mult = 10^(idp or 0)
    return math.floor(num * mult + 0.5) / mult
end

--fh: file_handle
--Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil
function Util.read_line(fh)
    local l_str, list
    
    repeat
        l_str = fh:read("*line")
        if (l_str == nil) then return nil end
        list = mysplit(l_str)
    until #list >= 1

    return list
end


--list: table, list of string(word)
--vocab: nerv.LMVocab
--ty: nerv.CuMatrix
--Returns: nerv.CuMatrixFloat
--Create a matrix of type 'ty', size #list * vocab:size(). null_word will become a zero vector.
function Util.create_onehot(list, vocab, ty)
    local m = ty(#list, vocab:size())
    m:fill(0)
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][vocab:get_word_str(list[i]).id - 1] = 1
        end
    end
    return m
end

--m: matrix
--list: table, list of string(word)
--vocab: nerv.LMVocab
--Returns: nerv.CuMatrixFloat
--Set the matrix, whose size should be size #list * vocab:size() to be one_hot according to the list. null_word will become a zero vector.
function Util.set_onehot(m, list, vocab)
    if (m:nrow() ~= #list or m:ncol() ~= vocab:size()) then
        nerv.error("size of matrix mismatch with list and vocab")
    end
    m:fill(0)
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][vocab:get_word_str(list[i]).id - 1] = 1
        end
    end
    return m
end

--m: matrix
--list: table, list of string(word)
--vocab: nerv.LMVocab
--Returns: nerv.MMatrixInt
--Set the matrix to be ids of the words, id starting at 1, not 0
function Util.set_id(m, list, vocab)
    if (m:nrow() ~= #list or m:ncol() ~= 1) then
        nerv.error("nrow of matrix mismatch with list or its col not one")
    end
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][0] = vocab:get_word_str(list[i]).id
        else
            m[i - 1][0] = 0
        end
    end
    return m
end

function Util.wait(sec)
    local start = os.time()
    repeat until os.time() > start + sec
end

local Result = nerv.class("nerv.LMResult")

--global_conf: table
--vocab:nerv.LMVocab
function Result:__init(global_conf, vocab)
    self.gconf = global_conf
    self.vocab = vocab
end

--cla:string
--Initialize status of class cla
function Result:init(cla)
    self[cla] = {logp_all = 0, logp_unk = 0, cn_w = 0, cn_unk = 0, cn_sen = 0}
end

--cla:string
--w:string
--prob:float, the probability
function Result:add(cla, w, prob)
    self[cla].logp_all = self[cla].logp_all + math.log10(prob)
    if (self.vocab:is_unk_str(w)) then
        self[cla].logp_unk = self[cla].logp_unk + math.log10(prob)
        self[cla].cn_unk = self[cla].cn_unk + 1
    end
    if (w == self.vocab.sen_end_token) then
        self[cla].cn_sen = self[cla].cn_sen + 1
    else
        self[cla].cn_w = self[cla].cn_w + 1
    end
end

function Result:ppl_net(cla)  
    local c = self[cla]
    return math.pow(10, -(c.logp_all - c.logp_unk) / (c.cn_w - c.cn_unk + c.cn_sen))
end

function Result:ppl_all(cla)
    local c = self[cla]
    return math.pow(10, -(c.logp_all) / (c.cn_w + c.cn_sen))
end

function Result:logp_sample(cla)
    local c = self[cla]
    return c.logp_all / (c.cn_w + c.cn_sen)
end

function Result:status(cla)
    return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <UNK_CN " .. self[cla].cn_unk .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">"
end

local Timer = nerv.class("nerv.Timer")
function Timer:__init()
    self.last = {}
    self.rec = {}
end

function Timer:tic(item)
    self.last[item] = os.clock()
end

function Timer:toc(item)
    if (self.last[item] == nil) then
        nerv.error("item not there")
    end
    if (self.rec[item] == nil) then
        self.rec[item] = 0
    end
    self.rec[item] = self.rec[item] + os.clock() - self.last[item]
end

function Timer:flush()
    for key, value in pairs(self.rec) do
        self.rec[key] = nil
    end
end