aboutsummaryrefslogblamecommitdiff
path: root/nerv/io/seq_buffer.lua
blob: 5c60f645d65ecc8eb4b874591433fcb4469fcadb (plain) (tree)
1
2
3
4
5
6
7
8
9
10







                                                                    

                                                                 



















                                                                               

                                                                                     
                                                         


                                                                           




                                                   

                                              



                                              








                                                    




                         












                                   































































                                                                                 
                                  
                                                                                         

                   


                                               

































                                                                         
                                      
                                       
 





















                                                                
--- Implements a sequence-level chopped and shuffled buffer which
-- shall be used for cyclic NNs.
-- @author Qi Liu <liuq901@163.com>

--- The class for a sequence-level chopped and shuffled buffer which
-- shall be used for cyclic NNs.
-- @type nerv.SeqBuffer

local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')

--- The constructor.
-- @param global_conf a table describing the computation state and providing
-- with some global settings
--
-- The following fields in `global_conf` will be used:
--
-- * `use_cpu`: whether to provide with the chunks/"mini-batches" stored in the
--   main memory on invocation of `get_data()`
-- * `mmat_type`: the class used for creating matrices in CPU computation
-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices
-- in GPU computation
--
-- @param buffer_conf a table providing with settings dedicated for the buffer.
-- Available fields includes:
--
-- * `readers`: an array of `nerv.DataReader` instances specifying the
--   readers used to read data
-- * `batch_size`: the number of rows for each batch matrix
-- * `chunk_size`: the length of the BPTT context (number of batch
--   matrices to provide upon each invocation of `get_data()`)
-- * `buffer_size`: the number of frames to be buffered and shuffled at once (shuffle
--   in the sequence level, not in the frame level)
-- * `randomize`: shuffle the buffer after filled if true
-- * `nn_act_default`: the default value to fill into the "holes" (non-data
--   frames)

function SeqBuffer:__init(global_conf, buffer_conf)
    self.gconf = global_conf

    self.batch_size = buffer_conf.batch_size
    self.chunk_size = buffer_conf.chunk_size
    self.buffer_size = buffer_conf.buffer_size
    self.randomize = buffer_conf.randomize
    self.readers = {}
    for _, v in ipairs(buffer_conf.readers) do
        table.insert(self.readers, v.reader)
    end
    self.nn_act_default = buffer_conf.nn_act_default
    if self.nn_act_default == nil then
        self.nn_act_default = 0
    end

    self.mat_type = self.gconf.mmat_type
    self.queue = {}
    self.head = 1
    self.tail = 0
    self.offset = 1
    self.buffer = {}
    self.length = {}
    self.index = {}
    self.complete = false
end

function SeqBuffer:new_mini_batch()
    local res = {}
    res.data = {}
    res.new_seq = {}
    res.seq_length = {}
    for i = 1, self.batch_size do
        res.seq_length[i] = 0
    end
    return res
end

local function random_shuffle(a)
    for i = #a, 2, -1 do
        local j = math.random(i)
        a[i], a[j] = a[j], a[i]
    end
end

function SeqBuffer:fill_buffer()
    if self.complete then
        return false
    end
    local t = os.clock()
    self.buffer = {}
    self.length = {}
    local size = 0
    while size < self.buffer_size do
        local drow = nil
        local data = {}
        for i = 1, #self.readers do
            local tmp = self.readers[i]:get_data()
            if tmp == nil then
                self.complete = true
                break
            end
            for id, d in pairs(tmp) do
                if drow == nil then
                    drow = d:nrow()
                elseif d:nrow() ~= drow then
                    nerv.error('readers provides with inconsistent rows of data')
                end
                data[id] = d
            end
        end
        if self.complete then
            break
        end
        size = size + drow
        table.insert(self.buffer, data)
        table.insert(self.length, drow)
    end
    self.index = {}
    for i = 1, #self.buffer do
        self.index[i] = i
    end
    if self.randomize then
        random_shuffle(self.index)
    end
    self.offset = 1
    collectgarbage('collect')
    nerv.info('%.3fs to fill the buffer', os.clock() - t)
    return #self.buffer > 0
end

function SeqBuffer:get_buffered_data()
    if self.offset > #self.buffer then
        if not self:fill_buffer() then
            return nil
        end
    end
    local id = self.index[self.offset]
    self.offset = self.offset + 1
    return self.buffer[id], self.length[id]
end

function SeqBuffer:saturate(batch)
    if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
        return true
    end
    local data, drow = self:get_buffered_data()
    if data == nil then
        return false
    end
    local offset = 0
    local head = self.head
    while offset < drow do
        local last = math.min(offset + self.chunk_size, drow)
        if head > self.tail then
            self.tail = self.tail + 1
            self.queue[self.tail] = self:new_mini_batch()
        end
        self.queue[head].seq_length[batch] = last - offset
        if offset == 0 then
            table.insert(self.queue[head].new_seq, batch)
        end
        local mini_batch = self.queue[head].data
        for id, d in pairs(data) do
            if mini_batch[id] == nil then
                mini_batch[id] = {}
            end
            local tmp = mini_batch[id]
            for i = offset + 1, last do
                local chunk = i - offset
                if tmp[chunk] == nil then
                    tmp[chunk] = self.mat_type(self.batch_size, d:ncol())
                    tmp[chunk]:fill(self.nn_act_default)
                end
                tmp[chunk]:copy_from(d, i - 1, i, batch - 1)
            end
        end
        head = head + 1
        offset = last
    end
    return true
end

--- Get a batch group from the buffer.
-- See `nerv.DataBuffer` for reference.

function SeqBuffer:get_data()
    local has_data = false
    for i = 1, self.batch_size do
        if self:saturate(i) then
            has_data = true
        end
    end
    if not has_data then
        return nil
    end
    local res = self.queue[self.head]
    self.queue[self.head] = nil
    self.head = self.head + 1
    if not self.gconf.use_cpu then
        for id, d in pairs(res.data) do
            for i = 1, #d do
                d[i] = self.gconf.cumat_type.new_from_host(d[i])
            end
        end
    end
    return res
end