aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer')
-rw-r--r--nerv/layer/lstm.lua6
-rw-r--r--nerv/layer/lstm_gate.lua7
2 files changed, 10 insertions, 3 deletions
diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua
index 5dbcc20..56f674a 100644
--- a/nerv/layer/lstm.lua
+++ b/nerv/layer/lstm.lua
@@ -29,9 +29,9 @@ function LSTMLayer:__init(id, global_conf, layer_conf)
outputTanh = {dim_in = {dout}, dim_out = {dout}},
},
['nerv.LSTMGateLayer'] = {
- forgetGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr},
- inputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr},
- outputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr},
+ forgetGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr},
+ inputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr},
+ outputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr},
},
['nerv.ElemMulLayer'] = {
inputGateMul = {dim_in = {dout, dout}, dim_out = {dout}},
diff --git a/nerv/layer/lstm_gate.lua b/nerv/layer/lstm_gate.lua
index 7a27bab..e690721 100644
--- a/nerv/layer/lstm_gate.lua
+++ b/nerv/layer/lstm_gate.lua
@@ -3,6 +3,7 @@ local LSTMGateLayer = nerv.class('nerv.LSTMGateLayer', 'nerv.Layer')
function LSTMGateLayer:__init(id, global_conf, layer_conf)
nerv.Layer.__init(self, id, global_conf, layer_conf)
+ self.param_type = layer_conf.param_type
self:check_dim_len(-1, 1) --accept multiple inputs
self:bind_params()
end
@@ -12,6 +13,9 @@ function LSTMGateLayer:bind_params()
self["ltp" .. i] = self:find_param("ltp" .. i, self.lconf, self.gconf,
nerv.LinearTransParam,
{self.dim_in[i], self.dim_out[1]})
+ if self.param_type[i] == 'D' then
+ self["ltp" .. i].trans:diagonalize()
+ end
end
self.bp = self:find_param("bp", self.lconf, self.gconf,
nerv.BiasParam, {1, self.dim_out[1]})
@@ -63,6 +67,9 @@ function LSTMGateLayer:update(bp_err, input, output)
self.err_bakm:sigmoid_grad(bp_err[1], output[1])
for i = 1, #self.dim_in do
self["ltp" .. i]:update_by_err_input(self.err_bakm, input[i])
+ if self.param_type[i] == 'D' then
+ self["ltp" .. i].trans:diagonalize()
+ end
end
self.bp:update_by_gradient(self.err_bakm:colsum())
end