aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/ptb/reader.lua
blob: 70c0c978b4dc499e39dfae4eae50d2a1816a21c4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
local Reader = nerv.class('nerv.Reader')

function Reader:__init(vocab_file, input_file)
    self:get_vocab(vocab_file)
    self:get_seq(input_file)
    self.offset = 1
end

function Reader:get_vocab(vocab_file)
    local f = io.open(vocab_file, 'r')
    local id = 0
    self.vocab = {}
    while true do
        local word = f:read()
        if word == nil then
            break
        end
        self.vocab[word] = id
        id = id + 1
    end
    self.size = id
end

function Reader:split(s, t)
    local ret = {}
    for x in (s .. t):gmatch('(.-)' .. t) do
        table.insert(ret, x)
    end
    return ret
end

function Reader:get_seq(input_file)
    local f = io.open(input_file, 'r')
    self.seq = {}
    -- while true do
    for i = 1, 26 do
        local seq = f:read()
        if seq == nil then
            break
        end
        seq = self:split(seq, ' ')
        local tmp = {}
        for i = 1, #seq do
            if seq[i] ~= '' then
                table.insert(tmp, self.vocab[seq[i]])
            end
        end
        table.insert(self.seq, tmp)
    end
end

function Reader:get_data()
    if self.offset > #self.seq then
        return nil
    end
    local tmp = self.seq[self.offset]
    local res = {
        input = nerv.MMatrixFloat(#tmp - 1, 1),
        label = nerv.MMatrixFloat(#tmp - 1, 1),
    }
    for i = 1, #tmp - 1 do
        res.input[i - 1][0] = tmp[i]
        res.label[i - 1][0] = tmp[i + 1]
    end
    self.offset = self.offset + 1
    return res
end