blob: d175d020d2ff66dea425351b13548529c1bd42ec (
plain) (
tree)
|
|
--- Implements the concept of groups of parameters (`nerv.Param`) and
-- computation nodes (`nerv.Layer`).
--- The class describing a group of parameters (an internal state) that can be
-- bound to layers. This class also implements the *chunk* interface (see
-- `nerv.ChunkFile`) , which means instances of `nerv.Param` can be exported to
-- chunk files as chunks.
-- @type nerv.Param
local Param = nerv.class('nerv.Param')
--- The constructor.
-- @param id the identifier for the group of parameters
-- @param global_conf a table describing the computation state and providing
-- with some global settings
function Param:__init(id, global_conf)
self.id = id
self.gconf = global_conf
end
--- Retrieve the metadata of the parameter group. This function implements the
-- *chunk* interface.
-- @return a table containing all metadata
function Param:get_info()
return self.info
end
--- Set the metadata of the parameter group. This function implements the
-- *chunk* interface.
-- @param info a table containing all metadata
function Param:set_info(info)
self.info = info
end
--- Read from the given file handle. This function implements the
-- *chunk* interface.
-- @param handle the file handle
function Param:read(handle)
nerv.error_method_not_implemented()
end
--- Write to the given file handle. This function implements the
-- *chunk* interface.
-- @param handle the file handle
function Param:write(handle)
nerv.error_method_not_implemented()
end
--- Generate zero.
-- @return zero
function Param.gen_zero()
return 0
end
--- The class describing a single computation node which calculates from the
-- input ports to the output ports which could be the input of others.
-- @type nerv.Layer
local Layer = nerv.class('nerv.Layer')
--- The constructor. All inheriting classes should call this base constructor to
-- initialize some predefined fields (of `self`):
--
-- * `id`: the identifier of the layer
-- * `gconf`: a table describing the computation state and providing
-- with some global settings
-- * `lconf`: a table providing with settings dedicated for the layer. There
-- are some fields considered to be "standard" and shared by all
-- layers:
-- * `dim_in`: an array of each input port dimension (width) with order
-- * `dim_out`: an array of each output port dimension (width) with order
-- * `params`: optional, a table containing pairs of the manually bound
-- parameter name used by the layer and parameter id used to find the
-- parameter in the parameter repo
-- * `pr`: optional, the parameter repo (see `nerv.ParamRepo`) to find
-- parameters while binding, used by `nerv.Layer.find_param`
-- * `mat_type`: the type of matrix should be used when storing intermediate
-- results
-- * `loc_type`: a value from `nerv.ParamRepo.LOC_TYPES` indicating whether the
-- storage of `nerv.Param` instances is on host or device RAM
-- * `dim_in`: an array of each input port dimension (width) with order
-- * `dim_out`: an array of each output port dimension (width) with order
--
-- @param id the identifier
-- @param global_conf see `self.gconf`
-- @param layer_conf see `self.lconf`
function Layer:__init(id, global_conf, layer_conf)
self.id = id
self.gconf = global_conf
self.lconf = layer_conf
if self.gconf.use_cpu then
self.mat_type = self.gconf.mmat_type
self.loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
else
self.mat_type = self.gconf.cumat_type
self.loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
self.dim_in = layer_conf.dim_in
self.dim_out = layer_conf.dim_out
end
--- Initialize the layer, called for each epoch.
function Layer:init(batch_size)
nerv.error_method_not_implemented()
end
--- Update (change the state of) the bound (tied) parameter according to the
-- calculation.
-- @param bp_err an array of row-major matrices storing the error
-- back-propagated from the output ports
-- @param input an array of row-major matrices storing the input before the
-- forward propagation
-- @param ouput an array of row-major matrices storing the output after the
-- forward propagation
-- @param t BPTT time `t`
function Layer:update(bp_err, input, output, t)
nerv.error_method_not_implemented()
end
--- Calculate the values in output ports according to the input.
-- @param input an array of row-major matrices storing the input before the
-- forward propagation
-- @param ouput an array of row-major matrices storing the output after the
-- forward propagation
-- @param t BPTT time `t`
function Layer:propagate(input, output
|