local LSTMLayer = nerv.class('nerv.LSTMLayer', 'nerv.Layer')
function LSTMLayer:__init(id, global_conf, layer_conf)
-- input1:x
-- input2:h
-- input3:c
nerv.Layer.__init(self, id, global_conf, layer_conf)
-- prepare a DAGLayer to hold the lstm structure
local pr = layer_conf.pr
if pr == nil then
pr = nerv.ParamRepo({}, self.loc_type)
end
local function ap(str)
return self.id .. '.' .. str
end
local din1, din2, din3 = self.dim_in[1], self.dim_in[2], self.dim_in[3]
local dout1, dout2, dout3 = self.dim_out[1], self.dim_out[2], self.dim_out[3]
local layers = {
["nerv.CombinerLayer"] = {
[ap("inputXDup")] = {dim_in = {din1},
dim_out = {din1, din1, din1, din1},
lambda = {1}},
[ap("inputHDup")] = {dim_in = {din2},
dim_out = {din2, din2, din2, din2},
lambda = {1}},
[ap("inputCDup")] = {dim_in = {din3},
dim_out = {din3, din3, din3},
lambda = {1}},
[ap("mainCDup")] = {dim_in = {din3, din3},
dim_out = {din3, din3, din3},
lambda = {1, 1}},
},
["nerv.AffineLayer"] = {
[ap("mainAffineL")] = {dim_in = {din1, din2},
dim_out = {dout1},
pr = pr},
},
["nerv.TanhLayer"] = {
[ap("mainTanhL")] = {dim_in = {dout1}, dim_out = {dout1}},
[ap("outputTanhL")] = {dim_in = {dout1}, dim_out = {dout1}},
},
["nerv.LSTMGateLayer"] = {
[ap("forgetGateL")] = {dim_in = {din1, din2, din3},
dim_out = {din3}, pr = pr},
[ap("inputGateL")] = {dim_in = {din1, din2, din3},
dim_out = {din3}, pr = pr},
[ap("outputGateL")] = {dim_in = {din1, din2, din3},
dim_out = {din3}, pr = pr},
},
["nerv.ElemMulLayer"] = {
[ap("inputGMulL")] = {dim_in = {din3, din3},
dim_out = {din3}},
[ap("forgetGMulL")] = {dim_in = {din3, din3},
dim_out = {din3}},
[ap("outputGMulL")] = {dim_in = {din3, din3},
dim_out = {din3}},
},
}
self.lrepo = nerv.LayerRepo(layers, pr, global_conf)
local connections = {
["[1]"] = ap("inputXDup[1]"),
["[2]"] = ap("inputHDup[1]"),
["[3]"] = ap("inputCDup[1]"),
[ap("inputXDup[1]")] = ap("mainAffineL[1]"),
[ap("inputHDup[1]")] = ap("mainAffineL[2]"),
[ap("mainAffineL[1]")] = ap("mainTanhL[1]"),
[ap("inputXDup[2]")] = ap("inputGateL[1]"),
[ap("inputHDup[2]")] = ap("inputGateL[2]"),
[ap("inputCDup[1]")] = ap("inputGateL[3]"),
[ap("inputXDup[3]")] = ap("forgetGateL[1]"),
[ap("inputHDup[3]")] = ap("forgetGateL[2]"),
[ap("inputCDup[2]")] = ap("forgetGateL[3]"),
[ap("mainTanhL[1]")] = ap("inputGMulL[1]"),
[ap("inputGateL[1]")] = ap("inputGMulL[2]"),
[ap("inputCDup[3]")] = ap("forgetGMulL[1]"),
[ap("forgetGateL[1]")] = ap("forgetGMulL[2]"),
[ap("inputGMulL[1]")] = ap("mainCDup[1]"),
[ap("forgetGMulL[1]")] = ap("mainCDup[2]"),
[ap("inputXDup[4]")] = ap("outputGateL[1]"),
[ap("inputHDup[4]")] = ap("outputGateL[2]"),
[ap("mainCDup[3]")] = ap("outputGateL[3]"),
[ap("mainCDup[2]")] = "