aboutsummaryrefslogtreecommitdiff
path: root/layer
diff options
context:
space:
mode:
Diffstat (limited to 'layer')
-rw-r--r--layer/affine.lua22
-rw-r--r--layer/init.lua4
2 files changed, 19 insertions, 7 deletions
diff --git a/layer/affine.lua b/layer/affine.lua
index cd2ba0b..221aacd 100644
--- a/layer/affine.lua
+++ b/layer/affine.lua
@@ -14,23 +14,31 @@ function AffineLayer:__init(id, global_conf, ltp, bp)
self.ltp = ltp
self.bp = bp
self.gconf = global_conf
+end
+
+function AffineLayer:init()
-- linear transform correction
- self.ltc = ltp:create()
+ self.ltc = self.ltp.trans:create()
self.ltc:fill(0)
-- bias correction
- self.bc = bp:create()
+ self.bc = self.bp.trans:create()
self.bc:fill(0)
end
function nerv.AffineLayer:update(bp_err, input, output)
+ local ltp = self.ltp.trans
+ local bp = self.bp.trans
+ local ltc = self.ltc
+ local bc = self.bc
+ local gconf = self.gconf
-- momentum gain
local mmt_gain = 1.0 / (1.0 - gconf.momentum);
- local n = input.nrow() * mmt_gain
+ local n = input:nrow() * mmt_gain
-- update corrections (accumulated errors)
ltc:mul(input, bp_err, 1.0, gconf.momentum, 'T', 'N')
bc:add(bc, bp_err:colsum(), gconf.momentum, 1.0)
-- perform update
- ltp:add(lpc, ltc, 1.0, -gconf.lrate / n)
+ ltp:add(ltp, ltc, 1.0, -gconf.lrate / n)
bp:add(bp, bc, 1.0, -gconf.lrate / n)
-- weight decay
ltp:add(ltp, ltp, 1.0, -gconf.lrate * gconf.wcost)
@@ -38,11 +46,11 @@ end
function nerv.AffineLayer:propagate(input, output)
-- apply linear transform
- output:mul(input, self.ltp, 'N', 'N')
+ output:mul(input, self.ltp.trans, 1.0, 0.0, 'N', 'N')
-- add bias
- output:add_row(self.bp, 1.0)
+ output:add_row(self.bp.trans, 1.0)
end
function nerv.AffineLayer:back_propagate(next_bp_err, bp_err, input, output)
- next_bp_err:mul(bp_err, self.ltp, 'N', 'T')
+ next_bp_err:mul(bp_err, self.ltp.trans, 1.0, 0.0, 'N', 'T')
end
diff --git a/layer/init.lua b/layer/init.lua
index 6923dbd..0f0afe8 100644
--- a/layer/init.lua
+++ b/layer/init.lua
@@ -6,6 +6,10 @@ function nerv.Param:__init(id)
self.id = id
end
+function nerv.Param:init(id)
+ nerv.error_method_not_implemented()
+end
+
function nerv.Param:get_info()
return self.info
end