aboutsummaryrefslogblamecommitdiff
path: root/nerv/layer/gru.lua
blob: 71718d74a83323960a6117ca1196bfb228b79b14 (plain) (tree)
1
2
3
4
5
6
7





                                                          
                                                        







                                                                
                                              





























                                                                               
                                  













                                                                                   
                                                        
































                                                        
                                                        




                                                           


                               
                                              



                         






















                                                                       
local GRULayer = nerv.class('nerv.GRULayer', 'nerv.Layer')

function GRULayer:__init(id, global_conf, layer_conf)
    -- input1:x
    -- input2:h
    -- input3:c (h^~)
    nerv.Layer.__init(self, id, global_conf, layer_conf)
    if self.dim_in[2] ~= self.dim_out[1] then
        nerv.error("dim_in[2](%d) mismatch with dim_out[1](%d)",
                    self.dim_in[2], self.dim_out[1])
    end

    -- prepare a DAGLayer to hold the lstm structure
    local pr = layer_conf.pr
    if pr == nil then
        pr = nerv.ParamRepo({}, self.loc_type)
    end
    
    local function ap(str)
        return self.id .. '.' .. str
    end
    local din1, din2 = self.dim_in[1], self.dim_in[2]
    local dout1 = self.dim_out[1]
    local layers = {
        ["nerv.CombinerLayer"] = {
            [ap("inputXDup")] = {{}, {dim_in = {din1},
                                      dim_out = {din1, din1, din1},
                                      lambda = {1}}},
            [ap("inputHDup")] = {{}, {dim_in = {din2},
                                      dim_out = {din2, din2, din2, din2, din2},
                                      lambda = {1}}},
            [ap("updateGDup")] = {{}, {dim_in = {din2},
                                       dim_out = {din2, din2},
                                       lambda = {1}}},
            [ap("updateMergeL")] = {{}, {dim_in = {din2, din2, din2},
                                         dim_out = {dout1},
                                         lambda = {1, -1, 1}}},
        },
        ["nerv.AffineLayer"] = {
            [ap("mainAffineL")] = {{}, {dim_in = {din1, din2},
                                        dim_out = {dout1},
                                        pr = pr}},
        },
        ["nerv.TanhLayer"] = {
            [ap("mainTanhL")] = {{}, {dim_in = {dout1}, dim_out = {dout1}}},
        },
        ["nerv.LSTMGateLayer"] = {
            [ap("resetGateL")] = {{}, {dim_in = {din1, din2},
                                       dim_out = {din2},
                                       pr = pr}},
            [ap("updateGateL")] = {{}, {dim_in = {din1, din2},
                                        dim_out = {din2},
                                        pr = pr}},
        },
        ["nerv.ElemMulLayer"] = {
            [ap("resetGMulL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
            [ap("updateGMulCL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
            [ap("updateGMulHL")] = {{}, {dim_in = {din2, din2}, dim_out = {din2}}},
        },
    }
    
    self.lrepo = nerv.LayerRepo(layers, pr, global_conf)

    local connections = {
        ["<input>[1]"] = ap("inputXDup[1]"),
        ["<input>[2]"] = ap("inputHDup[1]"),

        [ap("inputXDup[1]")] = ap("resetGateL[1]"),
        [ap("inputHDup[1]")] = ap("resetGateL[2]"),
        [ap("inputXDup[2]")] = ap("updateGateL[1]"),
        [ap("inputHDup[2]")] = ap("updateGateL[2]"),
        [ap("updateGateL[1]")] = ap("updateGDup[1]"),

        [ap("resetGateL[1]")] = ap("resetGMulL[1]"),
        [ap("inputHDup[3]")] = ap("resetGMulL[2]"),

        [ap("inputXDup[3]")] = ap("mainAffineL[1]"),
        [ap("resetGMulL[1]")] = ap("mainAffineL[2]"),
        [ap("mainAffineL[1]")] = ap("mainTanhL[1]"),

        [ap("updateGDup[1]")] = ap("updateGMulHL[1]"),
        [ap("inputHDup[4]")] = ap("updateGMulHL[2]"),
        [ap("updateGDup[2]")] = ap("updateGMulCL[1]"),
        [ap("mainTanhL[1]")] = ap("updateGMulCL[2]"),
 
        [ap("inputHDup[5]")] = ap("updateMergeL[1]"),
        [ap("updateGMulHL[1]")] = ap("updateMergeL[2]"),
        [ap("updateGMulCL[1]")] = ap("updateMergeL[3]"),

        [ap("updateMergeL[1]")] = "<output>[1]",
    }

    self.dag = nerv.DAGLayer(self.id, global_conf,
                                {dim_in = self.dim_in,
                                dim_out = self.dim_out,
                                sub_layers = self.lrepo,
                                connections = connections})
    
    self:check_dim_len(2, 1) -- x, h and h
end

function GRULayer:bind_params()
    local pr = layer_conf.pr
    if pr == nil then
        pr = nerv.ParamRepo({}, self.loc_type)
    end
    self.lrepo:rebind(pr)
end

function GRULayer:init(batch_size, chunk_size)
    self.dag:init(batch_size, chunk_size)
end

function GRULayer:batch_resize(batch_size, chunk_size)
    self.dag:batch_resize(batch_size, chunk_size)
end

function GRULayer:update(bp_err, input, output, t)
    self.dag:update(bp_err, input, output, t)
end

function GRULayer:propagate(input, output, t)
    self.dag:propagate(input, output, t)
end

function GRULayer:back_propagate(bp_err, next_bp_err, input, output, t)
    self.dag:back_propagate(bp_err, next_bp_err, input, output, t)
end

function GRULayer:get_params()
    return self.dag:get_params()
end