aboutsummaryrefslogblamecommitdiff
path: root/nerv/layer/affine.lua
blob: 16250fd3fdb6123b02beffdf8f158fc8d21ef7a3 (plain) (tree)
1
2
3
4
5
6
7


                                                                      

                                                                            

                         




                                                                        




                                   
                                                   
                                
                                 
                                                  

   

                                                   
                                  
                           



                                         
                                                  
                           
                               

   





                                                        
                                         


                          
                            

                                                 
                                         
                     
                              
                                                                                      
                                                                                   
        
                                                                                       
       
                               

   

                                                                    

   

                                                            

   

                                         

   
                                          
                            
                                            
                        
   
 

                                                                          

                         

                                                                
                    









                                                                                             
                                                        





                                                                          

                                                                      
                              
                              
                                                       
                                                                       
                                                                  
                                                                              
                                                     


                                                                       
       
                                         
                                                      
                                               

                                                     



                                                                   

   
                                     
                                                   

                                                                                  
                              







                                                                                         
                        

   



                                             
                             
                              
                                              
       
                                

   
                                             

                                                                


                                                                           
               
                                         

   
                                                                       
                              
                                                                                 
                                                                         
       
                                                          
   

                                 
                                                                  
                              
                                
       
             
   









                                                                                
--- Contains parameter and layer classes related to linear (or affine)
-- transform.

--- The class for all matrix-based parameters. The class has a single matrix
-- which can be accessed by `self.trans`.
-- @type nerv.MatrixParam

local MatrixParam = nerv.class('nerv.MatrixParam', 'nerv.Param')

--- Check the storage location of the contained matrix. This function is
-- required by `nerv.ParamRepo`.
-- @param checker the callback function for checking
function MatrixParam:check(checker)
    -- check trans matrix type
    checker(self.trans)
end

--- Read from a file handle. See `nerv.Param.read`.
-- @param handle the file handle
function MatrixParam:read(handle)
    self.trans = self.gconf.mmat_type.load(handle)
end

--- Write to a file handle. See `nerv.Param.write`.
-- @param handle the file handle
function MatrixParam:write(handle)
    self.trans:save(handle)
end

function MatrixParam:train_init()
    self.correction = self.trans:create()
    self.correction_acc = self.correction:create()
    self.correction:fill(0)
    self.correction_acc:fill(0)
end

function MatrixParam:copy(copier)
    local target = nerv.MatrixParam(self.id, self.gconf)
    target.trans = copier(self.trans)
    return target
end

function MatrixParam:_update(alpha, beta)
    if self.