diff options
Diffstat (limited to 'nerv/layer')
-rw-r--r-- | nerv/layer/affine.lua | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/nerv/layer/affine.lua b/nerv/layer/affine.lua index a2809bf..0fcff36 100644 --- a/nerv/layer/affine.lua +++ b/nerv/layer/affine.lua @@ -24,21 +24,21 @@ function MatrixParam:update(gradient) local mmt_gain = 1.0 / (1.0 - gconf.momentum); local n = self.gconf.batch_size * mmt_gain -- perform update - self.trans:add(self.trans, self.correction, 1.0, -gconf.lrate / n) + self.trans:add(self.trans, self.correction, 1.0 - gconf.lrate*gconf.wcost/gconf.batch_size, -gconf.lrate / n) end function LinearTransParam:update(gradient) MatrixParam.update(self, gradient) - local gconf = self.gconf - -- weight decay - self.trans:add(self.trans, self.trans, 1.0, -gconf.lrate * gconf.wcost / gconf.batch_size) + -- local gconf = self.gconf + -- weight decay(put into MatrixParam:update) + -- self.trans:add(self.trans, self.trans, 1.0, -gconf.lrate * gconf.wcost / gconf.batch_size) end function BiasParam:update(gradient) MatrixParam.update(self, gradient) - local gconf = self.gconf + -- local gconf = self.gconf -- weight decay - self.trans:add(self.trans, self.trans, 1.0, -gconf.lrate * gconf.wcost / gconf.batch_size) + -- self.trans:add(self.trans, self.trans, 1.0, -gconf.lrate * gconf.wcost / gconf.batch_size) end function AffineLayer:__init(id, global_conf, layer_conf) @@ -76,12 +76,13 @@ function AffineLayer:update(bp_err, input, output) local gconf = self.gconf if (gconf.momentum > 0) then self.ltp.correction:mul(input[1], bp_err[1], 1.0, gconf.momentum, 'T', 'N') + self.bp.correction:add(self.bp.correction, bp_err[1]:colsum(), gconf.momentum, 1) -- momentum gain local mmt_gain = 1.0 / (1.0 - gconf.momentum); local n = self.gconf.batch_size * mmt_gain -- perform update - self.ltp.trans:add(self.ltp.trans, self.ltp.correction, 1.0, -gconf.lrate / n) - self.bp.trans:add(self.bp.trans, bp_err[1]:colsum(), 1.0-gconf.lrate*gconf.wcost, -gconf.lrate / gconf.batch_size) + self.ltp.trans:add(self.ltp.trans, self.ltp.correction, 1.0-gconf.lrate*gconf.wcost/gconf.batch_size, -gconf.lrate / n) + self.bp.trans:add(self.bp.trans, self.bp.correction, 1.0-gconf.lrate*gconf.wcost/gconf.batch_size, -gconf.lrate / n) else self.ltp.trans:mul(input[1], bp_err[1], -gconf.lrate / gconf.batch_size, 1.0-gconf.lrate*gconf.wcost/gconf.batch_size, 'T', 'N') self.bp.trans:add(self.bp.trans, bp_err[1]:colsum(), 1.0-gconf.lrate*gconf.wcost/gconf.batch_size, -gconf.lrate / gconf.batch_size) |