local LSTMLayer = nerv.class('nerv.LSTMLayer', 'nerv.Layer') function LSTMLayer:__init(id, global_conf, layer_conf) -- input1:x -- input2:h -- input3:c nerv.Layer.__init(self, id, global_conf, layer_conf) -- prepare a DAGLayer to hold the lstm structure local pr = layer_conf.pr if pr == nil then pr = nerv.ParamRepo({}, self.loc_type) end local function ap(str) return self.id .. '.' .. str end local din1, din2, din3 = self.dim_in[1], self.dim_in[2], self.dim_in[3] local dout1, dout2, dout3 = self.dim_out[1], self.dim_out[2], self.dim_out[3] local layers = { ["nerv.CombinerLayer"] = { [ap("inputXDup")] = {dim_in = {din1}, dim_out = {din1, din1, din1, din1}, lambda = {1}}, [ap("inputHDup")] = {dim_in = {din2}, dim_out = {din2, din2, din2, din2}, lambda = {1}}, [ap("inputCDup")] = {dim_in = {din3}, dim_out = {din3, din3, din3}, lambda = {1}}, [ap("mainCDup")] = {dim_in = {din3, din3}, dim_out = {din3, din3, din3}, lambda = {1, 1}}, }, ["nerv.AffineLayer"] = { [ap("mainAffineL")] = {dim_in = {din1, din2}, dim_out = {dout1}, pr = pr}, }, ["nerv.TanhLayer"] = { [ap("mainTanhL")] = {dim_in = {dout1}, dim_out = {dout1}}, [ap("outputTanhL")] = {dim_in = {dout1}, dim_out = {dout1}}, }, ["nerv.LSTMGateLayer"] = { [ap("forgetGateL")] = {dim_in = {din1, din2, din3}, dim_out = {din3}, pr = pr}, [ap("inputGateL")] = {dim_in = {din1, din2, din3}, dim_out = {din3}, pr = pr}, [ap("outputGateL")] = {dim_in = {din1, din2, din3}, dim_out = {din3}, pr = pr}, }, ["nerv.ElemMulLayer"] = { [ap("inputGMulL")] = {dim_in = {din3, din3}, dim_out = {din3}}, [ap("forgetGMulL")] = {dim_in = {din3, din3}, dim_out = {din3}}, [ap("outputGMulL")] = {dim_in = {din3, din3}, dim_out = {din3}}, }, } self.lrepo = nerv.LayerRepo(layers, pr, global_conf) local connections = { ["[1]"] = ap("inputXDup[1]"), ["[2]"] = ap("inputHDup[1]"), ["[3]"] = ap("inputCDup[1]"), [ap("inputXDup[1]")] = ap("mainAffineL[1]"), [ap("inputHDup[1]")] = ap("mainAffineL[2]"), [ap("mainAffineL[1]")] = ap("mainTanhL[1]"), [ap("inputXDup[2]")] = ap("inputGateL[1]"), [ap("inputHDup[2]")] = ap("inputGateL[2]"), [ap("inputCDup[1]")] = ap("inputGateL[3]"), [ap("inputXDup[3]")] = ap("forgetGateL[1]"), [ap("inputHDup[3]")] = ap("forgetGateL[2]"), [ap("inputCDup[2]")] = ap("forgetGateL[3]"), [ap("mainTanhL[1]")] = ap("inputGMulL[1]"), [ap("inputGateL[1]")] = ap("inputGMulL[2]"), [ap("inputCDup[3]")] = ap("forgetGMulL[1]"), [ap("forgetGateL[1]")] = ap("forgetGMulL[2]"), [ap("inputGMulL[1]")] = ap("mainCDup[1]"), [ap("forgetGMulL[1]")] = ap("mainCDup[2]"), [ap("inputXDup[4]")] = ap("outputGateL[1]"), [ap("inputHDup[4]")] = ap("outputGateL[2]"), [ap("mainCDup[3]")] = ap("outputGateL[3]"), [ap("mainCDup[2]")] = "[2]", [ap("mainCDup[1]")] = ap("outputTanhL[1]"), [ap("outputTanhL[1]")] = ap("outputGMulL[1]"), [ap("outputGateL[1]")] = ap("outputGMulL[2]"), [ap("outputGMulL[1]")] = "[1]", } self.dag = nerv.DAGLayer(self.id, global_conf, {dim_in = self.dim_in, dim_out = self.dim_out, sub_layers = self.lrepo, connections = connections}) self:check_dim_len(3, 2) -- x, h, c and h, c end function LSTMLayer:bind_params() local pr = layer_conf.pr if pr == nil then pr = nerv.ParamRepo({}, self.loc_type) end self.lrepo:rebind(pr) end function LSTMLayer:init(batch_size, chunk_size) self.dag:init(batch_size, chunk_size) end function LSTMLayer:batch_resize(batch_size, chunk_size) self.dag:batch_resize(batch_size, chunk_size) end function LSTMLayer:update(bp_err, input, output, t) self.dag:update(bp_err, input, output, t) end function LSTMLayer:propagate(input, output, t) self.dag:propagate(input, output, t) end function LSTMLayer:back_propagate(bp_err, next_bp_err, input, output, t) self.dag:back_propagate(bp_err, next_bp_err, input, output, t) end function LSTMLayer:get_params() return self.dag:get_params() end