aboutsummaryrefslogtreecommitdiff
path: root/layer/affine.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-05-26 23:58:32 +0800
committerDeterminant <ted.sybil@gmail.com>2015-05-26 23:58:32 +0800
commitf8543464c13584e39bfacee694ee1ed80ac121f4 (patch)
tree3e29ffd5205659fbf3f908b5406522e4bab1c2e9 /layer/affine.lua
parent910640c0ef7c43d586180241f79723973e0e7d35 (diff)
fix a severe bug in memory management of userdata
Diffstat (limited to 'layer/affine.lua')
-rw-r--r--layer/affine.lua22
1 files changed, 15 insertions, 7 deletions
diff --git a/layer/affine.lua b/layer/affine.lua
index cd2ba0b..221aacd 100644
--- a/layer/affine.lua
+++ b/layer/affine.lua
@@ -14,23 +14,31 @@ function AffineLayer:__init(id, global_conf, ltp, bp)
self.ltp = ltp
self.bp = bp
self.gconf = global_conf
+end
+
+function AffineLayer:init()
-- linear transform correction
- self.ltc = ltp:create()
+ self.ltc = self.ltp.trans:create()
self.ltc:fill(0)
-- bias correction
- self.bc = bp:create()
+ self.bc = self.bp.trans:create()
self.bc:fill(0)
end
function nerv.AffineLayer:update(bp_err, input, output)
+ local ltp = self.ltp.trans
+ local bp = self.bp.trans
+ local ltc = self.ltc
+ local bc = self.bc
+ local gconf = self.gconf
-- momentum gain
local mmt_gain = 1.0 / (1.0 - gconf.momentum);
- local n = input.nrow() * mmt_gain
+ local n = input:nrow() * mmt_gain
-- update corrections (accumulated errors)
ltc:mul(input, bp_err, 1.0, gconf.momentum, 'T', 'N')
bc:add(bc, bp_err:colsum(), gconf.momentum, 1.0)
-- perform update
- ltp:add(lpc, ltc, 1.0, -gconf.lrate / n)
+ ltp:add(ltp, ltc, 1.0, -gconf.lrate / n)
bp:add(bp, bc, 1.0, -gconf.lrate / n)
-- weight decay
ltp:add(ltp, ltp, 1.0, -gconf.lrate * gconf.wcost)
@@ -38,11 +46,11 @@ end
function nerv.AffineLayer:propagate(input, output)
-- apply linear transform
- output:mul(input, self.ltp, 'N', 'N')
+ output:mul(input, self.ltp.trans, 1.0, 0.0, 'N', 'N')
-- add bias
- output:add_row(self.bp, 1.0)
+ output:add_row(self.bp.trans, 1.0)
end
function nerv.AffineLayer:back_propagate(next_bp_err, bp_err, input, output)
- next_bp_err:mul(bp_err, self.ltp, 'N', 'T')
+ next_bp_err:mul(bp_err, self.ltp.trans, 1.0, 0.0, 'N', 'T')
end