From a32195e3e2ae9ca0f0c7a82e73e6bddb64568c05 Mon Sep 17 00:00:00 2001 From: Determinant Date: Thu, 10 Mar 2016 13:40:11 +0800 Subject: major change: clearer param binding semantics; permit rebinding; enable resuming from previous training --- nerv/layer/lstm_gate.lua | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) (limited to 'nerv/layer/lstm_gate.lua') diff --git a/nerv/layer/lstm_gate.lua b/nerv/layer/lstm_gate.lua index 1963eba..7a27bab 100644 --- a/nerv/layer/lstm_gate.lua +++ b/nerv/layer/lstm_gate.lua @@ -2,20 +2,19 @@ local LSTMGateLayer = nerv.class('nerv.LSTMGateLayer', 'nerv.Layer') -- NOTE: this is a full matrix gate function LSTMGateLayer:__init(id, global_conf, layer_conf) - self.id = id - self.dim_in = layer_conf.dim_in - self.dim_out = layer_conf.dim_out - self.gconf = global_conf + nerv.Layer.__init(self, id, global_conf, layer_conf) + self:check_dim_len(-1, 1) --accept multiple inputs + self:bind_params() +end +function LSTMGateLayer:bind_params() for i = 1, #self.dim_in do - self["ltp" .. i] = self:find_param("ltp" .. i, layer_conf, global_conf, + self["ltp" .. i] = self:find_param("ltp" .. i, self.lconf, self.gconf, nerv.LinearTransParam, {self.dim_in[i], self.dim_out[1]}) end - self.bp = self:find_param("bp", layer_conf, global_conf, + self.bp = self:find_param("bp", self.lconf, self.gconf, nerv.BiasParam, {1, self.dim_out[1]}) - - self:check_dim_len(-1, 1) --accept multiple inputs end function LSTMGateLayer:init(batch_size) @@ -69,7 +68,7 @@ function LSTMGateLayer:update(bp_err, input, output) end function LSTMGateLayer:get_params() - local pr = nerv.ParamRepo({self.bp}) + local pr = nerv.ParamRepo({self.bp}, self.loc_type) for i = 1, #self.dim_in do pr:add(self["ltp" .. i].id, self["ltp" .. i]) end -- cgit v1.2.3