aboutsummaryrefslogblamecommitdiff
path: root/nerv/layer/lstm.lua
blob: 641d5dc52b2be9994fa3b2477d9b3fe07edf977a (plain) (tree)
1
2
3
4
5
6
7
8
9
10





                                                            
                                                        


                                                    
                                              








                                                                                 
                                                 
                                                                         
                                                    
 
                                                 
                                                                         
                                                    
 
                                                 
                                                                   
                                                    
 
                                                      
                                                                  
                                                      

                                
                                                         
                                                          
                                                 

                              

                                                                        

                                  





                                                                   


                                 





                                                          


          
                                                        









































                                                        
                                                         




                                                            


                                
                                              



                         






















                                                                        
local LSTMLayer = nerv.class('nerv.LSTMLayer', 'nerv.Layer')

function LSTMLayer:__init(id, global_conf, layer_conf)
    -- input1:x
    -- input2:h
    -- input3:c
    nerv.Layer.__init(self, id, global_conf, layer_conf)
    -- prepare a DAGLayer to hold the lstm structure
    local pr = layer_conf.pr
    if pr == nil then
        pr = nerv.ParamRepo({}, self.loc_type)
    end
    
    local function ap(str)
        return self.id .. '.' .. str
    end
    local din1, din2, din3 = self.dim_in[1], self.dim_in[2], self.dim_in[3]
    local dout1, dout2, dout3 = self.dim_out[1], self.dim_out[2], self.dim_out[3]
    local layers = {
        ["nerv.CombinerLayer"] = {
            [ap("inputXDup")] = {dim_in = {din1},
                                      dim_out = {din1, din1, din1, din1},
                                      lambda = {1}},

            [ap("inputHDup")] = {dim_in = {din2},
                                      dim_out = {din2, din2, din2, din2},
                                      lambda = {1}},

            [ap("inputCDup")] = {dim_in = {din3},
                                      dim_out = {din3, din3, din3},
                                      lambda = {1}},

            [ap("mainCDup")] = {dim_in = {din3, din3},
                                     dim_out = {din3, din3, din3},
                                     lambda = {1, 1}},
        },
        ["nerv.AffineLayer"] = {
            [ap("mainAffineL")] = {dim_in = {din1, din2},
                                        dim_out = {dout1},
                                        pr = pr},
        },
        ["nerv.TanhLayer"] = {
            [ap("mainTanhL")] = {dim_in = {dout1}, dim_out = {dout1}},
            [ap("outputTanhL")] = {dim_in = {dout1}, dim_out = {dout1}},
        },
        ["nerv.LSTMGateLayer"] = {
            [ap("forgetGateL")] = {dim_in = {din1, din2, din3},
                                        dim_out = {din3}, pr = pr},
            [ap("inputGateL")] = {dim_in = {din1, din2, din3},
                                        dim_out = {din3}, pr = pr},
            [ap("outputGateL")] = {dim_in = {din1, din2, din3},
                                        dim_out = {din3}, pr = pr},

        },
        ["nerv.ElemMulLayer"] = {
            [ap("inputGMulL")] = {dim_in = {din3, din3},
                                        dim_out = {din3}},
            [ap("forgetGMulL")] = {dim_in = {din3, din3},
                                        dim_out = {din3}},
            [ap("outputGMulL")] = {dim_in = {din3, din3},
                                        dim_out = {din3}},
        },
    }
    
    self.lrepo = nerv.LayerRepo(layers, pr, global_conf)

    local connections = {
        ["<input>[1]"] = ap("inputXDup[1]"),
        ["<input>[2]"] = ap("inputHDup[1]"),
        ["<input>[3]"] = ap("inputCDup[1]"),

        [ap("inputXDup[1]")] = ap("mainAffineL[1]"),
        [ap("inputHDup[1]")] = ap("mainAffineL[2]"),
        [ap("mainAffineL[1]")] = ap("mainTanhL[1]"),

        [ap("inputXDup[2]")] = ap("inputGateL[1]"),
        [ap("inputHDup[2]")] = ap("inputGateL[2]"),
        [ap("inputCDup[1]")] = ap("inputGateL[3]"),
        
        [ap("inputXDup[3]")] = ap("forgetGateL[1]"),
        [ap("inputHDup[3]")] = ap("forgetGateL[2]"),
        [ap("inputCDup[2]")] = ap("forgetGateL[3]"),

        [ap("mainTanhL[1]")] = ap("inputGMulL[1]"),
        [ap("inputGateL[1]")] = ap("inputGMulL[2]"),
        
        [ap("inputCDup[3]")] = ap("forgetGMulL[1]"),
        [ap("forgetGateL[1]")] = ap("forgetGMulL[2]"),

        [ap("inputGMulL[1]")] = ap("mainCDup[1]"),
        [ap("forgetGMulL[1]")] = ap("mainCDup[2]"),

        [ap("inputXDup[4]")] = ap("outputGateL[1]"),
        [ap("inputHDup[4]")] = ap("outputGateL[2]"),
        [ap("mainCDup[3]")] = ap("outputGateL[3]"),

        [ap("mainCDup[2]")] = "<output>[2]",
        [ap("mainCDup[1]")] = ap("outputTanhL[1]"),
        
        [ap("outputTanhL[1]")] = ap("outputGMulL[1]"),
        [ap("outputGateL[1]")] = ap("outputGMulL[2]"),

        [ap("outputGMulL[1]")] = "<output>[1]",
    }
    self.dag = nerv.DAGLayer(self.id, global_conf,
                                {dim_in = self.dim_in,
                                 dim_out = self.dim_out,
                                 sub_layers = self.lrepo,
                                 connections = connections})
    
    self:check_dim_len(3, 2) -- x, h, c and h, c
end

function LSTMLayer:bind_params()
    local pr = layer_conf.pr
    if pr == nil then
        pr = nerv.ParamRepo({}, self.loc_type)
    end
    self.lrepo:rebind(pr)
end

function LSTMLayer:init(batch_size, chunk_size)
    self.dag:init(batch_size, chunk_size)
end

function LSTMLayer:batch_resize(batch_size, chunk_size)
    self.dag:batch_resize(batch_size, chunk_size)
end

function LSTMLayer:update(bp_err, input, output, t)
    self.dag:update(bp_err, input, output, t)
end

function LSTMLayer:propagate(input, output, t)
    self.dag:propagate(input, output, t)
end

function LSTMLayer:back_propagate(bp_err, next_bp_err, input, output, t)
    self.dag:back_propagate(bp_err, next_bp_err, input, output, t)
end

function LSTMLayer:get_params()
    return self.dag:get_params()
end