aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmseqreader.lua
blob: cc805a417e9ef93abe511e3437d4b7e95cd03e8f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
require 'lmptb.lmvocab'

local LMReader = nerv.class("nerv.LMSeqReader")

local printf = nerv.printf

--global_conf: table
--batch_size: int
--vocab: nerv.LMVocab
function LMReader:__init(global_conf, batch_size, chunk_size, vocab)
    self.gconf = global_conf
    self.fh = nil --file handle to read, nil means currently no file
    self.batch_size = batch_size
    self.chunk_size = chunk_size
    self.log_pre = "[LOG]LMSeqReader:"
    self.vocab = vocab
    self.streams = nil
end

--fn: string
--Initialize all streams
function LMReader:open_file(fn)
    if (self.fh ~= nil) then
        nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn)
    end
    printf("%s opening file %s...\n", self.log_pre, fn)
    print("batch_size:", self.batch_size, "chunk_size", self.chunk_size)
    self.fh = io.open(fn, "r")
    self.streams = {}
    for i = 1, self.batch_size, 1 do
        self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0}
    end
    
    self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix)
    for j = 1, self.chunk_size, 1 do
        self.bak_inputs_m[j] = {}
        self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1)
        self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used
    end
end
   
--id: int
--Refresh stream id,  read a line from file, will check whether this line is cntklm-style
function LMReader:refresh_stream(id)
    if (self.streams[id] == nil) then
        nerv.error("stream %d does not exit.", id)
    end
    local st = self.streams[id]
    if (st.store[st.head] ~= nil) then return end
    if (self.fh == nil) then return end
    local list = self.vocab:read_line(self.fh)
    if (list == nil) then --file has end
        printf("%s file expires, closing.\n", self.log_pre)
        self.fh:close() 
        self.fh = nil 
        return 
    end

    --some sanity check
    if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntklm style input
        nerv.error("%s sentence not begin or end with </s> : %s", self.log_pre, table.tostring(list));
    end
    for i = 2, #list - 1, 1 do
        if (list[i] == self.vocab.sen_end_token) then
            nerv.error("%s Got </s> in the middle of a line(%s) in file", self.log_pre, table.tostring(list))
        end
    end

    for i = 1, #list, 1 do
        st.tail = st.tail + 1
        st.store[st.tail] = list[i]
    end
end

--feeds: a table that will be filled by the reader
--Returns: bool
function LMReader:get_batch(feeds)
    if (feeds == nil or type(feeds) ~= "table") then
        nerv.error("feeds is not a table")
    end

    feeds["inputs_s"] = {}
    feeds["labels_s"] = {}
    local inputs_s = feeds.inputs_s
    local labels_s = feeds.labels_s
    for i = 1, self.chunk_size, 1 do
        inputs_s[i] = {}
        labels_s[i] = {} 
    end

    local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label
    local flags = feeds.flags_now
    local flagsPack = feeds.flagsPack_now

    local got_new = false
    for j = 1, self.chunk_size, 1 do
        inputs_m[j][2]:fill(0)
    end
    for i = 1, self.batch_size, 1 do
        local st = self.streams[i]
        for j = 1, self.chunk_size, 1 do
            flags[j][i] = 0
            self:refresh_stream(i)
            if st.store[st.head] ~= nil then
                inputs_s[j][i] = st.store[st.head]
                --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
                self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
            else
                inputs_s[j][i] = self.vocab.null_token
                --inputs_m[j][1][i - 1][0] = 0
                self.bak_inputs_m[j][1][i - 1][0] = 0
            end
            if st.store[st.head + 1] ~= nil then
                labels_s[j][i] = st.store[st.head + 1]
                inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
            else
                if (inputs_s[j][i] ~= self.vocab.null_token) then
                    nerv.error("reader error : input not null but label is null_token")
                end
                labels_s[j][i] = self.vocab.null_token
            end
            if (inputs_s[j][i] ~= self.vocab.null_token) then
                if (labels_s[j][i] == self.vocab.null_token) then
                    nerv.error("reader error : label is null while input is not null")
                end
                flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM)
                got_new = true
                st.store[st.head] = nil
                st.head = st.head + 1
                if labels_s[j][i] == self.vocab.sen_end_token then
                    flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END)
                    st.store[st.head] = nil --sentence end is passed
                    st.head = st.head + 1
                end
                if inputs_s[j][i] == self.vocab.sen_end_token then
                    flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
                end
            end 
        end
    end
    
    for j = 1, self.chunk_size, 1 do
        flagsPack[j] = 0
        for i = 1, self.batch_size, 1 do
            flagsPack[j] = bit.bor(flagsPack[j], flags[j][i])
        end
        inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1])
    end

    if (got_new == false) then
        return false
    else
        return true
    end
end

--[[
do
    local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/some-text"
    --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
    local vocab = nerv.LMVocab()
    vocab:build_file(test_fn)
    local batch_size = 3
    local feeder = nerv.LMFeeder({}, batch_size, vocab)
    feeder:open_file(test_fn)
    while (1) do
        local list = feeder:get_batch()
        if (list == nil) then break end
        for i = 1, batch_size, 1 do
            printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) 
        end
        printf("\n")
    end
end
]]--