aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/layer/affine.lua18
-rw-r--r--nerv/layer/lstm_gate.lua14
2 files changed, 27 insertions, 5 deletions
diff --git a/nerv/layer/affine.lua b/nerv/layer/affine.lua
index b4358ca..3bf5a11 100644
--- a/nerv/layer/affine.lua
+++ b/nerv/layer/affine.lua
@@ -41,6 +41,9 @@ function MatrixParam:copy(copier)
end
function MatrixParam:_update(alpha, beta)
+ if self.no_update then
+ return
+ end
local gconf = self.gconf
-- momentum gain
local mmt_gain = 1.0 / (1.0 - gconf.momentum)
@@ -97,19 +100,28 @@ function AffineLayer:__init(id, global_conf, layer_conf)
end
function AffineLayer:bind_params()
+ local lconf = self.lconf
+ lconf.no_update_ltp1 = lconf.no_update_ltp1 or lconf.no_update_ltp
for i = 1, #self.dim_in do
local pid = "ltp" .. i
local pid_list = i == 1 and {pid, "ltp"} or pid
- self["ltp" .. i] = self:find_param(pid_list, self.lconf, self.gconf,
+ self["ltp" .. i] = self:find_param(pid_list, lconf, self.gconf,
nerv.LinearTransParam,
{self.dim_in[i], self.dim_out[1]})
+ local no_update = lconf["no_update_ltp"..i]
+ if (no_update ~= nil) and no_update or lconf.no_update_all then
+ self["ltp" .. i].no_update = true
+ end
end
self.ltp = self.ltp1 -- alias of ltp1
- self.bp = self:find_param("bp", self.lconf, self.gconf,
+ self.bp = self:find_param("bp", lconf, self.gconf,
nerv.BiasParam,
{1, self.dim_out[1]},
nerv.Param.gen_zero)
-
+ local no_update = lconf["no_update_bp"]
+ if (no_update ~= nil) and no_update or lconf.no_update_all then
+ self.bp.no_update = true
+ end
end
function AffineLayer:init(batch_size)
diff --git a/nerv/layer/lstm_gate.lua b/nerv/layer/lstm_gate.lua
index a3ae797..99bf3ca 100644
--- a/nerv/layer/lstm_gate.lua
+++ b/nerv/layer/lstm_gate.lua
@@ -9,17 +9,27 @@ function LSTMGateLayer:__init(id, global_conf, layer_conf)
end
function LSTMGateLayer:bind_params()
+ local lconf = self.lconf
+ lconf.no_update_ltp1 = lconf.no_update_ltp1 or lconf.no_update_ltp
for i = 1, #self.dim_in do
- self["ltp" .. i] = self:find_param("ltp" .. i, self.lconf, self.gconf,
+ self["ltp" .. i] = self:find_param("ltp" .. i, lconf, self.gconf,
nerv.LinearTransParam,
{self.dim_in[i], self.dim_out[1]})
if self.param_type[i] == 'D' then
self["ltp" .. i].trans:diagonalize()
end
+ local no_update = lconf["no_update_ltp"..i]
+ if (no_update ~= nil) and no_update or lconf.no_update_all then
+ self["ltp" .. i].no_update = true
+ end
end
- self.bp = self:find_param("bp", self.lconf, self.gconf,
+ self.bp = self:find_param("bp", lconf, self.gconf,
nerv.BiasParam, {1, self.dim_out[1]},
nerv.Param.gen_zero)
+ local no_update = lconf["no_update_bp"]
+ if (no_update ~= nil) and no_update or lconf.no_update_all then
+ self.bp.no_update = true
+ end
end
function LSTMGateLayer:init(batch_size)