aboutsummaryrefslogblamecommitdiff
path: root/nerv/tnn/layer_dag_t.lua
blob: b651f4ebad5bc49d97e5fe4de47cf9b8555b4d62 (plain) (tree)
1
2
3
4
5



                                                             
                                                                     


















                                                             
                           




























                                                                         





                                                                                    






                                                                                  





                                                                                    







                                                                                      






























































                                                                                                    
                                                 

                             
                                                                                     








































                                                              
                                                            



                                                                       
                                                             








































































                                                                                                                             


                                      










                                                      


                                       







                                            


                                          







                                                  


                                           










                                                   
                                       
                                                                             


       
                                              







                                 
                                                       
                                                                   













                                                                        
                                                                                                         




































                                                            
local DAGLayerT = nerv.class("nerv.DAGLayerT", "nerv.LayerT")

local function parse_id(str)
    local id, port, _
    _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
    if id == nil or port == nil then
        _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
        if not (id == "<input>" or id == "<output>") then
            nerv.error("wrong format of connection id")
        end
    end
    port = tonumber(port)
    return id, port
end

local function discover(id, layers, layer_repo)
    local ref = layers[id]
    if id == "<input>" or id == "<output>" then
        return nil
    end
    if ref == nil then
        local layer = layer_repo:get_layer(id)
        local dim_in, dim_out = layer:get_dim()
        ref = {
            id = layer.id, 
            layer = layer,
            inputs = {},
            outputs = {},
            err_inputs = {},
            err_outputs = {},
            next_layers = {},
            input_len = #dim_in,
            output_len = #dim_out,
            in_deg = 0,
            visited = false
        }
        layers[id] = ref
    end
    return ref
end

function DAGLayerT:__init(id, global_conf, layer_conf)
    local layers = {}
    local inputs = {}
    local outputs = {}
    local dim_in = layer_conf.dim_in
    local dim_out = layer_conf.dim_out
    local parsed_conn = {}
    for from, to in pairs(layer_conf.connections) do
        local id_from, port_from = parse_id(from)
        local id_to, port_to = parse_id(to)
        local ref_from = discover(id_from, layers, layer_conf.sub_layers)
        local ref_to = discover(id_to, layers, layer_conf.sub_layers)
        local input_dim, output_dim, _
        if id_from == "<input>" then
            input_dim, _ = ref_to.layer:get_dim()
            if dim_in[port_from] ~= input_dim[port_to] then
                nerv.error("mismatching data dimension between %s and %s", from, to)
            end
            inputs[port_from] = {ref_to, port_to}
            if ref_to.inputs[1] == nil then
               ref_to.inputs[1] = {}
            end
            if ref_to.inputs[1][port_to] ~= nil then
                nerv.error("port(%d) for layer(%s) already attached", port_to, to)
            end
            ref_to.inputs[1][port_to] = inputs -- just a place holder
        elseif id_to == "<output>" then
            _, output_dim = ref_from.layer:get_dim()
            if output_dim[port_from] ~= dim_out[port_to] then
                nerv.error("mismatching data dimension between %s and %s", from, to)
            end
            outputs[port_to] = {ref_from, port_from}
            if ref_from.outputs[1] == nil then
                ref_from.outputs[1] = {}
            end
            if ref_from.outputs[1][port_from] ~= nil then
                nerv.error("port(%d) for layer(%s) already attached", port_from, from)
            end
            ref_from.outputs[1] = {}
            ref_from.outputs[1][port_from] = outputs -- just a place holder
        else
            _, output_dim = ref_from.layer:get_dim()
            input_dim, _ = ref_to.layer:get_dim()
            if output_dim[port_from] ~= input_dim[port_to] then
                nerv.error("mismatching data dimension between %s and %s", from, to)
            end

            table.insert(parsed_conn,
            {{ref_from, port_from}, {ref_to, port_to}})
            table.insert(ref_from.next_layers, ref_to) -- add edge
            ref_to.in_deg = ref_to.in_deg + 1          -- increase the in-degree of the target layer
        end
    end

    -- topology sort
    local queue = {}
    local l = 1
    local r = 1
    for id, ref in pairs(layers) do
        if ref.in_deg == 0 then
            table.insert(queue, ref)
            nerv.info("adding source layer: %s", id)
            r = r + 1
        end
    end
    if l == r then
        nerv.error("loop detected")
    end
    while l < r do
        local cur = queue[l]
        cur.visited = true
        l = l + 1
        for _, nl in pairs(cur.next_layers) do
            nl.in_deg = nl.in_deg - 1 
            if nl.in_deg == 0 then
                table.insert(queue, nl)
                r = r + 1
            end
        end
    end
    for i = 1, #queue do
        nerv.info("enqueued layer: %s %s", queue[i].layer, queue[i].layer.id)
    end

    for id, ref in pairs(layers) do
        -- check wether the graph is connected
        if ref.visited == false then
            nerv.warning("layer %s is ignored", id)
        end
    end

    self.layers = layers
    self.inputs = inputs
    self.outputs = outputs
    self.id = id
    self.dim_in = dim_in
    self.dim_out = dim_out
    self.parsed_conn = parsed_conn
    self.queue = queue
    self.gconf = global_conf
end

function DAGLayerT:init(batch_size, chunk_size)
    nerv.info("initing DAGLayerT %s...", self.id)
    if chunk_size == nil then
        chunk_size = 1
        nerv.info("(Initing DAGLayerT) chunk_size is nil, setting it to default 1\n")
    end

    self.chunk_size = chunk_size

    for i, conn in ipairs(self.parsed_conn) do
        local _, output_dim
        local ref_from, port_from, ref_to, port_to
        ref_from, port_from = unpack(conn[1])
        ref_to, port_to = unpack(conn[2])
        _, output_dim = ref_from.layer:get_dim()
        local dim = 1
        if output_dim[port_from] > 0 then
            dim = output_dim[port_from]
        end

        for t = 1,