aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmvocab.lua
blob: 38bb18e997425be464ff6b7ab09b7794b313f5f5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
require 'lmptb.lmutil'

local Vocab = nerv.class("nerv.LMVocab")

local mysplit = function(inputstr, sep)
    if sep == nil then
        sep = "%s"
    end
    local t={} ; i=1
    for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
        t[i] = str
        i = i + 1
    end
    return t
end

function Vocab:__init(global_conf)
    self.gconf = global_conf
    self.sen_end_token = "</s>"
    self.unk_token = "<unk>"
    self.null_token = "<null>" --indicating end of stream(in feeder)
    self.log_pre = "[LOG]LMVocab:"
    self.map_str = {} --map from str to word_entry
    self.map_id = {} --map from id to word_entry
    
    self:add_word(self.sen_end_token)
    self:add_word(self.unk_token)
end

--id: int
--w_str: string
--Returns: table
function Vocab:new_word_entry(id, w_str)
    return { ["id"] = id,
            ["str"] = w_str,
            ["cnt"] = 0,
        } 
end

--Returns: int
function Vocab:size()
    return #self.map_id
end

--w_str: string
--if w_str is not in vocab, then add it in, if it is already in, do nothing
function Vocab:add_word(w_str)
    if (self.map_str[w_str] ~= nil) then
        return 
    end
    local e = self:new_word_entry(self:size() + 1, w_str)
    self.map_id[self:size() + 1] = e
    self.map_str[w_str] = e
end

--Returns: table, the entry of the unk
function Vocab:get_unk_entry()
    if (self.map_str[self.unk_token] == nil) then
        nerv.error("unk entry not found.")
    end
    return self.map_str[self.unk_token]
end

--Returns: table, the entry of sentence end
function Vocab:get_sen_entry()
    if (self.map_str[self.sen_end_token] == nil) then
        nerv.error("sen end token not found")
    end
    return self.map_str[self.sen_end_token]
end

function Vocab:is_unk_str(w)
    if (key == self.null_token) then
        nerv.error("Vocab:get_word_str is called by the null token")
    end
    if (w == self.unk_token or self.map_str[w] == nil) then
        return true
    else
        return false
    end
end

--key: string
--Returns: table, the word_entry of this key
function Vocab:get_word_str(key)
    if (self.map_str[key] == nil) then
        return self:get_unk_entry()
    end
    if (key == self.null_token) then
        nerv.error("Vocab:get_word_str is called by the null token")
    end
    return self.map_str[key]
end

--key: int
--Returns: table
function Vocab:get_word_id(key)
    if (self.map_id[key] == nil) then
        nerv.error("id key %d does not exist.", key) 
    end
    return self.map_id[key]
end

--fn: string
--Add all words in fn to the vocab
function Vocab:build_file(fn)
    nerv.printf("%s Vocab building on file %s...\n", self.log_pre, fn)
    local file = io.open(fn, "r")
    while (true) do
        local list = nerv.LMUtil.read_line(file)
        if (list == nil) then
            break
        else
            for i = 1, #list, 1 do
                self:add_word(list[i])          
            end
        end
    end
    file:close()
    nerv.printf("%s Building finished, vocab size now is %d.\n", self.log_pre, self:size())
end

--[[test
do
    local test_fn = "/home/slhome/txh18/workspace/nerv-project/some-text"
    local vocab = nerv.LMVocab()
    vocab:build_file(test_fn)
end
]]--