aboutsummaryrefslogblamecommitdiff
path: root/nerv/io/frm_buffer.lua
blob: 45f73a0d18dd8b9b04ce7f43e83660ea37642641 (plain) (tree)
1
2
3
4
5
6
7
8
9







                                                                            
                                                                 
 
























                                                                               
                                                   
                            
                                            
                                                           
                                                                      
                                          
                                      
                                             




                                                    

                                  







                                                                          




                                                                          








                                                                               
                                                  
       
                                                    





                                                        
                                                                       
                                       








                                                                
                             












                                                                               
                                   
               
                                                    






















                                                                                


                                                                         




                                                  
                                                                                




                                                    
                                                                       
                             
                                       

   


                                      
                             
                                      
                                                     
                            
                                                            

                                                                 


                                      
                                                             
       
                                                                     

                                                             
                                                                   


                                                                

                                              
                                                                             
                                  
                                                                                      
                
                                                                                          
               
                                  

           
                                             

              
--- Implements a frame-level chopped and shuffled buffer which shall be used
-- for acyclic feed forward NNs (`chunk_size = 1`).
-- @author Ted Yin <ted.sybil@gmail.com>

--- The class for a frame-level chopped and shuffled buffer
-- which shall be used for acyclic feed forward NNs 
-- @type nerv.FrmBuffer

local FrmBuffer = nerv.class("nerv.FrmBuffer", "nerv.DataBuffer")

--- The constructor.
-- @param global_conf a table describing the computation state and providing
-- with some global settings
--
-- The following fields in `global_conf` will be used:
--
-- * `use_cpu`: whether to provide with the chunks/"mini-batches" stored in the
--   main memory on invocation of `get_data()`
-- * `mmat_type`: the class used for creating matrices in CPU computation
-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices
-- in GPU computation
--
-- @param buffer_conf a table providing with settings dedicated for the buffer.
-- Available fields includes:
--
-- * `readers`: an array of `nerv.DataReader` instances specifying the
--   readers used to read data
-- * `batch_size`: the number of rows for each batch matrix
-- * `buffer_size`: the number of frames to be buffered and shuffled at once
-- * `randomize`: shuffle the buffer after filled if true
-- * `consume`: drop the last frames which cannot make up a full `batch_size`
--   matrix if false
-- * `use_gpu`: the buffer space will be allocated on the device (graphics
--   card) if true

function FrmBuffer:__init(global_conf, buffer_conf)
    self.gconf = global_conf
    self.batch_size = buffer_conf.batch_size
    self.buffer_size = math.floor(buffer_conf.buffer_size /
                                    self.batch_size) * self.batch_size
    self.randomize = buffer_conf.randomize
    self.consume = buffer_conf.consume
    local cumat_type = global_conf.cumat_type
    if self.gconf.use_cpu then
        self.output_mat_type = self.gconf.mmat_type
    else
        self.output_mat_type = self.gconf.cumat_type
    end
    if buffer_conf.use_gpu then
        self.mat_type = cumat_type
        if self.gconf.use_cpu then
            -- gpu buffer -> cpu training
            nerv.error("not implemeted")
        else
            -- gpu buffer -> gpu training
            self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
            self.copy_from = cumat_type.copy_fromd
        end
        self.perm_gen = function (x)
            return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
        end
    else
        self.mat_type = global_conf.mmat_type
        if self.gconf.use_cpu then
            -- cpu buffer -> cpu training
            self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx
            self.copy_from = gconf.mmat_type.copy_fromh
        else
            -- cpu buffer -> gpu training
            self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
            self.copy_from = cumat_type.copy_fromh
        end
        self.perm_gen = nerv.MMatrixFloat.perm_gen
    end
    self.copy_from_reader = self.mat_type.copy_fromh
    self.head = 0
    self.tail = 0
    self.readers = {}
    for i, reader_spec in ipairs(buffer_conf.readers) do
        local buffs = {}
        for id, width in pairs(reader_spec.data) do
            buffs[id] = {data = self.mat_type(self.buffer_size, width),
                        leftover = nil,
                        width = width}
        end
        table.insert(self.readers, {buffs = buffs,
                                    reader = reader_spec.reader,
                                    tail = 0,
                                    has_leftover = false})
    end
end

function FrmBuffer:saturate()
    local buffer_size = self.buffer_size
    self.head = 0
    self.tail = buffer_size
    for i, reader in ipairs(self.readers) do
        reader.tail = 0
        if reader.has_leftover then
            local lrow
            for id, buff in pairs(reader.buffs) do
                lrow = buff.leftover:nrow()
                if lrow > buffer_size then
                    nerv.error("buffer size is too small to contain leftovers")
                end
                buff.data:copy_from(buff.leftover, 0, lrow)
                buff.leftover = nil
            end
            nerv.info("buffer leftover: %d\n", lrow)
            reader.tail = lrow
            reader.has_leftover = false
        end
        while reader.tail < buffer_size do
            local data = reader.reader:get_data()
            if data == nil then
                break
            end
            local drow = nil
            for id, d in pairs(data) do
                if drow == nil then
                    drow = d:nrow()
                elseif d:nrow() ~= drow then
                    nerv.error("reader provides with inconsistent rows of data")
                end
            end
            local remain = buffer_size - reader.tail
            if drow > remain then
                for id, buff in pairs(reader.buffs) do
                    local d = data[id]
                    if d == nil then
                        nerv.error("reader does not provide data for %s", id)
                    end
                    buff.leftover = self.mat_type(drow - remain,
                                                  buff.width)
                    self.copy_from_reader(buff.leftover, d, remain, drow)
                end
                drow = remain
                reader.has_leftover = true
            end
            for id, buff in pairs(reader.buffs) do
                self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail)
            end
            reader.tail = reader.tail + drow
        end
        self.tail = math.min(self.tail, reader.tail)
    end
    self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
    collectgarbage("collect")
    return self.tail >= self.batch_size
end

--- Get a batch group from the buffer.
-- See `nerv.DataBuffer` for reference

function FrmBuffer:get_data()
    local batch_size = self.batch_size
    if self.head >= self.tail then -- buffer is empty
        local t = os.clock()
        if (not self:saturate()) and (not self.consume) then
            re