aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmutil.lua
blob: 73cf0416fffaf43bb9e3676e1d9810aa3a74df68 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
local Util = nerv.class("nerv.LMUtil")

--list: table, list of string(word)
--vocab: nerv.LMVocab
--ty: nerv.CuMatrix
--Returns: nerv.CuMatrixFloat
--Create a matrix of type 'ty', size #list * vocab:size(). null_word will become a zero vector.
function Util.create_onehot(list, vocab, ty)
    local m = ty(#list, vocab:size())
    m:fill(0)
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][vocab:get_word_str(list[i]).id - 1] = 1
        end
    end
    return m
end

--m: matrix
--list: table, list of string(word)
--vocab: nerv.LMVocab
--Returns: nerv.CuMatrixFloat
--Set the matrix, whose size should be size #list * vocab:size() to be one_hot according to the list. null_word will become a zero vector.
function Util.set_onehot(m, list, vocab)
    if (m:nrow() ~= #list or m:ncol() ~= vocab:size()) then
        nerv.error("size of matrix mismatch with list and vocab")
    end
    m:fill(0)
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][vocab:get_word_str(list[i]).id - 1] = 1
        end
    end
    return m
end

--m: matrix
--list: table, list of string(word)
--vocab: nerv.LMVocab
--Returns: nerv.MMatrixInt
--Set the matrix to be ids of the words, id starting at 1, not 0
function Util.set_id(m, list, vocab)
    if (m:nrow() ~= #list or m:ncol() ~= 1) then
        nerv.error("nrow of matrix mismatch with list or its col not one")
    end
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][0] = vocab:get_word_str(list[i]).id
        else
            m[i - 1][0] = 0
        end
    end
    return m
end

function Util.wait(sec)
    local start = os.time()
    repeat until os.time() > start + sec
end

local Result = nerv.class("nerv.LMResult")

--global_conf: table
--vocab:nerv.LMVocab
function Result:__init(global_conf, vocab)
    self.gconf = global_conf
    self.vocab = vocab
end

--cla:string
--Initialize status of class cla
function Result:init(cla)
    self[cla] = {logp_all = 0, logp_unk = 0, cn_w = 0, cn_unk = 0, cn_sen = 0}
end

--cla:string
--w:string
--prob:float, the probability
function Result:add(cla, w, prob)
    self[cla].logp_all = self[cla].logp_all + math.log10(prob)
    if (self.vocab:is_unk_str(w)) then
        self[cla].logp_unk = self[cla].logp_unk + math.log10(prob)
        self[cla].cn_unk = self[cla].cn_unk + 1
    end
    if (w == self.vocab.sen_end_token) then
        self[cla].cn_sen = self[cla].cn_sen + 1
    else
        self[cla].cn_w = self[cla].cn_w + 1
    end
end

function Result:ppl_net(cla)  
    local c = self[cla]
    return math.pow(10, -(c.logp_all - c.logp_unk) / (c.cn_w - c.cn_unk + c.cn_sen))
end

function Result:ppl_all(cla)
    local c = self[cla]
    return math.pow(10, -(c.logp_all) / (c.cn_w + c.cn_sen))
end

function Result:status(cla)
    return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">"
end

local Timer = nerv.class("nerv.Timer")
function Timer:__init()
    self.last = {}
    self.rec = {}
end

function Timer:tic(item)
    self.last[item] = os.time()
end

function Timer:toc(item)
    if (self.last[item] == nil) then
        nerv.error("item not there")
    end
    if (self.rec[item] == nil) then
        self.rec[item] = 0
    end
    self.rec[item] = self.rec[item] + os.difftime(os.time(), self.last[item])
end

function Timer:flush()
    for key, value in pairs(self.rec) do
        self.rec[key] = 0
    end
end