aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmutil.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmutil.lua')
-rw-r--r--nerv/examples/lmptb/lmptb/lmutil.lua65
1 files changed, 65 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmutil.lua b/nerv/examples/lmptb/lmptb/lmutil.lua
index c15c637..73cf041 100644
--- a/nerv/examples/lmptb/lmptb/lmutil.lua
+++ b/nerv/examples/lmptb/lmptb/lmutil.lua
@@ -17,6 +17,45 @@ function Util.create_onehot(list, vocab, ty)
return m
end
+--m: matrix
+--list: table, list of string(word)
+--vocab: nerv.LMVocab
+--Returns: nerv.CuMatrixFloat
+--Set the matrix, whose size should be size #list * vocab:size() to be one_hot according to the list. null_word will become a zero vector.
+function Util.set_onehot(m, list, vocab)
+ if (m:nrow() ~= #list or m:ncol() ~= vocab:size()) then
+ nerv.error("size of matrix mismatch with list and vocab")
+ end
+ m:fill(0)
+ for i = 1, #list, 1 do
+ --index in matrix starts at 0
+ if (list[i] ~= vocab.null_token) then
+ m[i - 1][vocab:get_word_str(list[i]).id - 1] = 1
+ end
+ end
+ return m
+end
+
+--m: matrix
+--list: table, list of string(word)
+--vocab: nerv.LMVocab
+--Returns: nerv.MMatrixInt
+--Set the matrix to be ids of the words, id starting at 1, not 0
+function Util.set_id(m, list, vocab)
+ if (m:nrow() ~= #list or m:ncol() ~= 1) then
+ nerv.error("nrow of matrix mismatch with list or its col not one")
+ end
+ for i = 1, #list, 1 do
+ --index in matrix starts at 0
+ if (list[i] ~= vocab.null_token) then
+ m[i - 1][0] = vocab:get_word_str(list[i]).id
+ else
+ m[i - 1][0] = 0
+ end
+ end
+ return m
+end
+
function Util.wait(sec)
local start = os.time()
repeat until os.time() > start + sec
@@ -66,3 +105,29 @@ end
function Result:status(cla)
return "LMResult status of " .. cla .. ": " .. "<SEN_CN " .. self[cla].cn_sen .. "> <W_CN " .. self[cla].cn_w .. "> <PPL_NET " .. self:ppl_net(cla) .. "> <PPL_OOV " .. self:ppl_all(cla) .. "> <LOGP " .. self[cla].logp_all .. ">"
end
+
+local Timer = nerv.class("nerv.Timer")
+function Timer:__init()
+ self.last = {}
+ self.rec = {}
+end
+
+function Timer:tic(item)
+ self.last[item] = os.time()
+end
+
+function Timer:toc(item)
+ if (self.last[item] == nil) then
+ nerv.error("item not there")
+ end
+ if (self.rec[item] == nil) then
+ self.rec[item] = 0
+ end
+ self.rec[item] = self.rec[item] + os.difftime(os.time(), self.last[item])
+end
+
+function Timer:flush()
+ for key, value in pairs(self.rec) do
+ self.rec[key] = 0
+ end
+end