1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
|
local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
function SGDBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.batch_size = buffer_conf.batch_size
self.buffer_size = math.floor(buffer_conf.buffer_size /
self.batch_size) * self.batch_size
self.randomize = buffer_conf.randomize
self.consume = buffer_conf.consume
local cumat_type = global_conf.cumat_type
if self.gconf.use_cpu then
self.output_mat_type = self.gconf.mmat_type
else
self.output_mat_type = self.gconf.cumat_type
end
if buffer_conf.use_gpu then
self.mat_type = cumat_type
if self.gconf.use_cpu then
-- gpu buffer -> cpu training
nerv.error("not implemeted")
else
-- gpu buffer -> gpu training
self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
self.copy_from = cumat_type.copy_fromd
end
self.perm_gen = function (x)
return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
end
else
self.mat_type = global_conf.mmat_type
if self.gconf.use_cpu then
-- cpu buffer -> cpu training
self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx
self.copy_from = gconf.mmat_type.copy_fromh
else
-- cpu buffer -> gpu training
self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
self.copy_from = cumat_type.copy_fromh
end
self.perm_gen = nerv.MMatrixFloat.perm_gen
end
self.copy_from_reader = self.mat_type.copy_fromh
self.head = 0
self.tail = 0
self.readers = {}
for i, reader_spec in ipairs(buffer_conf.readers) do
local buffs = {}
for id, width in pairs(reader_spec.data) do
buffs[id] = {data = self.mat_type(self.buffer_size, width),
leftover = nil,
width = width}
end
table.insert(self.readers, {buffs = buffs,
reader = reader_spec.reader,
tail = 0,
has_leftover = false})
end
end
function SGDBuffer:saturate()
local buffer_size = self.buffer_size
self.head = 0
self.tail = buffer_size
for i, reader in ipairs(self.readers) do
reader.tail = 0
if reader.has_leftover then
local lrow
for id, buff in pairs(reader.buffs) do
lrow = buff.leftover:nrow()
if lrow > buffer_size then
nerv.error("buffer size is too small to contain leftovers")
end
buff.data:copy_from(buff.leftover, 0, lrow)
buff.leftover = nil
end
nerv.info("buffer leftover: %d\n", lrow)
reader.tail = lrow
reader.has_leftover = false
end
while reader.tail < buffer_size do
local data = reader.reader:get_data()
if data == nil then
break
end
local drow = nil
for id, d in pairs(data) do
if drow == nil then
drow = d:nrow()
elseif d:nrow() ~= drow then
nerv.error("reader provides with inconsistent rows of data")
end
end
local remain = buffer_size - reader.tail
if drow > remain then
for id, buff in pairs(reader.buffs) do
local d = data[id]
if d == nil then
nerv.error("reader does not provide data for %s", id)
end
buff.leftover = self.mat_type(drow - remain,
buff.width)
self.copy_from_reader(buff.leftover, d, remain, drow)
end
drow = remain
reader.has_leftover = true
end
for id, buff in pairs(reader.buffs) do
self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail)
end
reader.tail = reader.tail + drow
end
self.tail = math.min(self.tail, reader.tail)
end
self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
collectgarbage("collect")
return self.tail >= self.batch_size
end
function SGDBuffer:get_data()
local batch_size = self.batch_size
if self.head >= self.tail then -- buffer is empty
local t = os.clock()
if (not self:saturate()) and (not self.consume) then
return nil -- the remaining data cannot build a batch
end
if self.tail == self.head then
return nil -- nothing left
end
nerv.info("%.3fs to fill the buffer", os.clock() - t)
end
if self.head + batch_size > self.tail and (not self.consume) then
return nil -- the remaining data cannot build a batch
end
actual_batch_size = math.min(batch_size, self.tail - self.head)
local res = {}
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
local batch = self.output_mat_type(actual_batch_size, buff.width)
if self.randomize then
self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
else
self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
end
res[id] = batch
end
end
self.head = self.head + actual_batch_size
return res
end
|