diff options
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmutil.lua')
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmutil.lua | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmutil.lua b/nerv/examples/lmptb/lmptb/lmutil.lua index c15c637..73cf041 100644 --- a/nerv/examples/lmptb/lmptb/lmutil.lua +++ b/nerv/examples/lmptb/lmptb/lmutil.lua @@ -17,6 +17,45 @@ function Util.create_onehot(list, vocab, ty) return m end +--m: matrix +--list: table, list of string(word) +--vocab: nerv.LMVocab +--Returns: nerv.CuMatrixFloat +--Set the matrix, whose size should be size #list * vocab:size() to be one_hot according to the list. null_word will become a zero vector. +function Util.set_onehot(m, list, vocab) + if (m:nrow() ~= #list or m:ncol() ~= vocab:size()) then + nerv.error("size of matrix mismatch with list and vocab") + end + m:fill(0) + for i = 1, #list, 1 do + --index in matrix starts at 0 + if (list[i] ~= vocab.null_token) then + m[i - 1][vocab:get_word_str(list[i]).id - 1] = 1 + end + end + return m +end + +--m: matrix +--list: table, list of string(word) +--vocab: nerv.LMVocab +--Returns: nerv.MMatrixInt +--Set the matrix to be ids of the words, id starting at 1, not 0 +function Util.set_id(m, list, vocab) + if (m:nrow() ~= #list or m:ncol() ~= 1) then + nerv.error("nrow of matrix mismatch with list or its col not one") + end + for i = 1, #list, 1 do + --index in matrix starts at 0 + if (list[i] ~= vocab.null_token) then + m[i - 1][0] = vocab:get_word_str(list[i]).id + else + m[i - 1][0] = 0 + end + end + return m +end + function Util.wait(sec) local start = os.time() repeat until os.time() > start + sec @@ -66,3 +105,29 @@ end function Result:status(cla) return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">" end + +local Timer = nerv.class("nerv.Timer") +function Timer:__init() + self.last = {} + self.rec = {} +end + +function Timer:tic(item) + self.last[item] = os.time() +end + +function Timer:toc(item) + if (self.last[item] == nil) then + nerv.error("item not there") + end + if (self.rec[item] == nil) then + self.rec[item] = 0 + end + self.rec[item] = self.rec[item] + os.difftime(os.time(), self.last[item]) +end + +function Timer:flush() + for key, value in pairs(self.rec) do + self.rec[key] = 0 + end +end |