aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/lstm_gate.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-03-10 13:40:11 +0800
committerDeterminant <ted.sybil@gmail.com>2016-03-10 13:40:11 +0800
commita32195e3e2ae9ca0f0c7a82e73e6bddb64568c05 (patch)
treea19f21f8cbadecff7357f9a102f160f5fe699b65 /nerv/layer/lstm_gate.lua
parent4a6872601f05e9ecc059f83fb64a0a4887992b99 (diff)
major change: clearer param binding semantics; permit rebinding; enable
resuming from previous training
Diffstat (limited to 'nerv/layer/lstm_gate.lua')
-rw-r--r--nerv/layer/lstm_gate.lua17
1 files changed, 8 insertions, 9 deletions
diff --git a/nerv/layer/lstm_gate.lua b/nerv/layer/lstm_gate.lua
index 1963eba..7a27bab 100644
--- a/nerv/layer/lstm_gate.lua
+++ b/nerv/layer/lstm_gate.lua
@@ -2,20 +2,19 @@ local LSTMGateLayer = nerv.class('nerv.LSTMGateLayer', 'nerv.Layer')
-- NOTE: this is a full matrix gate
function LSTMGateLayer:__init(id, global_conf, layer_conf)
- self.id = id
- self.dim_in = layer_conf.dim_in
- self.dim_out = layer_conf.dim_out
- self.gconf = global_conf
+ nerv.Layer.__init(self, id, global_conf, layer_conf)
+ self:check_dim_len(-1, 1) --accept multiple inputs
+ self:bind_params()
+end
+function LSTMGateLayer:bind_params()
for i = 1, #self.dim_in do
- self["ltp" .. i] = self:find_param("ltp" .. i, layer_conf, global_conf,
+ self["ltp" .. i] = self:find_param("ltp" .. i, self.lconf, self.gconf,
nerv.LinearTransParam,
{self.dim_in[i], self.dim_out[1]})
end
- self.bp = self:find_param("bp", layer_conf, global_conf,
+ self.bp = self:find_param("bp", self.lconf, self.gconf,
nerv.BiasParam, {1, self.dim_out[1]})
-
- self:check_dim_len(-1, 1) --accept multiple inputs
end
function LSTMGateLayer:init(batch_size)
@@ -69,7 +68,7 @@ function LSTMGateLayer:update(bp_err, input, output)
end
function LSTMGateLayer:get_params()
- local pr = nerv.ParamRepo({self.bp})
+ local pr = nerv.ParamRepo({self.bp}, self.loc_type)
for i = 1, #self.dim_in do
pr:add(self["ltp" .. i].id, self["ltp" .. i])
end