1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
|
local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
function SGDBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.buffer_size = math.floor(buffer_conf.buffer_size /
global_conf.batch_size) * global_conf.batch_size
self.randomize = buffer_conf.randomize
if self.randomize == nil then
self.randomize = false
end
self.head = 0
self.tail = 0
self.readers = {}
for i, reader_spec in ipairs(buffer_conf.readers) do
local buffs = {}
for id, width in pairs(reader_spec.data) do
buffs[id] = {data = global_conf.mmat_type(self.buffer_size, width),
leftover = {},
width = width}
end
table.insert(self.readers, {buffs = buffs,
reader = reader_spec.reader,
tail = 0,
has_leftover = false})
end
end
function SGDBuffer:saturate()
local buffer_size = self.buffer_size
self.head = 0
self.tail = buffer_size
for i, reader in ipairs(self.readers) do
reader.tail = 0
if reader.has_leftover then
local lrow
for id, buff in pairs(reader.buffs) do
lrow = buff.leftover:nrow()
if lrow > buffer_size then
nerv.error("buffer size is too small to contain leftovers")
end
buff.data:copy_from(buff.leftover, 0, lrow)
buff.leftover = nil
end
nerv.utils.printf("leftover: %d\n", lrow)
reader.tail = lrow
reader.has_leftover = false
end
while reader.tail < buffer_size do
local data = reader.reader:get_data()
if data == nil then
break
end
local drow = nil
for id, d in pairs(data) do
if drow == nil then
drow = d:nrow()
elseif d:nrow() ~= drow then
nerv.error("reader provides with inconsistent rows of data")
end
end
local remain = buffer_size - reader.tail
if drow > remain then
for id, buff in pairs(reader.buffs) do
local d = data[id]
if d == nil then
nerv.error("reader does not provide data for %s", id)
end
buff.leftover = self.gconf.mmat_type(drow - remain,
buff.width)
buff.leftover:copy_from(d, remain, drow)
end
drow = remain
reader.has_leftover = true
end
for id, buff in pairs(reader.buffs) do
buff.data:copy_from(data[id], 0, drow, reader.tail)
end
reader.tail = reader.tail + drow
end
self.tail = math.min(self.tail, reader.tail)
end
self.rand_map = nerv.MMatrixInt.perm_gen(self.tail) -- generate shuffled index
collectgarbage("collect")
return self.tail >= self.gconf.batch_size
end
function SGDBuffer:get_data()
local batch_size = self.gconf.batch_size
if self.head >= self.tail then -- buffer is empty
if not self:saturate() then
return nil -- the remaining data cannot build a batch
end
end
if self.head + batch_size > self.tail then
return nil -- the remaining data cannot build a batch
end
local res = {}
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
local batch = self.gconf.cumat_type(batch_size, buff.width)
if self.randomize then
batch:copy_rows_fromh_by_idx(buff.data, self.rand_map, self.head)
else
batch:copy_fromh(buff.data, self.head, self.head + batch_size)
end
res[id] = batch
end
end
self.head = self.head + batch_size
return res
end
|