blob: 2dd2dc063c2c77d13703058b93c55a0837c4fa87 (
plain) (
tree)
|
|
--- Contains parameter and layer classes related to linear (or affine)
-- transform.
--- The class for linear transform parameter.
-- @type nerv.LinearTransParam
local LinearTransParam = nerv.class('nerv.LinearTransParam', 'nerv.MatrixParam')
--- The class for bias parameter (currently implemented as a one-row matrix).
-- @type nerv.BiasParam
local BiasParam = nerv.class('nerv.BiasParam', 'nerv.MatrixParam')
--- The class for all matrix-based parameters. The class has a single matrix
-- which can be accessed by `self.trans`.
-- @type nerv.MatrixParam
local MatrixParam = nerv.class('nerv.MatrixParam', 'nerv.Param')
--- Check the storage location of the contained matrix. This function is
-- required by `nerv.ParamRepo`.
-- @param checker the callback function for checking
function MatrixParam:check(checker)
-- check trans matrix type
checker(self.trans)
end
--- Read from a file handle. See `nerv.Param.read`.
-- @param handle the file handle
function MatrixParam:read(handle)
self.trans = self.gconf.mmat_type.load(handle)
end
--- Write to a file handle. See `nerv.Param.write`.
-- @param handle the file handle
function MatrixParam:write(handle)
self.trans:save(handle)
end
function MatrixParam:train_init()
self.correction = self.trans:create()
self.correction_acc = self.correction:create()
self.correction:fill(0)
self.correction_acc:fill(0)
end
function MatrixParam:copy(copier)
local target = nerv.MatrixParam(self.id, self.gconf)
target.trans = copier(self.trans)
return target
end
function MatrixParam:_update(alpha, beta)
local gconf = self.gconf
-- momentum gain
local mmt_gain = 1.0 / (1.0 - gconf.momentum)
local n = gconf.batch_size * mmt_gain
-- perform update
if gconf.momentum > 0 then
self.correction:add
|