diff options
author | Determinant <[email protected]> | 2015-05-26 23:58:32 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-05-26 23:58:32 +0800 |
commit | f8543464c13584e39bfacee694ee1ed80ac121f4 (patch) | |
tree | 3e29ffd5205659fbf3f908b5406522e4bab1c2e9 /layer/affine.lua | |
parent | 910640c0ef7c43d586180241f79723973e0e7d35 (diff) |
fix a severe bug in memory management of userdata
Diffstat (limited to 'layer/affine.lua')
-rw-r--r-- | layer/affine.lua | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/layer/affine.lua b/layer/affine.lua index cd2ba0b..221aacd 100644 --- a/layer/affine.lua +++ b/layer/affine.lua @@ -14,23 +14,31 @@ function AffineLayer:__init(id, global_conf, ltp, bp) self.ltp = ltp self.bp = bp self.gconf = global_conf +end + +function AffineLayer:init() -- linear transform correction - self.ltc = ltp:create() + self.ltc = self.ltp.trans:create() self.ltc:fill(0) -- bias correction - self.bc = bp:create() + self.bc = self.bp.trans:create() self.bc:fill(0) end function nerv.AffineLayer:update(bp_err, input, output) + local ltp = self.ltp.trans + local bp = self.bp.trans + local ltc = self.ltc + local bc = self.bc + local gconf = self.gconf -- momentum gain local mmt_gain = 1.0 / (1.0 - gconf.momentum); - local n = input.nrow() * mmt_gain + local n = input:nrow() * mmt_gain -- update corrections (accumulated errors) ltc:mul(input, bp_err, 1.0, gconf.momentum, 'T', 'N') bc:add(bc, bp_err:colsum(), gconf.momentum, 1.0) -- perform update - ltp:add(lpc, ltc, 1.0, -gconf.lrate / n) + ltp:add(ltp, ltc, 1.0, -gconf.lrate / n) bp:add(bp, bc, 1.0, -gconf.lrate / n) -- weight decay ltp:add(ltp, ltp, 1.0, -gconf.lrate * gconf.wcost) @@ -38,11 +46,11 @@ end function nerv.AffineLayer:propagate(input, output) -- apply linear transform - output:mul(input, self.ltp, 'N', 'N') + output:mul(input, self.ltp.trans, 1.0, 0.0, 'N', 'N') -- add bias - output:add_row(self.bp, 1.0) + output:add_row(self.bp.trans, 1.0) end function nerv.AffineLayer:back_propagate(next_bp_err, bp_err, input, output) - next_bp_err:mul(bp_err, self.ltp, 'N', 'T') + next_bp_err:mul(bp_err, self.ltp.trans, 1.0, 0.0, 'N', 'T') end |