aboutsummaryrefslogblamecommitdiff
path: root/nerv/io/sgd_buffer.lua
blob: d78f6d1d56ec16d5bd948a504afecef7b33e69c4 (plain) (tree)
1
2
3
4
5
6
7
8
9



                                                                 
                                            
                                                           
                                                                      
                                          
                                      
                                             




                                                    

                                  







                                                                          




                                                                          








                                                                               
                                                  
       
                                                    





                                                        
                                                                       
                                       






















                                                                               
                                   
               
                                                    






















                                                                                


                                                                         




                                                  
                                                                                




                                                    
                                                                       
                             
                                       


                             
                                      
                                                     
                            
                                                            

                                                                 


                                      
                                                             
       
                                                                     

                                                             
                                                                   


                                              
                                                                             
                                  
                                                                                      
                
                                                                                          
               


                           
                                             

              
local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")

function SGDBuffer:__init(global_conf, buffer_conf)
    self.gconf = global_conf
    self.batch_size = buffer_conf.batch_size
    self.buffer_size = math.floor(buffer_conf.buffer_size /
                                    self.batch_size) * self.batch_size
    self.randomize = buffer_conf.randomize
    self.consume = buffer_conf.consume
    local cumat_type = global_conf.cumat_type
    if self.gconf.use_cpu then
        self.output_mat_type = self.gconf.mmat_type
    else
        self.output_mat_type = self.gconf.cumat_type
    end
    if buffer_conf.use_gpu then
        self.mat_type = cumat_type
        if self.gconf.use_cpu then
            -- gpu buffer -> cpu training
            nerv.error("not implemeted")
        else
            -- gpu buffer -> gpu training
            self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
            self.copy_from = cumat_type.copy_fromd
        end
        self.perm_gen = function (x)
            return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
        end
    else
        self.mat_type = global_conf.mmat_type
        if self.gconf.use_cpu then
            -- cpu buffer -> cpu training
            self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx
            self.copy_from = gconf.mmat_type.copy_fromh
        else
            -- cpu buffer -> gpu training
            self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
            self.copy_from = cumat_type.copy_fromh
        end
        self.perm_gen = nerv.MMatrixFloat.perm_gen
    end
    self.copy_from_reader = self.mat_type.copy_fromh
    self.head = 0
    self.tail = 0
    self.readers = {}
    for i, reader_spec in ipairs(buffer_conf.readers) do
        local buffs = {}
        for id, width in pairs(reader_spec.data) do
            buffs[id] = {data = self.mat_type(self.buffer_size, width),
                        leftover = nil,
                        width = width}
        end
        table.insert(self.readers, {buffs = buffs,
                                    reader = reader_spec.reader,
                                    tail = 0,
                                    has_leftover = false})
    end
end

function SGDBuffer:saturate()
    local buffer_size = self.buffer_size
    self.head = 0
    self.tail = buffer_size
    for i, reader in ipairs(self.readers) do
        reader.tail = 0
        if reader.has_leftover then
            local lrow
            for id, buff in pairs(reader.buffs) do
                lrow = buff.leftover:nrow()
                if lrow > buffer_size then
                    nerv.error("buffer size is too small to contain leftovers")
                end
                buff.data:copy_from(buff.leftover, 0, lrow)
                buff.leftover = nil
            end
            nerv.info("buffer leftover: %d\n", lrow)
            reader.tail = lrow
            reader.has_leftover = false
        end
        while reader.tail < buffer_size do
            local data = reader.reader:get_data()
            if data == nil then
                break
            end
            local drow = nil
            for id, d in pairs(data) do
                if drow == nil then
                    drow = d:nrow()
                elseif d:nrow() ~= drow then
                    nerv.error("reader provides with inconsistent rows of data")
                end
            end
            local remain = buffer_size - reader.tail
            if drow > remain then
                for id, buff in pairs(reader.buffs) do
                    local d = data[id]
                    if d == nil then
                        nerv.error("reader does not provide data for %s", id)
                    end
                    buff.leftover = self.mat_type(drow - remain,
                                                  buff.width)
                    self.copy_from_reader(buff.leftover, d, remain, drow)
                end
                drow = remain
                reader.has_leftover = true
            end
            for id, buff in pairs(reader.buffs) do
                self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail)
            end
            reader.tail = reader.tail + drow
        end
        self.tail = math.min(self.tail, reader.tail)
    end
    self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
    collectgarbage("collect")
    return self.tail >= self.batch_size
end

function SGDBuffer:get_data()
    local batch_size = self.batch_size
    if self.head >= self.tail then -- buffer is empty
        local t = os.clock()
        if (not self:saturate()) and (not self.consume) then
            return nil -- the remaining data cannot build a batch
        end
        if self.tail == self.head then
            return nil -- nothing left
        end
        nerv.info("%.3fs to fill the buffer", os.clock() - t)
    end
    if self.head + batch_size > self.tail and (not self.consume) then
        return nil -- the remaining data cannot build a batch
    end
    actual_batch_size = math.min(batch_size, self.tail - self.head)
    local res = {}
    for i, reader in ipairs(self.readers) do
        for id, buff in pairs(reader.buffs) do
            local batch = self.output_mat_type(actual_batch_size, buff.width)
            if self.randomize then
                self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
            else
                self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
            end
            res[id] = batch
        end
    end
    self.head = self.head + actual_batch_size
    return res
end