aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/sgd_buffer.lua
blob: 65d6da125f9da5a8675525ea5e7298116b3bcdb8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")

function SGDBuffer:__init(global_conf, buffer_conf)
    self.gconf = global_conf
    self.buffer_size = math.floor(buffer_conf.buffer_size /
                                global_conf.batch_size) * global_conf.batch_size
    self.randomize = buffer_conf.randomize
    if self.randomize == nil then
        self.randomize = false
    end
    local cumat_type = global_conf.cumat_type
    if buffer_conf.use_gpu then
        self.mat_type = cumat_type
        self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
        self.copy_from = cumat_type.copy_fromd
        self.copy_from_reader = cumat_type.copy_fromh
        self.perm_gen = function (x)
            return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
        end
    else
        self.mat_type = global_conf.mmat_type
        self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
        self.copy_from = cumat_type.copy_fromh
        self.perm_gen = nerv.MMatrixFloat.perm_gen
        self.copy_from_reader = self.mat_type.copy_from
    end
    self.head = 0
    self.tail = 0
    self.readers = {}
    for i, reader_spec in ipairs(buffer_conf.readers) do
        local buffs = {}
        for id, width in pairs(reader_spec.data) do
            buffs[id] = {data = self.mat_type(self.buffer_size, width),
                        leftover = nil,
                        width = width}
        end
        table.insert(self.readers, {buffs = buffs,
                                    reader = reader_spec.reader,
                                    tail = 0,
                                    has_leftover = false})
    end
end

function SGDBuffer:saturate()
    local buffer_size = self.buffer_size
    self.head = 0
    self.tail = buffer_size
    for i, reader in ipairs(self.readers) do
        reader.tail = 0
        if reader.has_leftover then
            local lrow
            for id, buff in pairs(reader.buffs) do
                lrow = buff.leftover:nrow()
                if lrow > buffer_size then
                    nerv.error("buffer size is too small to contain leftovers")
                end
                buff.data:copy_from(buff.leftover, 0, lrow)
                buff.leftover = nil
            end
            nerv.info("buffer leftover: %d", lrow)
            reader.tail = lrow
            reader.has_leftover = false
        end
        while reader.tail < buffer_size do
            local data = reader.reader:get_data()
            if data == nil then
                break
            end
            local drow = nil
            for id, d in pairs(data) do
                if drow == nil then
                    drow = d:nrow()
                elseif d:nrow() ~= drow then
                    nerv.error("reader provides with inconsistent rows of data")
                end
            end
            local remain = buffer_size - reader.tail
            if drow > remain then
                for id, buff in pairs(reader.buffs) do
                    local d = data[id]
                    if d == nil then
                        nerv.error("reader does not provide data for %s", id)
                    end
                    buff.leftover = self.mat_type(drow - remain,
                                                  buff.width)
                    self.copy_from_reader(buff.leftover, d, remain, drow)
                end
                drow = remain
                reader.has_leftover = true
            end
            for id, buff in pairs(reader.buffs) do
                self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail)
            end
            reader.tail = reader.tail + drow
        end
        self.tail = math.min(self.tail, reader.tail)
    end
    self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
    collectgarbage("collect")
    return self.tail >= self.gconf.batch_size
end

function SGDBuffer:get_data()
    local batch_size = self.gconf.batch_size
    if self.head >= self.tail then -- buffer is empty
        local t = os.clock()
        if not self:saturate() then
            return nil -- the remaining data cannot build a batch
        end
        --nerv.info("%.3fs to fill the buffer", os.clock() - t)
    end
    if self.head + batch_size > self.tail then
        return nil -- the remaining data cannot build a batch
    end
    local res = {}
    for i, reader in ipairs(self.readers) do
        for id, buff in pairs(reader.buffs) do
            local batch = self.gconf.cumat_type(batch_size, buff.width)
            if self.randomize then
                self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
            else
                self.copy_from(batch, buff.data, self.head, self.head + batch_size)
            end
            res[id] = batch
        end
    end
    self.head = self.head + batch_size
    return res
end