diff options
Diffstat (limited to 'nerv/layer/lstm_gate.lua')
-rw-r--r-- | nerv/layer/lstm_gate.lua | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/nerv/layer/lstm_gate.lua b/nerv/layer/lstm_gate.lua index a3ae797..99bf3ca 100644 --- a/nerv/layer/lstm_gate.lua +++ b/nerv/layer/lstm_gate.lua @@ -9,17 +9,27 @@ function LSTMGateLayer:__init(id, global_conf, layer_conf) end function LSTMGateLayer:bind_params() + local lconf = self.lconf + lconf.no_update_ltp1 = lconf.no_update_ltp1 or lconf.no_update_ltp for i = 1, #self.dim_in do - self["ltp" .. i] = self:find_param("ltp" .. i, self.lconf, self.gconf, + self["ltp" .. i] = self:find_param("ltp" .. i, lconf, self.gconf, nerv.LinearTransParam, {self.dim_in[i], self.dim_out[1]}) if self.param_type[i] == 'D' then self["ltp" .. i].trans:diagonalize() end + local no_update = lconf["no_update_ltp"..i] + if (no_update ~= nil) and no_update or lconf.no_update_all then + self["ltp" .. i].no_update = true + end end - self.bp = self:find_param("bp", self.lconf, self.gconf, + self.bp = self:find_param("bp", lconf, self.gconf, nerv.BiasParam, {1, self.dim_out[1]}, nerv.Param.gen_zero) + local no_update = lconf["no_update_bp"] + if (no_update ~= nil) and no_update or lconf.no_update_all then + self.bp.no_update = true + end end function LSTMGateLayer:init(batch_size) |