aboutsummaryrefslogblamecommitdiff
path: root/nerv/layer/gru.lua
blob: e81d21a131206ecafa91ae4037f58da59af8c712 (plain) (tree)

















































                                                                               
                                  












































































                                                                                   
local GRULayer = nerv.class('nerv.GRULayer', 'nerv.Layer')

function GRULayer:__init(id, global_conf, layer_conf)
    -- input1:x
    -- input2:h
    -- input3:c (h^~)
    self.id = id
    self.dim_in = layer_conf.dim_in
    self.dim_out = layer_conf.dim_out
    self.gconf = global_conf

    if self.dim_in[2] ~= self.dim_out[1] then
        nerv.error("dim_in[2](%d) mismatch with dim_out[1](%d)",
                    self.dim_in[2], self.dim_out[1])
    end

    -- prepare a DAGLayer to hold the lstm structure
    local pr = layer_conf.pr
    if pr == nil then
        pr = nerv.ParamRepo()
    end
    
    local function ap(str)
        return self.id .. '.' .. str
    end
    local din1, din2 = self.dim_in[1], self.dim_in[2]
    local dout1 = self.dim_out[1]
    local layers = {
        ["nerv.CombinerLayer"] = {
            [ap("inputXDup")] = {{}, {dim_in = {din1},
                                      dim_out = {din1, din1, din1},
                                      lambda = {1}}},
            [ap("inputHDup")] = {{}, {dim_in = {din2},
                                      dim_out = {din2, din2, din2, din2, din2},
                                      lambda = {1}}},
            [ap("updateGDup")] = {{}, {dim_in = {din2},
                                       dim_out = {din2, din2},
                                       lambda = {1}}},
            [ap("updateMergeL")] = {{}, {dim_in = {din2, din2, din2},
                                         dim_out = {dout1},
                                         lambda = {1, -1, 1}}},
        },
        ["nerv.AffineLayer"] = {
            [ap("mainAffineL")] = {{}, {dim_in = {din1, din2},
                                        dim_out = {dout1},
                                        pr = pr}},
        },
        ["nerv.TanhLayer"] = {
            [ap("mainTanhL")] = {{}, {dim_in = {dout1}, dim_out = {dout1}}},
        },
        ["nerv.LSTMGateLayer"] = {
            [ap("resetGateL")] = {{}, {dim_in = {din1, din2},
                                       dim_out = {din2},
                                       pr = pr}},
            [ap("updateGateL")] = {{}, {dim_in = {din1, din2},
                                        dim_out = {din2},
                                        pr = pr}},
        },
        ["nerv.ElemMulLayer"] = {
            [ap("resetGMulL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
            [ap("updateGMulCL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
            [ap("updateGMulHL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
        },
    }
    
    local layerRepo = nerv.LayerRepo(layers, pr, global_conf)

    local connections = {
        ["<input>[1]"] = ap("inputXDup[1]"),
        ["<input>[2]"] = ap("inputHDup[1]"),

        [ap("inputXDup[1]")] = ap("resetGateL[1]"),
        [ap("inputHDup[1]")] = ap("resetGateL[2]"),
        [ap("inputXDup[2]")] = ap("updateGateL[1]"),
        [ap("inputHDup[2]")] = ap("updateGateL[2]"),
        [ap("updateGateL[1]")] = ap("updateGDup[1]"),

        [ap("resetGateL[1]")] = ap("resetGMulL[1]"),
        [ap("inputHDup[3]")] = ap("resetGMulL[2]"),

        [ap("inputXDup[3]")] = ap("mainAffineL[1]"),
        [ap("resetGMulL[1]")] = ap("mainAffineL[2]"),
        [ap("mainAffineL[1]")] = ap("mainTanhL[1]"),

        [ap("updateGDup[1]")] = ap("updateGMulHL[1]"),
        [ap("inputHDup[4]")] = ap("updateGMulHL[2]"),
        [ap("updateGDup[2]")] = ap("updateGMulCL[1]"),
        [ap("mainTanhL[1]")] = ap("updateGMulCL[2]"),
 
        [ap("inputHDup[5]")] = ap("updateMergeL[1]"),
        [ap("updateGMulHL[1]")] = ap("updateMergeL[2]"),
        [ap("updateGMulCL[1]")] = ap("updateMergeL[3]"),

        [ap("updateMergeL[1]")] = "<output>[1]",
    }

    self.dag = nerv.DAGLayer(self.id, global_conf,
                                {dim_in = self.dim_in,
                                dim_out = self.dim_out,
                                sub_layers = layerRepo,
                                connections = connections})
    
    self:check_dim_len(2, 1) -- x, h and h
end

function GRULayer:init(batch_size, chunk_size)
    self.dag:init(batch_size, chunk_size)
end

function GRULayer:batch_resize(batch_size, chunk_size)
    self.dag:batch_resize(batch_size, chunk_size)
end

function GRULayer:update(bp_err, input, output, t)
    self.dag:update(bp_err, input, output, t)
end

function GRULayer:propagate(input, output, t)
    self.dag:propagate(input, output, t)
end

function GRULayer:back_propagate(bp_err, next_bp_err, input, output, t)
    self.dag:back_propagate(bp_err, next_bp_err, input, output, t)
end

function GRULayer:get_params()
    return self.dag:get_params()
end