aboutsummaryrefslogblamecommitdiff
path: root/nerv/layer/init.lua
blob: 054784b2b2b5109d6769ebc9325dd0bb06daf844 (plain) (tree)
1
2
3
4
5
6
7
8







                                                                               


                                      




                                                                            
                                      
                
                            

   



                                                                              
                         


                    



                                                                         
                             


                    



                                                                 
                           


                                       



                                                                



                                       

                  
 
                         


            



                                                                            

                                      


























                                                                                
                                                  











                                                          

   

                                                
                               


                                       










                                                                            


                                       







                                                                           


                                       











                                                                           

                                       
 




                                                                             
                                             

                                      








                                                           
 


                                                                          



                                       



                                                                            



                                       


                                                               
                        


                                    



                                          



                                    



                                                                         
                               
                                                         

   















                                                                                     

                                   
       








                                                                                                          
           

                                          
           




                                                                                                       
       
                                    

                                                                                             




                                                      

                                                         
                          


            
                         

                           
                        




                              
                           
                            
                        

                           
                       
                             
                            
                         
                        






                                                                             
--- Implements the concept of groups of parameters (`nerv.Param`) and
-- computation nodes (`nerv.Layer`).

--- The class describing a group of parameters (an internal state) that can be
-- bound to layers. This class also implements the *chunk* interface (see
-- `nerv.ChunkFile`) , which means instances of `nerv.Param` can be exported to
-- chunk files as chunks.
-- @type nerv.Param

local Param = nerv.class('nerv.Param')

--- The constructor.
-- @param id the identifier for the group of parameters
-- @param global_conf a table describing the computation state and providing
-- with some global settings

function Param:__init(id, global_conf)
    self.id = id
    self.gconf = global_conf
end

--- Retrieve the metadata of the parameter group. This function implements the
-- *chunk* interface.
-- @return a table containing all metadata

function Param:get_info()
    return self.info
end

--- Set the metadata of the parameter group. This function implements the
-- *chunk* interface.
-- @param info a table containing all metadata

function Param:set_info(info)
    self.info = info
end

--- Read from the given file handle. This function implements the
-- *chunk* interface.
-- @param handle the file handle

function Param:read(handle)
    nerv.error_method_not_implemented()
end

--- Write to the given file handle. This function implements the
-- *chunk* interface.
-- @param handle the file handle

function Param:write(handle)
    nerv.error_method_not_implemented()
end

--- Generate zero.
-- @return zero

function Param.gen_zero()
    return 0
end

--- The class describing a single computation node which calculates from the
-- input ports to the output ports which could be the input of others.
-- @type nerv.Layer

local Layer = nerv.class('nerv.Layer')

--- The constructor. All inheriting classes should call this base constructor to
-- initialize some predefined fields (of `self`):
--
-- * `id`: the identifier of the layer
-- * `gconf`: a table describing the computation state and providing
--   with some global settings
-- * `lconf`: a table providing with settings dedicated for the layer. There
--   are some fields considered to be "standard" and shared by all
--   layers:
--      * `dim_in`: an array of each input port dimension (width) with order
--      * `dim_out`: an array of each output port dimension (width) with order
--      * `params`: optional, a table containing pairs of the manually bound
--        parameter name used by the layer and parameter id used to find the
--        parameter in the parameter repo
--      * `pr`: optional, the parameter repo (see `nerv.ParamRepo`)  to find
--        parameters while binding, used by `nerv.Layer.find_param`
-- * `mat_type`: the type of matrix should be used when storing intermediate
--   results
-- * `loc_type`: a value from `nerv.ParamRepo.LOC_TYPES` indicating whether the
--    storage of `nerv.Param` instances is on host or device RAM
-- * `dim_in`: an array of each input port dimension (width) with order
-- * `dim_out`: an array of each output port dimension (width) with order
--
-- @param id the identifier
-- @param global_conf see `self.gconf`
-- @param layer_conf see `self.lconf`

function Layer:__init(id, global_conf, layer_conf)
    self.id = id
    self.gconf = global_conf
    self.lconf = layer_conf
    if self.gconf.use_cpu then
        self.mat_type = self.gconf.mmat_type
        self.loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    else
        self.mat_type = self.gconf.cumat_type
        self.loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
    end
    self.dim_in = layer_conf.dim_in
    self.dim_out = layer_conf.dim_out
end

--- Initialize the layer, called for each epoch.

function Layer:init(batch_size)
    nerv.error_method_not_implemented()
end

--- Update (change the state of) the bound (tied) parameter according to the
-- calculation.
-- @param bp_err an array of row-major matrices storing the error
-- back-propagated from the output ports
-- @param input an array of row-major matrices storing the input before the
-- forward propagation
-- @param ouput an array of row-major matrices storing the output after the
-- forward propagation
-- @param t BPTT time `t`

function Layer:update(bp_err, input, output, t)
    nerv.error_method_not_implemented()
end

--- Calculate the values in output ports according to the input.
-- @param input an array of row-major matrices storing the input before the
-- forward propagation
-- @param ouput an array of r