summaryrefslogtreecommitdiff
path: root/nerv/layer/gru.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer/gru.lua')
-rw-r--r--nerv/layer/gru.lua128
1 files changed, 128 insertions, 0 deletions
diff --git a/nerv/layer/gru.lua b/nerv/layer/gru.lua
new file mode 100644
index 0000000..2162e28
--- /dev/null
+++ b/nerv/layer/gru.lua
@@ -0,0 +1,128 @@
+local GRULayer = nerv.class('nerv.GRULayer', 'nerv.Layer')
+
+function GRULayer:__init(id, global_conf, layer_conf)
+ -- input1:x
+ -- input2:h
+ -- input3:c (h^~)
+ self.id = id
+ self.dim_in = layer_conf.dim_in
+ self.dim_out = layer_conf.dim_out
+ self.gconf = global_conf
+
+ if self.dim_in[2] ~= self.dim_out[1] then
+ nerv.error("dim_in[2](%d) mismatch with dim_out[1](%d)",
+ self.dim_in[2], self.dim_out[1])
+ end
+
+ -- prepare a DAGLayer to hold the lstm structure
+ local pr = layer_conf.pr
+ if pr == nil then
+ pr = nerv.ParamRepo()
+ end
+
+ local function ap(str)
+ return self.id .. '.' .. str
+ end
+ local din1, din2 = self.dim_in[1], self.dim_in[2]
+ local dout1 = self.dim_out[1]
+ local layers = {
+ ["nerv.CombinerLayer"] = {
+ [ap("inputXDup")] = {{}, {dim_in = {din1},
+ dim_out = {din1, din1, din1},
+ lambda = {1}}},
+ [ap("inputHDup")] = {{}, {dim_in = {din2},
+ dim_out = {din2, din2, din2, din2, din2},
+ lambda = {1}}},
+ [ap("updateGDup")] = {{}, {dim_in = {din2},
+ dim_out = {din2, din2},
+ lambda = {1}}},
+ [ap("updateMergeL")] = {{}, {dim_in = {din2, din2, din2},
+ dim_out = {dout1},
+ lambda = {1, -1, 1}}},
+ },
+ ["nerv.AffineLayer"] = {
+ [ap("mainAffineL")] = {{}, {dim_in = {din1, din2},
+ dim_out = {dout1},
+ pr = pr}},
+ },
+ ["nerv.TanhLayer"] = {
+ [ap("mainTanhL")] = {{}, {dim_in = {dout1}, dim_out = {dout1}}},
+ },
+ ["nerv.GateFLayer"] = {
+ [ap("resetGateL")] = {{}, {dim_in = {din1, din2},
+ dim_out = {din2},
+ pr = pr}},
+ [ap("updateGateL")] = {{}, {dim_in = {din1, din2},
+ dim_out = {din2},
+ pr = pr}},
+ },
+ ["nerv.ElemMulLayer"] = {
+ [ap("resetGMulL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
+ [ap("updateGMulCL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
+ [ap("updateGMulHL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
+ },
+ }
+
+ local layerRepo = nerv.LayerRepo(layers, pr, global_conf)
+
+ local connections = {
+ ["<input>[1]"] = ap("inputXDup[1]"),
+ ["<input>[2]"] = ap("inputHDup[1]"),
+
+ [ap("inputXDup[1]")] = ap("resetGateL[1]"),
+ [ap("inputHDup[1]")] = ap("resetGateL[2]"),
+ [ap("inputXDup[2]")] = ap("updateGateL[1]"),
+ [ap("inputHDup[2]")] = ap("updateGateL[2]"),
+ [ap("updateGateL[1]")] = ap("updateGDup[1]"),
+
+ [ap("resetGateL[1]")] = ap("resetGMulL[1]"),
+ [ap("inputHDup[3]")] = ap("resetGMulL[2]"),
+
+ [ap("inputXDup[3]")] = ap("mainAffineL[1]"),
+ [ap("resetGMulL[1]")] = ap("mainAffineL[2]"),
+ [ap("mainAffineL[1]")] = ap("mainTanhL[1]"),
+
+ [ap("updateGDup[1]")] = ap("updateGMulHL[1]"),
+ [ap("inputHDup[4]")] = ap("updateGMulHL[2]"),
+ [ap("updateGDup[2]")] = ap("updateGMulCL[1]"),
+ [ap("mainTanhL[1]")] = ap("updateGMulCL[2]"),
+
+ [ap("inputHDup[5]")] = ap("updateMergeL[1]"),
+ [ap("updateGMulHL[1]")] = ap("updateMergeL[2]"),
+ [ap("updateGMulCL[1]")] = ap("updateMergeL[3]"),
+
+ [ap("updateMergeL[1]")] = "<output>[1]",
+ }
+
+ self.dag = nerv.DAGLayer(self.id, global_conf,
+ {dim_in = self.dim_in,
+ dim_out = self.dim_out,
+ sub_layers = layerRepo,
+ connections = connections})
+
+ self:check_dim_len(2, 1) -- x, h and h
+end
+
+function GRULayer:init(batch_size, chunk_size)
+ self.dag:init(batch_size, chunk_size)
+end
+
+function GRULayer:batch_resize(batch_size, chunk_size)
+ self.dag:batch_resize(batch_size, chunk_size)
+end
+
+function GRULayer:update(bp_err, input, output, t)
+ self.dag:update(bp_err, input, output, t)
+end
+
+function GRULayer:propagate(input, output, t)
+ self.dag:propagate(input, output, t)
+end
+
+function GRULayer:back_propagate(bp_err, next_bp_err, input, output, t)
+ self.dag:back_propagate(bp_err, next_bp_err, input, output, t)
+end
+
+function GRULayer:get_params()
+ return self.dag:get_params()
+end