1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
|
local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')
function SeqBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.batch_size = buffer_conf.batch_size
self.chunk_size = buffer_conf.chunk_size
self.readers = buffer_conf.readers
self.nn_act_default = buffer_conf.nn_act_default
if self.nn_act_default == nil then
self.nn_act_default = 0
end
self.mat_type = self.gconf.mmat_type
self.queue = {}
self.head = 1
self.tail = 0
end
function SeqBuffer:new_mini_batch()
local res = {}
res.data = {}
res.new_seq = {}
res.seq_length = {}
for i = 1, self.batch_size do
res.seq_length[i] = 0
end
return res
end
function SeqBuffer:saturate(batch)
if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
return true
end
local data = {}
local drow = nil
for i = 1, #self.readers do
local tmp = self.readers[i]:get_data()
if tmp == nil then
return false
end
for id, d in pairs(tmp) do
if drow == nil then
drow = d:nrow()
elseif d:nrow() ~= drow then
nerv.error('readers provides with inconsistent rows of data')
end
data[id] = d
end
end
local offset = 0
local head = self.head
while offset < drow do
local last = math.min(offset + self.chunk_size, drow)
if head > self.tail then
self.tail = self.tail + 1
self.queue[self.tail] = self:new_mini_batch()
end
self.queue[head].seq_length[batch] = last - offset
if offset == 0 then
table.insert(self.queue[head].new_seq, batch)
end
local mini_batch = self.queue[head].data
for id, d in pairs(data) do
if mini_batch[id] == nil then
mini_batch[id] = {}
end
local tmp = mini_batch[id]
for i = offset + 1, last do
local chunk = i - offset
if tmp[chunk] == nil then
tmp[chunk] = self.mat_type(self.batch_size, d:ncol())
tmp[chunk]:fill(self.nn_act_default)
end
tmp[chunk]:copy_from(d, i - 1, i, batch - 1)
end
end
head = head + 1
offset = last
end
return true
end
function SeqBuffer:get_data()
local has_data = false
for i = 1, self.batch_size do
if self:saturate(i) then
has_data = true
end
end
if not has_data then
return nil
end
local res = self.queue[self.head]
self.queue[self.head] = nil
self.head = self.head + 1
if not self.gconf.use_cpu then
for id, d in pairs(res.data) do
for i = 1, #d do
d[i] = self.gconf.cumat_type.new_from_host(d[i])
end
end
end
return res
end
|