aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/lstm_gate.lua
diff options
context:
space:
mode:
authorQi Liu <liuq901@163.com>2016-02-24 16:16:03 +0800
committerQi Liu <liuq901@163.com>2016-02-24 16:16:03 +0800
commita51498d2761714f4e034f036b84b4b89a88e9539 (patch)
treeff78fabd169c3c0346453c7005c84b176ac49ca6 /nerv/layer/lstm_gate.lua
parent9642bd16922b288c81dee25f17373466ae6888c4 (diff)
update LSTM layer
Diffstat (limited to 'nerv/layer/lstm_gate.lua')
-rw-r--r--nerv/layer/lstm_gate.lua7
1 files changed, 7 insertions, 0 deletions
diff --git a/nerv/layer/lstm_gate.lua b/nerv/layer/lstm_gate.lua
index 1963eba..8785b4f 100644
--- a/nerv/layer/lstm_gate.lua
+++ b/nerv/layer/lstm_gate.lua
@@ -5,12 +5,16 @@ function LSTMGateLayer:__init(id, global_conf, layer_conf)
self.id = id
self.dim_in = layer_conf.dim_in
self.dim_out = layer_conf.dim_out
+ self.param_type = layer_conf.param_type
self.gconf = global_conf
for i = 1, #self.dim_in do
self["ltp" .. i] = self:find_param("ltp" .. i, layer_conf, global_conf,
nerv.LinearTransParam,
{self.dim_in[i], self.dim_out[1]})
+ if self.param_type[i] == 'D' then
+ self["ltp" .. i].trans:diagonalize()
+ end
end
self.bp = self:find_param("bp", layer_conf, global_conf,
nerv.BiasParam, {1, self.dim_out[1]})
@@ -64,6 +68,9 @@ function LSTMGateLayer:update(bp_err, input, output)
self.err_bakm:sigmoid_grad(bp_err[1], output[1])
for i = 1, #self.dim_in do
self["ltp" .. i]:update_by_err_input(self.err_bakm, input[i])
+ if self.param_type[i] == 'D' then
+ self["ltp" .. i].trans:diagonalize()
+ end
end
self.bp:update_by_gradient(self.err_bakm:colsum())
end