aboutsummaryrefslogblamecommitdiff
path: root/nerv/nn/layer_dag.lua
blob: 6ad7ae9cd36c7554d6bfe71c3f1ee26154a0b095 (plain) (tree)







































                                                                    
                                                     




                                      
                          































                                                                                    
 
                                     
                                                       



                                                                                                    
 
                    


                    
                                   

                                    
                                                    


















                                              
                                                                             
       

                                   

                                              
                                                   
           




                          
                




                                  




                                             

   
                                  





                                                  



                                         
                                                  








                                                

                                        
                                                                       



                                         
                                                                        


                                
                                  












                                                               










                                                                                              
                                                                        














                                                    
                                   
                              


                                                     





                                       
                                     
                               


                                                      





                                        
                                        






                                          
                                              






                                                
                                               


                               
                      
                                       
                              



                                                                 
                                          

                            
                     

                                 
                              
                                                          
       
              

   
                                                                    





                                     
                              
                                                                                          

       

                              
                          
                                       
                                                         
       
                                            
   











                                                         
                                 













                                                           
local DAGLayer = nerv.class("nerv.DAGLayer", "nerv.Layer")

local function parse_id(str)
    local id, port, _
    _, _, id, port = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%]")
    if id == nil or port == nil then
        _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
        if not (id == "<input>" or id == "<output>") then
            nerv.error("wrong format of connection id")
        end
    end
    port = tonumber(port)
    return id, port
end

local function discover(id, layers, layer_repo)
    local ref = layers[id]
    if id == "<input>" or id == "<output>" then
        return nil
    end
    if ref == nil then
        local layer = layer_repo:get_layer(id)
        local dim_in, dim_out = layer:get_dim()
        ref = {
            layer = layer,
            inputs = {},
            outputs = {},
            err_inputs = {},
            err_outputs = {},
            next_layers = {},
            input_len = #dim_in,
            output_len = #dim_out,
            in_deg = 0,
            visited = false
        }
        layers[id] = ref
    end
    return ref
end

function DAGLayer:__init(id, global_conf, layer_conf)
    local layers = {}
    local inputs = {}
    local outputs = {}
    local dim_in = layer_conf.dim_in
    local dim_out = layer_conf.dim_out
    local parsed_conn = {}
    for from, to in pairs(layer_conf.connections) do
        local id_from, port_from =