local GRULayer = nerv.class('nerv.GRULayer', 'nerv.Layer')
function GRULayer:__init(id, global_conf, layer_conf)
-- input1:x
-- input2:h
-- input3:c (h^~)
nerv.Layer.__init(self, id, global_conf, layer_conf)
if self.dim_in[2] ~= self.dim_out[1] then
nerv.error("dim_in[2](%d) mismatch with dim_out[1](%d)",
self.dim_in[2], self.dim_out[1])
end
-- prepare a DAGLayer to hold the lstm structure
local pr = layer_conf.pr
if pr == nil then
pr = nerv.ParamRepo({}, self.loc_type)
end
local function ap(str)
return self.id .. '.' .. str
end
local din1, din2 = self.dim_in[1], self.dim_in[2]
local dout1 = self.dim_out[1]
local layers = {
["nerv.CombinerLayer"] = {
[ap("inputXDup")] = {{}, {dim_in = {din1},
dim_out = {din1, din1, din1},
lambda = {1}}},
[ap("inputHDup")] = {{}, {dim_in = {din2},
dim_out = {din2, din2, din2, din2, din2},
lambda = {1}}},
[ap("updateGDup")] = {{}, {dim_in = {din2},
dim_out = {din2, din2},
lambda = {1}}},
[ap("updateMergeL")] = {{}, {dim_in = {din2, din2, din2},
dim_out = {dout1},
lambda = {1, -1, 1}}},
},
["nerv.AffineLayer"] = {
[ap("mainAffineL")] = {{}, {dim_in = {din1, din2},
dim_out = {dout1},
pr = pr}},
},
["nerv.TanhLayer"] = {
[ap("mainTanhL")] = {{}, {dim_in = {dout1}, dim_out = {dout1}}},
},
["nerv.LSTMGateLayer"] = {
[ap("resetGateL")] = {{}, {dim_in = {din1, din2},
dim_out = {din2},
pr = pr}},
[ap("updateGateL")] = {{}, {dim_in = {din1, din2},
dim_out = {din2},
pr = pr}},
},
["nerv.ElemMulLayer"] = {
[ap("resetGMulL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
[ap("updateGMulCL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
[ap("updateGMulHL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
},
}
self.lrepo = nerv.LayerRepo(layers, pr, global_conf)
local connections = {
["<input>[1]"] = ap("inputXDup[1]"),
["<input>[2]"] = ap("inputHDup[1]"),
[ap("inputXDup[1]")] = ap("resetGateL[1]"),
[ap("inputHDup[1]")] = ap("resetGateL[2]"),
[ap("inputXDup[2]")] = ap("updateGateL[1]"),
[ap("inputHDup[2]")] = ap("updateGateL[2]"),
[ap("updateGateL[1]")] = ap("updateGDup[1]"),
[ap("resetGateL[1]")] = ap("resetGMulL[1]"),
[ap("inputHDup[3]")] = ap("resetGMulL[2]"),
[ap("inputXDup[3]")] = ap("mainAffineL[1]"),
[ap("resetGMulL[1]")] = ap("mainAffineL[2]"),
[ap("mainAffineL[1]")] = ap("mainTanhL[1]"),
[ap("updateGDup[1]")] = ap("updateGMulHL[1]"),
[ap("inputHDup[4]")] = ap("updateGMulHL[2]"),
[ap("updateGDup[2]")] = ap("updateGMulCL[1]"),
[ap("mainTanhL[1]")] = ap("updateGMulCL[2]"),
[ap("inputHDup[5]")] = ap("updateMergeL[1]"),
[ap("updateGMulHL[1]")] = ap("updateMergeL[2]"),
[ap("updateGMulCL[1]")] = ap("updateMergeL[3]"),
[ap("updateMergeL[1]")] = "<output>[1]",
}
self.dag = nerv.DAGLayer(self.id, global_conf,
{dim_in = self.dim_in,
dim_out = self.dim_out,
sub_layers = self.lrepo,
connections = connections})
self:check_dim_len(2, 1) -- x, h and h
end
function GRULayer:bind_params()
local pr = layer_conf.pr
if pr == nil then
pr = nerv.ParamRepo({}, self.loc_type)
end
self.lrepo:rebind(pr)
end
function GRULayer:init(batch_size, chunk_size)
self.dag:init(batch_size, chunk_size)
end
function GRULayer:batch_resize(batch_size, chunk_size)
self.dag:batch_resize(batch_size, chunk_size)
end
function GRULayer:update(bp_err, input, output, t)
self.dag:update(bp_err, input, output, t)
end
function GRULayer:propagate(input, output, t)
self.dag:propagate(input, output, t)
end
function GRULayer:back_propagate(bp_err, next_bp_err, input, output, t)
self.dag:back_propagate(bp_err, next_bp_err, input, output, t)
end
function GRULayer:get_params()
return self.dag:get_params()
end