aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmvocab.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmvocab.lua')
-rw-r--r--nerv/examples/lmptb/lmptb/lmvocab.lua141
1 files changed, 141 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmvocab.lua b/nerv/examples/lmptb/lmptb/lmvocab.lua
new file mode 100644
index 0000000..3d256c0
--- /dev/null
+++ b/nerv/examples/lmptb/lmptb/lmvocab.lua
@@ -0,0 +1,141 @@
+local Vocab = nerv.class("nerv.LMVocab")
+
+local printf = nerv.printf
+
+local mysplit = function(inputstr, sep)
+ if sep == nil then
+ sep = "%s"
+ end
+ local t={} ; i=1
+ for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
+ t[i] = str
+ i = i + 1
+ end
+ return t
+end
+
+function Vocab:__init(global_conf)
+ self.gconf = global_conf
+ self.sen_end_token = "</s>"
+ self.unk_token = "<unk>"
+ self.null_token = "<null>" --indicating end of stream(in feeder)
+ self.log_pre = "[LOG]LMVocab:"
+ self.map_str = {} --map from str to word_entry
+ self.map_id = {} --map from id to word_entry
+
+ self:add_word(self.sen_end_token)
+ self:add_word(self.unk_token)
+end
+
+--id: int
+--w_str: string
+--Returns: table
+function Vocab:new_word_entry(id, w_str)
+ return { ["id"] = id,
+ ["str"] = w_str,
+ ["cnt"] = 0,
+ }
+end
+
+--Returns: int
+function Vocab:size()
+ return #self.map_id
+end
+
+--w_str: string
+--if w_str is not in vocab, then add it in, if it is already in, do nothing
+function Vocab:add_word(w_str)
+ if (self.map_str[w_str] ~= nil) then
+ return
+ end
+ local e = self:new_word_entry(self:size() + 1, w_str)
+ self.map_id[self:size() + 1] = e
+ self.map_str[w_str] = e
+end
+
+--Returns: table, the entry of the unk
+function Vocab:get_unk_entry()
+ if (self.map_str[self.unk_token] == nil) then
+ nerv.error("unk entry not found.")
+ end
+ return self.map_str[self.unk_token]
+end
+
+--Returns: table, the entry of sentence end
+function Vocab:get_sen_entry()
+ if (self.map_str[self.sen_end_token] == nil) then
+ nerv.error("sen end token not found")
+ end
+ return self.map_str[self.sen_end_token]
+end
+
+function Vocab:is_unk_str(w)
+ if (key == self.null_token) then
+ nerv.error("Vocab:get_word_str is called by the null token")
+ end
+ if (w == self.unk_token or self.map_str[w] == nil) then
+ return true
+ else
+ return false
+ end
+end
+
+--key: string
+--Returns: table, the word_entry of this key
+function Vocab:get_word_str(key)
+ if (self.map_str[key] == nil) then
+ return self:get_unk_entry()
+ end
+ if (key == self.null_token) then
+ nerv.error("Vocab:get_word_str is called by the null token")
+ end
+ return self.map_str[key]
+end
+
+--key: int
+--Returns: table
+function Vocab:get_word_id(key)
+ if (self.map_id[key] == nil) then
+ nerv.error("id key %d does not exist.", key)
+ end
+ return self.map_id(key)
+end
+
+--fh: file_handle
+--Returns: a list of tokens(string) in the line, if there is no "</s>" at the end, the function will at it, if nothing to read, returns nil
+function Vocab:read_line(fh)
+ local l_str = fh:read("*line")
+ if (l_str == nil) then return nil end
+ local list = mysplit(l_str)
+ if (list[(#list)] ~= self.sen_end_token) then
+ list[#list + 1] = self.sen_end_token
+ end
+ return list
+end
+
+--fn: string
+--Add all words in fn to the vocab
+function Vocab:build_file(fn)
+ printf("%s Vocab building on file %s...\n", self.log_pre, fn)
+ local file = io.open(fn, "r")
+ while (true) do
+ local list = self:read_line(file)
+ if (list == nil) then
+ break
+ else
+ for i = 1, #list, 1 do
+ self:add_word(list[i])
+ end
+ end
+ end
+ file:close()
+ printf("%s Building finished, vocab size now is %d.\n", self.log_pre, self:size())
+end
+
+--[[test
+do
+ local test_fn = "/home/slhome/txh18/workspace/nerv-project/some-text"
+ local vocab = nerv.LMVocab()
+ vocab:build_file(test_fn)
+end
+]]--