aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmseqreader.lua
blob: 127292940851d78de0cc7c071f649bb2e2ae138a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
require 'lmptb.lmvocab'
require 'lmptb.lmutil'
--require 'tnn.init'

local LMReader = nerv.class("nerv.LMSeqReader")

local printf = nerv.printf

--global_conf: table
--batch_size: int
--vocab: nerv.LMVocab
function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf)
    self.gconf = global_conf
    self.fh = nil --file handle to read, nil means currently no file
    self.batch_size = batch_size
    self.chunk_size = chunk_size
    self.log_pre = "[LOG]LMSeqReader:"
    self.vocab = vocab
    self.streams = nil
    if r_conf == nil then
        r_conf = {}
    end
    self.se_mode = false --sentence end mode, when a sentence end is met, the stream after will be null
    if r_conf.se_mode == true then
        self.se_mode = true
    end
    self.compressed_label = false
    if r_conf.compressed_label == true then
        self.compressed_label = true
    end
    self.same_io = false
    if r_conf.same_io == true then --can be used to train P(wi|w1..(i-1),(i+1)..n)
        self.same_io = true
    end
end

--fn: string
--Initialize all streams
function LMReader:open_file(fn)
    if (self.fh ~= nil) then
        nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn)
    end
    nerv.printf("%s opening file %s...\n", self.log_pre, fn)
    nerv.printf("%s batch_size:%d chunk_size:%d\n", self.log_pre, self.batch_size, self.chunk_size)
    nerv.printf("%s se_mode:%s same_io:%s\n", self.log_pre, tostring(self.se_mode), tostring(self.same_io))
    self.fh = io.open(fn, "r")
    self.streams = {}
    for i = 1, self.batch_size, 1 do
        self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0}
    end
    self.stat = {} --stat collected during file reading
    self.stat.al_sen_start = true --check whether it's always sentence_start at the begining of a minibatch
    self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix)
    for j = 1, self.chunk_size, 1 do
        self.bak_inputs_m[j] = {}
        self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1)
        if self.compressed_label == true then
            self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, 1)
        end
        --self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used
    end
end
   
--id: int
--Refresh stream id,  read a line from file, will check whether this line is cntklm-style
function LMReader:refresh_stream(id)
    if (self.streams[id] == nil) then
        nerv.error("stream %d does not exit.", id)
    end
    local st = self.streams[id]
    if (st.store[st.head] ~= nil) then return end
    if (self.fh == nil) then return end
    local list = nerv.LMUtil.read_line(self.fh)
    if (list == nil) then --file has end
        printf("%s file expires, closing.\n", self.log_pre)
        self.fh:close() 
        self.fh = nil 
        return 
    end

    --some sanity check
    if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntklm style input
        nerv.error("%s sentence not begin or end with </s> : %s", self.log_pre, table.tostring(list));
    end
    for i = 2, #list - 1, 1 do
        if (list[i] == self.vocab.sen_end_token) then
            nerv.error("%s Got </s> in the middle of a line(%s) in file", self.log_pre, table.tostring(list))
        end
    end

    for i = 1, #list, 1 do
        st.tail = st.tail + 1
        st.store[st.tail] = list[i]
    end
end

--feeds: a table that will be filled by the reader
--Returns: bool
function LMReader:get_batch(feeds)
    if (feeds == nil or type(feeds) ~= "table") then
        nerv.error("feeds is not a table")
    end

    feeds["inputs_s"] = {}
    feeds["labels_s"] = {}
    local inputs_s = feeds.inputs_s
    local labels_s = feeds.labels_s
    for i = 1, self.chunk_size, 1 do
        inputs_s[i] = {}
        labels_s[i] = {} 
    end

    local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label
    local flags = feeds.flags_now
    local flagsPack = feeds.flagsPack_now

    local got_new = false
    for j = 1, self.chunk_size, 1 do
        inputs_m[j][2]:fill(0)
    end
    for i = 1, self.batch_size, 1 do
        local st = self.streams[i]
        local end_stream = false --used for se_mode, indicating that this stream is ended
        for j = 1, self.chunk_size, 1 do
            flags[j][i] = 0
            if end_stream == true then
                if self.se_mode == false then
                    nerv.error("lmseqreader:getbatch: error, end_stream is true while se_mode is false")
                end
                inputs_s[j][i] = self.vocab.null_token
                self.bak_inputs_m[j][1][i - 1][0] = 0
                if self.compressed_label == true then
                    self.bak_inputs_m[j][2][i - 1][0] = 0
                end
                labels_s[j][i] = self.vocab.null_token
            else
                self:refresh_stream(i)
                if st.store[st.head] ~= nil then
                    if self.same_io == false then 
                        inputs_s[j][i] = st.store[st.head]
                        self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
                    else
                        inputs_s[j][i] = st.store[st.head + 1]
                        self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1                      
                    end
                else
                    inputs_s[j][i] = self.vocab.null_token
                    self.bak_inputs_m[j][1][i - 1][0] = 0
                end
                if st.store[st.head + 1] ~= nil then
                    labels_s[j][i] = st.store[st.head + 1]
                    if self.compressed_label == true then
                        self.bak_inputs_m[j][2][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1
                    else
                        inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
                    end
                else
                    if inputs_s[j][i] ~= self.vocab.null_token then
                        nerv.error("reader error : input not null but label is null_token")
                    end
                    labels_s[j][i] = self.vocab.null_token
                end
                if inputs_s[j][i] ~= self.vocab.null_token then
                    if labels_s[j][i] == self.vocab.null_token then
                        nerv.error("reader error : label is null while input is not null")
                    end
                    flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label
                    got_new = true
                    if st.store[st.head] == self.vocab.sen_end_token then
                        flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
                    end
                    st.store[st.head] = nil
                    st.head = st.head + 1
                    if labels_s[j][i] == self.vocab.sen_end_token then
                        flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END)
                        st.store[st.head] = nil --sentence end is passed
                        st.head = st.head + 1
                        if self.se_mode == true then
                            end_stream = true --meet sentence end, this stream ends now
                        end
                    end
               end 
            end
        end
    end
    
    for j = 1, self.chunk_size, 1 do
        flagsPack[j] = 0
        for i = 1, self.batch_size, 1 do
            flagsPack[j] = bit.bor(flagsPack[j], flags[j][i])
        end
        inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1])
        if self.compressed_label == true then
            inputs_m[j][2]:copy_fromh(self.bak_inputs_m[j][2])
        end
    end

    --check for self.al_sen_start
    for i = 1, self.batch_size do
        if bit.band(flags[1][i], nerv.TNN.FC.SEQ_START) == 0 and flags[1][i] > 0 then
            self.stat.al_sen_start = false
        end
    end

    if got_new == false then
        nerv.info("lmseqreader file ends, printing stats...")
        nerv.printf("al_sen_start:%s\n", tostring(self.stat.al_sen_start))
        return false
    else
        return true
    end
end

--[[
do
    local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/some-text"
    --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
    local vocab = nerv.LMVocab()
    vocab:build_file(test_fn)
    local batch_size = 3
    local feeder = nerv.LMFeeder({}, batch_size, vocab)
    feeder:open_file(test_fn)
    while (1) do
        local list = feeder:get_batch()
        if (list == nil) then break end
        for i = 1, batch_size, 1 do
            printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) 
        end
        printf("\n")
    end
end
]]--