aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmutil.lua
blob: 13a5c45005c2cfc86a0c229cf3b653c74d8ca0b5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
local Util = nerv.class("nerv.LMUtil")

local mysplit = function(inputstr, sep)
    if sep == nil then
        sep = "%s"
    end
    local t={} ; i=1
    for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
        t[i] = str
        i = i + 1
    end
    return t
end

--function rounds a number to the given number of decimal places.
function Util.round(num, idp)
    local mult = 10^(idp or 0)
    return math.floor(num * mult + 0.5) / mult
end

--fh: file_handle
--Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil
function Util.read_line(fh)
    local l_str, list
    
    repeat
        l_str = fh:read("*line")
        if (l_str == nil) then return nil end
        list = mysplit(l_str)
    until #list >= 1

    return list
end


--list: table, list of string(word)
--vocab: nerv.LMVocab
--ty: nerv.CuMatrix
--Returns: nerv.CuMatrixFloat
--Create a matrix of type 'ty', size #list * vocab:size(). null_word will become a zero vector.
function Util.create_onehot(list, vocab, ty)
    local m = ty(#list, vocab:size())
    m:fill(0)
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][vocab:get_word_str(list[i]).id - 1] = 1
        end
    end
    return m
end

--m: matrix
--list: table, list of string(word)
--vocab: nerv.LMVocab
--Returns: nerv.CuMatrixFloat
--Set the matrix, whose size should be size #list * vocab:size() to be one_hot according to the list. null_word will become a zero vector.
function Util.set_onehot(m, list, vocab)
    if (m:nrow() ~= #list or m:ncol() ~= vocab:size()) then
        nerv.error("size of matrix mismatch with list and vocab")
    end
    m:fill(0)
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][vocab:get_word_str(list[i]).id - 1] = 1
        end
    end
    return m
end

--m: matrix
--list: table, list of string(word)
--vocab: nerv.LMVocab
--Returns: nerv.MMatrixInt
--Set the matrix to be ids of the words, id starting at 1, not 0
function Util.set_id(m, list, vocab)
    if (m:nrow() ~= #list or m:ncol() ~= 1) then
        nerv.error("nrow of matrix mismatch with list or its col not one")
    end
    for i = 1, #list, 1 do
        --index in matrix starts at 0
        if (list[i] ~= vocab.null_token) then 
            m[i - 1][0] = vocab:get_word_str(list[i]).id
        else
            m[i - 1][0] = 0
        end
    end
    return m
end

function Util.wait(sec)
    local start = os.time()
    repeat until os.time() > start + sec
end

local Result = nerv.class("nerv.LMResult")

--global_conf: table
--vocab:nerv.LMVocab
function Result:__init(global_conf, vocab)
    self.gconf = global_conf
    self.vocab = vocab
end

--cla:string
--Initialize status of class cla
function Result:init(cla)
    self[cla] = {logp_all = 0, logp_unk = 0, cn_w = 0, cn_unk = 0, cn_sen = 0}
end

--cla:string
--w:string
--prob:float, the probability
function Result:add(cla, w, prob, log10ed)
    local lp
    if log10ed == true then
        lp = prob
    else
        lp = math.log10(prob)
    end

    self[cla].logp_all = self[cla].logp_all + lp
    if (self.vocab:is_unk_str(w)) then
        self[cla].logp_unk = self[cla].logp_unk + lp
        self[cla].cn_unk = self[cla].cn_unk + 1
    end
    if (w == self.vocab.sen_end_token) then
        self[cla].cn_sen = self[cla].cn_sen + 1
    else
        self[cla].cn_w = self[cla].cn_w + 1
    end
end

function Result:ppl_net(cla)  
    local c = self[cla]
    return math.pow(10, -(c.logp_all - c.logp_unk) / (c.cn_w - c.cn_unk + c.cn_sen))
end

function Result:ppl_all(cla)
    local c = self[cla]
    return math.pow(10, -(c.logp_all) / (c.cn_w + c.cn_sen))
end

function Result:logp_sample(cla)
    local c = self[cla]
    return c.logp_all / (c.cn_w + c.cn_sen)
end

function Result:status(cla)
    return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <UNK_CN " .. self[cla].cn_unk .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">"
end

local Timer = nerv.class("nerv.Timer")
function Timer:__init()
    self.last = {}
    self.rec = {}
end

function Timer:tic(item)
    self.last[item] = os.clock()
end

function Timer:toc(item)
    if (self.last[item] == nil) then
        nerv.error("item not there")
    end
    if (self.rec[item] == nil) then
        self.rec[item] = 0
    end
    self.rec[item] = self.rec[item] + os.clock() - self.last[item]
end

function Timer:flush()
    for key, value in pairs(self.rec) do
        self.rec[key] = nil
    end
end