summaryrefslogtreecommitdiff
path: root/lua/reader.lua
blob: 2e51a9cf4d8aa7528a7a821bec9cdc046f5c5d55 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
local Reader = nerv.class('nerv.Reader')

function Reader:__init(vocab_file, input_file)
    self:get_vocab(vocab_file)
    self:get_seq(input_file)
end

function Reader:get_vocab(vocab_file)
    local f = io.open(vocab_file, 'r')
    local id = 0
    self.vocab = {}
    while true do
        local word = f:read()
        if word == nil then
            break
        end
        self.vocab[word] = id
        id = id + 1
    end
    self.size = id
end

function Reader:split(s, t)
    local ret = {}
    for x in (s .. t):gmatch('(.-)' .. t) do
        table.insert(ret, x)
    end
    return ret
end

function Reader:get_seq(input_file)
    local f = io.open(input_file, 'r')
    self.seq = {}
    while true do
        local seq = f:read()
        if seq == nil then
            break
        end
        seq = self:split(seq, ' ')
        local tmp = {}
        for i = 1, #seq do
            if seq[i] ~= '' then
                table.insert(tmp, self.vocab[seq[i]])
            end
        end
        table.insert(self.seq, tmp)
    end
end

function Reader:get_in_out(id, pos)
    return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id]
end

function Reader:get_all_batch(global_conf)
    local data = {}
    local pos = {}
    local offset = 1
    for i = 1, global_conf.batch_size do
        pos[i] = nil
    end
    while true do
        local input = {}
        local output = {}
        for i = 1, global_conf.chunk_size do
            input[i] = global_conf.mmat_type(global_conf.batch_size, 1)
            input[i]:fill(global_conf.nn_act_default)
            output[i] = global_conf.mmat_type(global_conf.batch_size, 1)
            output[i]:fill(global_conf.nn_act_default)
        end
        local seq_start = {}
        local seq_end = {}
        local seq_len = {}
        for i = 1, global_conf.batch_size do
            seq_start[i] = false
            seq_end[i] = false
            seq_len[i] = 0
        end
        local has_new = false
        for i = 1, global_conf.batch_size do
            if pos[i] == nil then
                if offset < #self.seq then
                    seq_start[i] = true
                    pos[i] = {offset, 1}
                    offset = offset + 1
                end
            end
            if pos[i] ~= nil then
                has_new = true
                for j = 1, global_conf.chunk_size do
                    local final
                    input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2])
                    seq_len[i] = j
                    if final then
                        seq_end[i] = true
                        pos[i] = nil
                        break
                    end
                    pos[i][2] = pos[i][2] + 1
                end
            end
        end
        if not has_new then
            break
        end
        for i = 1, global_conf.chunk_size do
            input[i] = global_conf.cumat_type.new_from_host(input[i])
            output[i] = global_conf.cumat_type.new_from_host(output[i])
        end
        table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len})
    end
    return data
end