aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/lstm.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer/lstm.lua')
-rw-r--r--nerv/layer/lstm.lua192
1 files changed, 67 insertions, 125 deletions
diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua
index d8eee71..56f674a 100644
--- a/nerv/layer/lstm.lua
+++ b/nerv/layer/lstm.lua
@@ -1,143 +1,85 @@
-local LSTMLayer = nerv.class('nerv.LSTMLayer', 'nerv.Layer')
+local LSTMLayer = nerv.class('nerv.LSTMLayer', 'nerv.GraphLayer')
function LSTMLayer:__init(id, global_conf, layer_conf)
- -- input1:x
- -- input2:h
- -- input3:c
- self.id = id
- self.dim_in = layer_conf.dim_in
- self.dim_out = layer_conf.dim_out
- self.gconf = global_conf
+ nerv.Layer.__init(self, id, global_conf, layer_conf)
+ self:check_dim_len(1, 1)
+
+ local din = layer_conf.dim_in[1]
+ local dout = layer_conf.dim_out[1]
- -- prepare a DAGLayer to hold the lstm structure
local pr = layer_conf.pr
if pr == nil then
- pr = nerv.ParamRepo()
- end
-
- local function ap(str)
- return self.id .. '.' .. str
+ pr = nerv.ParamRepo({}, self.loc_type)
end
- local din1, din2, din3 = self.dim_in[1], self.dim_in[2], self.dim_in[3]
- local dout1, dout2 = self.dim_out[1], self.dim_out[2]
- local layers = {
- ["nerv.CombinerLayer"] = {
- [ap("inputXDup")] = {{}, {dim_in = {din1},
- dim_out = {din1, din1, din1, din1},
- lambda = {1}}},
- [ap("inputHDup")] = {{}, {dim_in = {din2},
- dim_out = {din2, din2, din2, din2},
- lambda = {1}}},
-
- [ap("inputCDup")] = {{}, {dim_in = {din3},
- dim_out = {din3, din3, din3},
- lambda = {1}}},
-
- [ap("mainCDup")] = {{}, {dim_in = {din3, din3},
- dim_out = {din3, din3, din3},
- lambda = {1, 1}}},
+ local layers = {
+ ['nerv.CombinerLayer'] = {
+ mainCombine = {dim_in = {dout, dout}, dim_out = {dout}, lambda = {1, 1}},
},
- ["nerv.AffineLayer"] = {
- [ap("mainAffineL")] = {{}, {dim_in = {din1, din2},
- dim_out = {dout1},
- pr = pr}},
+ ['nerv.DuplicateLayer'] = {
+ inputDup = {dim_in = {din}, dim_out = {din, din, din, din}},
+ outputDup = {dim_in = {dout}, dim_out = {dout, dout, dout, dout, dout}},
+ cellDup = {dim_in = {dout}, dim_out = {dout, dout, dout, dout, dout}},
},
- ["nerv.TanhLayer"] = {
- [ap("mainTanhL")] = {{}, {dim_in = {dout1}, dim_out = {dout1}}},
- [ap("outputTanhL")] = {{}, {dim_in = {dout1}, dim_out = {dout1}}},
+ ['nerv.AffineLayer'] = {
+ mainAffine = {dim_in = {din, dout}, dim_out = {dout}, pr = pr},
},
- ["nerv.LSTMGateLayer"] = {
- [ap("forgetGateL")] = {{}, {dim_in = {din1, din2, din3},
- dim_out = {din3}, pr = pr,
- param_type = {'N', 'N', 'D'}}},
- [ap("inputGateL")] = {{}, {dim_in = {din1, din2, din3},
- dim_out = {din3}, pr = pr,
- param_type = {'N', 'N', 'D'}}},
- [ap("outputGateL")] = {{}, {dim_in = {din1, din2, din3},
- dim_out = {din3}, pr = pr,
- param_type = {'N', 'N', 'D'}}},
-
+ ['nerv.TanhLayer'] = {
+ mainTanh = {dim_in = {dout}, dim_out = {dout}},
+ outputTanh = {dim_in = {dout}, dim_out = {dout}},
},
- ["nerv.ElemMulLayer"] = {
- [ap("inputGMulL")] = {{}, {dim_in = {din3, din3},
- dim_out = {din3}}},
- [ap("forgetGMulL")] = {{}, {dim_in = {din3, din3},
- dim_out = {din3}}},
- [ap("outputGMulL")] = {{}, {dim_in = {din3, din3},
- dim_out = {din3}}},
+ ['nerv.LSTMGateLayer'] = {
+ forgetGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr},
+ inputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr},
+ outputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr},
+ },
+ ['nerv.ElemMulLayer'] = {
+ inputGateMul = {dim_in = {dout, dout}, dim_out = {dout}},
+ forgetGateMul = {dim_in = {dout, dout}, dim_out = {dout}},
+ outputGateMul = {dim_in = {dout, dout}, dim_out = {dout}},
},
}
- local layerRepo = nerv.LayerRepo(layers, pr, global_conf)
-
local connections = {
- ["<input>[1]"] = ap("inputXDup[1]"),
- ["<input>[2]"] = ap("inputHDup[1]"),
- ["<input>[3]"] = ap("inputCDup[1]"),
-
- [ap("inputXDup[1]")] = ap("mainAffineL[1]"),
- [ap("inputHDup[1]")] = ap("mainAffineL[2]"),
- [ap("mainAffineL[1]")] = ap("mainTanhL[1]"),
-
- [ap("inputXDup[2]")] = ap("inputGateL[1]"),
- [ap("inputHDup[2]")] = ap("inputGateL[2]"),
- [ap("inputCDup[1]")] = ap("inputGateL[3]"),
-
- [ap("inputXDup[3]")] = ap("forgetGateL[1]"),
- [ap("inputHDup[3]")] = ap("forgetGateL[2]"),
- [ap("inputCDup[2]")] = ap("forgetGateL[3]"),
-
- [ap("mainTanhL[1]")] = ap("inputGMulL[1]"),
- [ap("inputGateL[1]")] = ap("inputGMulL[2]"),
-
- [ap("inputCDup[3]")] = ap("forgetGMulL[1]"),
- [ap("forgetGateL[1]")] = ap("forgetGMulL[2]"),
-
- [ap("inputGMulL[1]")] = ap("mainCDup[1]"),
- [ap("forgetGMulL[1]")] = ap("mainCDup[2]"),
-
- [ap("inputXDup[4]")] = ap("outputGateL[1]"),
- [ap("inputHDup[4]")] = ap("outputGateL[2]"),
- [ap("mainCDup[3]")] = ap("outputGateL[3]"),
-
- [ap("mainCDup[2]")] = "<output>[2]",
- [ap("mainCDup[1]")] = ap("outputTanhL[1]"),
-
- [ap("outputTanhL[1]")] = ap("outputGMulL[1]"),
- [ap("outputGateL[1]")] = ap("outputGMulL[2]"),
-
- [ap("outputGMulL[1]")] = "<output>[1]",
+ -- lstm input
+ {'<input>[1]', 'inputDup[1]', 0},
+
+ -- input gate
+ {'inputDup[1]', 'inputGate[1]', 0},
+ {'outputDup[1]', 'inputGate[2]', 1},
+ {'cellDup[1]', 'inputGate[3]', 1},
+
+ -- forget gate
+ {'inputDup[2]', 'forgetGate[1]', 0},
+ {'outputDup[2]', 'forgetGate[2]', 1},
+ {'cellDup[2]', 'forgetGate[3]', 1},
+
+ -- lstm cell
+ {'forgetGate[1]', 'forgetGateMul[1]', 0},
+ {'cellDup[3]', 'forgetGateMul[2]', 1},
+ {'inputDup[3]', 'mainAffine[1]', 0},
+ {'outputDup[3]', 'mainAffine[2]', 1},
+ {'mainAffine[1]', 'mainTanh[1]', 0},
+ {'inputGate[1]', 'inputGateMul[1]', 0},
+ {'mainTanh[1]', 'inputGateMul[2]', 0},
+ {'inputGateMul[1]', 'mainCombine[1]', 0},
+ {'forgetGateMul[1]', 'mainCombine[2]', 0},
+ {'mainCombine[1]', 'cellDup[1]', 0},
+
+ -- forget gate
+ {'inputDup[4]', 'outputGate[1]', 0},
+ {'outputDup[4]', 'outputGate[2]', 1},
+ {'cellDup[4]', 'outputGate[3]', 0},
+
+ -- lstm output
+ {'cellDup[5]', 'outputTanh[1]', 0},
+ {'outputGate[1]', 'outputGateMul[1]', 0},
+ {'outputTanh[1]', 'outputGateMul[2]', 0},
+ {'outputGateMul[1]', 'outputDup[1]', 0},
+ {'outputDup[5]', '<output>[1]', 0},
}
- self.dag = nerv.DAGLayer(self.id, global_conf,
- {dim_in = self.dim_in,
- dim_out = self.dim_out,
- sub_layers = layerRepo,
- connections = connections})
-
- self:check_dim_len(3, 2) -- x, h, c and h, c
-end
-
-function LSTMLayer:init(batch_size, chunk_size)
- self.dag:init(batch_size, chunk_size)
-end
-
-function LSTMLayer:batch_resize(batch_size, chunk_size)
- self.dag:batch_resize(batch_size, chunk_size)
-end
-
-function LSTMLayer:update(bp_err, input, output, t)
- self.dag:update(bp_err, input, output, t)
-end
-
-function LSTMLayer:propagate(input, output, t)
- self.dag:propagate(input, output, t)
-end
-
-function LSTMLayer:back_propagate(bp_err, next_bp_err, input, output, t)
- self.dag:back_propagate(bp_err, next_bp_err, input, output, t)
-end
-function LSTMLayer:get_params()
- return self.dag:get_params()
+ self:add_prefix(layers, connections)
+ local layer_repo = nerv.LayerRepo(layers, pr, global_conf)
+ self:graph_init(layer_repo, connections)
end