blob: 8cde1b357e4a0a0b7858b140410fd2c766b676a7 (
plain) (
tree)
|
|
--- Implements a sequence-level chopped and shuffled buffer which
-- shall be used for cyclic NNs.
-- @author Qi Liu <liuq901@163.com>
--- The class for a sequence-level chopped and shuffled buffer which
-- shall be used for cyclic NNs.
-- @type nerv.SeqBuffer
local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')
--- The constructor.
-- @param global_conf a table describing the computation state and providing
-- with some global settings
--
-- The following fields in `global_conf` will be used:
--
-- * `use_cpu`: whether to provide with the chunks/"mini-batches" stored in the
-- main memory on invocation of `get_data()`
-- * `mmat_type`: the class used for creating matrices in CPU computation
-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices
-- in GPU computation
--
-- @param buffer_conf a table providing with settings dedicated for the buffer.
-- Available fields includes:
--
-- * `readers`: an array of `nerv.DataReader` instances specifying the
-- readers used to read data
-- * `batch_size`: the number of rows for each batch matrix
-- * `chunk_size`: the length of the BPTT context (number of batch
-- matrices to provide upon each invocation of `get_data()`)
-- * `nn_act_default`: the default value to fill into the "holes" (non-data
-- frames)
function SeqBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.batch_size = buffer_conf.batch_size
self.chunk_size = buffer_conf.chunk_size
self.readers = {}
for _, v in ipairs(buffer_conf.readers) do
table.insert(self.readers, v.reader)
end
self.nn_act_default = buffer_conf.nn_act_default
if self.nn_act_default == nil then
self.nn_act_default = 0
end
self.mat_type = self.gconf.mmat_type
self.queue = {}
self.head = 1
self.tail = 0
end
function SeqBuffer:new_mini_batch()
local res = {}
res.data = {}
res.new_seq = {}
res.seq_length = {}
for i = 1, self.batch_size do
res.seq_length[i] = 0
end
return res
end
function SeqBuffer:saturate(batch)
if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
return true
end
local data = {}
local drow = nil
for i = 1, #self.readers do
local tmp = self.readers[i]:get_data()
if tmp == nil then
return false
end
for id, d in pairs(tmp) do
if drow == nil then
drow = d:nrow()
elseif d:nrow() ~= drow then
nerv.error('readers provides with inconsistent rows of data')
end
data[id] = d
end
end
local offset = 0
local head = self.head
while offset < drow do
local last = math.min(offset + self.chunk_size, drow)
if head > self.tail then
self.tail = self.tail + 1
self.queue[self.tail] = self:new_mini_batch()
end
self.queue[head].seq_length[batch] = last - offset
if offset == 0 then
table.insert(self.queue[head].new_seq, batch)
end
local mini_batch = self.queue[head].data
for id, d in pairs(data) do
if mini_batch[id] == nil then
mini_batch[id] = {}
end
local tmp = mini_batch[id]
for i = offset + 1, last do
local chunk = i - offset
if tmp[chunk] == nil then
tmp[chunk] = self.mat_type(self.batch_size, d:ncol())
tmp[chunk]:fill(self.nn_act_default)
end
tmp[chunk]:copy_from(d, i - 1, i, batch - 1)
end
end
head = head + 1
offset = last
end
return true
end
--- Get a batch group from the buffer.
-- See `nerv.DataBuffer` for reference.
function SeqBuffer:get_data()
local has_data = false
for i = 1, self.batch_size do
if self:saturate(i) then
has_data = true
end
end
if not has_data then
return nil
end
local res = self.queue[self.head]
self.queue[self.head] = nil
self.head = self.head + 1
if not self.gconf.use_cpu then
for id, d in pairs(res.data) do
for i = 1, #d do
d[i] = self.gconf.cumat_type.new_from_host(d[i])
end
end
end
return res
end
|