blob: 16250fd3fdb6123b02beffdf8f158fc8d21ef7a3 (
plain) (
tree)
|
|
--- Contains parameter and layer classes related to linear (or affine)
-- transform.
--- The class for all matrix-based parameters. The class has a single matrix
-- which can be accessed by `self.trans`.
-- @type nerv.MatrixParam
local MatrixParam = nerv.class('nerv.MatrixParam', 'nerv.Param')
--- Check the storage location of the contained matrix. This function is
-- required by `nerv.ParamRepo`.
-- @param checker the callback function for checking
function MatrixParam:check(checker)
-- check trans matrix type
checker(self.trans)
end
--- Read from a file handle. See `nerv.Param.read`.
-- @param handle the file handle
function MatrixParam:read(handle)
self.trans = self.gconf.mmat_type.load(handle)
end
--- Write to a file handle. See `nerv.Param.write`.
-- @param handle the file handle
function MatrixParam:write(handle)
self.trans:save(handle)
end
function MatrixParam:train_init()
self.correction = self.trans:create()
self.correction_acc = self.correction:create()
self.correction:fill(0)
self.correction_acc:fill(0)
end
function MatrixParam:copy(copier)
local target = nerv.MatrixParam(self.id, self.gconf)
target.trans = copier(self.trans)
return target
end
function MatrixParam:_update(alpha, beta)
if self.
|