--- Contains parameter and layer classes related to linear (or affine)
-- transform.
--- The class for all matrix-based parameters. The class has a single matrix
-- which can be accessed by `self.trans`.
-- @type nerv.MatrixParam
local MatrixParam = nerv.class('nerv.MatrixParam', 'nerv.Param')
--- Check the storage location of the contained matrix. This function is
-- required by `nerv.ParamRepo`.
-- @param checker the callback function for checking
function MatrixParam:check(checker)
-- check trans matrix type
checker(self.trans)
end
--- Read from a file handle. See `nerv.Param.read`.
-- @param handle the file handle
function MatrixParam:read(handle)
self.trans = self.gconf.mmat_type.load(handle)
end
--- Write to a file handle. See `nerv.Param.write`.
-- @param handle the file handle
function MatrixParam:write(handle)
self.trans:save(handle)
end
function MatrixParam:train_init()
self.correction = self.trans:create()
self.correction_acc = self.correction:create()
self.correction:fill(0)
self.correction_acc:fill(0)
end
function MatrixParam:copy(copier)
local target = nerv.MatrixParam(self.id, self.gconf)
target.trans = copier(self.trans)
return target
end
function MatrixParam:_update(alpha, beta)
local gconf = self.gconf
-- momentum gain
local mmt_gain = 1.0 / (1.0 - gconf.momentum)
local n = gconf.batch_size * mmt_gain
-- perform update
if gconf.momentum > 0 then
self.correction:add(self.correction, self.correction_acc, gconf.momentum, 1.0)
self.trans:add(self.trans, self.correction, alpha, -gconf.lrate / n * beta)
else
self.trans:add(self.trans, self.correction_acc, alpha, -gconf.lrate / n * beta)
end
self.correction_acc:fill(0)
end
function MatrixParam:back_propagate_by_gradient(gradient)
self.correction_acc:add(self.correction_acc, gradient, 1.0, 1.0)
end
function M