aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/seq_buffer.lua
blob: 029e7b83ae7a9c966b881e336facbf9399f84db9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')

function SeqBuffer:__init(global_conf, buffer_conf)
    self.gconf = global_conf

    self.batch_size = buffer_conf.batch_size
    self.chunk_size = buffer_conf.chunk_size
    self.readers = {}
    for _, v in ipairs(buffer_conf.readers) do
        table.insert(self.readers, v.reader)
    end
    self.nn_act_default = buffer_conf.nn_act_default
    if self.nn_act_default == nil then
        self.nn_act_default = 0
    end

    self.mat_type = self.gconf.mmat_type
    self.queue = {}
    self.head = 1
    self.tail = 0
end

function SeqBuffer:new_mini_batch()
    local res = {}
    res.data = {}
    res.new_seq = {}
    res.seq_length = {}
    for i = 1, self.batch_size do
        res.seq_length[i] = 0
    end
    return res
end

function SeqBuffer:saturate(batch)
    if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
        return true
    end
    local data = {}
    local drow = nil
    for i = 1, #self.readers do
        local tmp = self.readers[i]:get_data()
        if tmp == nil then
            return false
        end
        for id, d in pairs(tmp) do
            if drow == nil then
                drow = d:nrow()
            elseif d:nrow() ~= drow then
                nerv.error('readers provides with inconsistent rows of data')
            end
            data[id] = d
        end
    end
    local offset = 0
    local head = self.head
    while offset < drow do
        local last = math.min(offset + self.chunk_size, drow)
        if head > self.tail then
            self.tail = self.tail + 1
            self.queue[self.tail] = self:new_mini_batch()
        end
        self.queue[head].seq_length[batch] = last - offset
        if offset == 0 then
            table.insert(self.queue[head].new_seq, batch)
        end
        local mini_batch = self.queue[head].data
        for id, d in pairs(data) do
            if mini_batch[id] == nil then
                mini_batch[id] = {}
            end
            local tmp = mini_batch[id]
            for i = offset + 1, last do
                local chunk = i - offset
                if tmp[chunk] == nil then
                    tmp[chunk] = self.mat_type(self.batch_size, d:ncol())
                    tmp[chunk]:fill(self.nn_act_default)
                end
                tmp[chunk]:copy_from(d, i - 1, i, batch - 1)
            end
        end
        head = head + 1
        offset = last
    end
    return true
end

function SeqBuffer:get_data()
    local has_data = false
    for i = 1, self.batch_size do
        if self:saturate(i) then
            has_data = true
        end
    end
    if not has_data then
        return nil
    end
    local res = self.queue[self.head]
    self.queue[self.head] = nil
    self.head = self.head + 1
    if not self.gconf.use_cpu then
        for id, d in pairs(res.data) do
            for i = 1, #d do
                d[i] = self.gconf.cumat_type.new_from_host(d[i])
            end
        end
    end
    return res
end