--- Implements a sequence-level chopped and shuffled buffer which
-- shall be used for cyclic NNs.
-- @author Qi Liu <liuq901@163.com>
--- The class for a sequence-level chopped and shuffled buffer which
-- shall be used for cyclic NNs.
-- @type nerv.SeqBuffer
local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')
--- The constructor.
-- @param global_conf a table describing the computation state and providing
-- with some global settings
--
-- The following fields in `global_conf` will be used:
--
-- * `use_cpu`: whether to provide with the chunks/"mini-batches" stored in the
-- main memory on invocation of `get_data()`
-- * `mmat_type`: the class used for creating matrices in CPU computation
-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices
-- in GPU computation
--
-- @param buffer_conf a table providing with settings dedicated for the buffer.
-- Available fields includes:
--
-- * `readers`: an array of `nerv.DataReader` instances specifying the
-- readers used to read data
-- * `batch_size`: the number of rows for each batch matrix
-- * `chunk_size`: the length of the BPTT context (number of batch
-- matrices to provide upon each invocation of `get_data()`)
-- * `buffer_size`: the number of frames to be buffered and shuffled at once (shuffle
-- in the sequence level, not in the frame level)
-- * `randomize`: shuffle the buffer after filled if true
-- * `nn_act_default`: the default value to fill into the "holes" (non-data
-- frames)
function SeqBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.batch_size = buffer_conf.batch_size
self.chunk_size = buffer_conf.chunk_size
self.buffer_size = buffer_conf.buffer_size
self.randomize = buffer_conf.randomize
self.readers = {}
for _, v in ipairs(buffer_conf.readers) do
table.insert(self.readers, v.reader)
end
self.nn_act_default = buffer_conf.nn_act_default
if self.nn_act_default == nil then
self.nn_act_default = 0
end
self.mat_type = self.gconf.mmat_type
self.queue = {}
self.head = 1
self.tail = 0
self.offset = 1
self.buffer = {}
self.length = {}
self.index = {}
self.complete = false
end
function SeqBuffer:new_mini_batch()
local res = {}
res.data = {}
res.new_seq = {}
res.seq_length = {}
for i = 1, self.batch_size do
res.seq_length[i] = 0
end
return res
end
local function random_shuffle(a)
for i = #a, 2, -1 do
local j = math.random(i)
a[i], a[j] = a[j], a[i]
end
end
function SeqBuffer:fill_buffer()
if self.complete then
return false
end
local t = os.clock()
self.buffer = {}
self.length = {}
local size = 0
while size < self.buffer_size do
local drow = nil
local data = {}
for i = 1, #self.readers do
local tmp = self.readers[i]:get_data()
if tmp == nil then
self.complete = true
break
end
for id, d in pairs(tmp) do
if drow == nil then
drow = d:nrow()
elseif d:nrow() ~= drow then
nerv.error('readers provides with inconsistent rows of data')
end
data[id] = d
end
end
if self.complete then
break
end
size = size + drow
table.insert(self.buffer, data)
table.insert(self.length, drow)
end
self.index = {}
for i = 1, #self.buffer do
self.index[i] = i
end
if self.randomize then
random_shuffle(self.index)
end
self.offset = 1
collectgarbage('collect')
nerv.info('%.3fs to fill the buffer', os.clock() - t)
return #self.buffer > 0
end
function SeqBuffer:get_buffered_data()
if self.offset > #self.buffer then
if not self:fill_buffer() then
return nil
end
end
local id = self.index[self.offset]
self.offset = self.offset + 1
return self.buffer[id], self.length[id]
end
function SeqBuffer:saturate(batch)
if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
return true
end
local data, drow = self:get_buffered_data()
if data == nil then
return false
end
local offset = 0
local head = self.head
while offset < drow do
local last = math.min(offset + self.chunk_size, drow)
if head > self.tail then
self.tail = self.tail + 1
self.queue[self.tail] = self:new_mini_batch()
end
self.queue[head].seq_length[batch] = last - offset
if offset == 0 then
table.insert(self.queue[head].new_seq, batch)
end
local mini_batch = self.queue[head].data
for id, d in pairs(data) do
if mini_batch[id] == nil then
mini_batch[id] = {}
end
local tmp = mini_batch[id]
for i = offset + 1, last do
local chunk = i - offset
if tmp[chunk] == nil then
tmp[chunk] = self.mat_type(self.batch_size, d:ncol())
tmp[chunk]:fill(self.nn_act_default)
end
tmp[chunk]:copy_from(d, i - 1, i, batch - 1)
end
end
head = head + 1
offset = last
end
return true
end
--- Get a batch group from the buffer.
-- See `nerv.DataBuffer` for reference.
function SeqBuffer:get_data()
local has_data = false
for i = 1, self.batch_size do
if self:saturate(i) then
has_data = true
end
end
if not has_data then
return nil
end
local res = self.queue[self.head]
self.queue[self.head] = nil
self.head = self.head + 1
if not self.gconf.use_cpu then
for id, d in pairs(res.data) do
for i = 1, #d do
d[i] = self.gconf.cumat_type.new_from_host(d[i])
end
end
end
return res
end