aboutsummaryrefslogblamecommitdiff
path: root/include/salticidae/conn.h
blob: ceec176612b3ecb03aca731330a60b646f482418 (plain) (tree)


























                                                                                  
                  

                  







                        


                  


                            
                             


                               
                              
 

                      
                                             
                
                 

               
                                                     
                                
                                                                          
                                                                      
                                                                                 
                                                       
                
                        



                                                                      
                                                

          
                  
                             
                        
                             
               
                       
                        
                                   

                     
                                    
                              
 

                                
                                

                                            




                                                                
                                     
    






                                                       
                                               
 
                                   
 


                                                                     
                              
 
               
                                                   

                                                             

                                    

                         
                                                                            

         
                                        









                                    
                                     
                                                        
                                                                     
                                                  
                                                    
                                                                  
 

                                                                              

                                                                             

         
                  


                                                                             
                                                 
                                 
                                                                    
                                  
                                                            
                                     
      
 





                                                                              


                                                                                      
 
 
            


                                     
                                

                          

                            
              
                                  

            
                            
                              
 
                                 
                      
                                         
                                                                 
 
                                                          
                                                                              







                                                                                         
                   
             
           

     

                        
                         
                           

                                  
                                                         

               
                                                          
 
                                                                              






                                                                 





                                                                  
                                                                
                                                                            

                                                           
                     


                                                                                   
                     














                                                                              











                                                                                
                                                                                  





                                                                                                  
                                                                     
                                
                                                                     



                                                                                     



                                                                           
               

         

                                  
                     
                                                                          



                                                    
                                                  


                                                        
                                           

      


                                         
                     
 
                                 
                                        
                                      
 
              

                                         
                                                                
                                                                  
                                               

           
 
            
 
                             











                                               
     
 
           






                                    
                               


                                   

                              

                                           





                                    
                        
                               


                               
                               


                                            



















                                               




                                          




                                    





























                                                      


                                                           
                   


                                                             
                                                   

                                           
                          
                                     
                             


                                       







                                                                
                                                                                        
                                          
                                                                     
         
                                 
                                      
                                        

                                            
                                    


                                                                      

                                                    
               

                                    















                                                                      
     
 
                           
 


                                        
                  
                                 


                                                       
                         

     
                         

                                      
                                                       


                                       
                              
                                            

                                     
                                            
                                           




                                     
                              
         



                       




                             

     
                                           

                                                               
         

                                                                              







                                                                   
                       

                                                               
         

            
                                                                       




                                                            
               



                           


                                                                          
                                      










                                                                             
     


                                                    
 


                                                      
                                        
                                                                   




                                                        

           




      

      
/**
 * Copyright (c) 2018 Cornell University.
 *
 * Author: Ted Yin <[email protected]>
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of
 * this software and associated documentation files (the "Software"), to deal in
 * the Software without restriction, including without limitation the rights to
 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
 * of the Software, and to permit persons to whom the Software is furnished to do
 * so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#ifndef _SALTICIDAE_CONN_H
#define _SALTICIDAE_CONN_H

#ifdef __cplusplus
#include <cassert>
#include <cstdint>
#include <arpa/inet.h>
#include <unistd.h>

#include <string>
#include <unordered_map>
#include <list>
#include <algorithm>
#include <exception>
#include <mutex>
#include <thread>
#include <fcntl.h>

#include "salticidae/type.h"
#include "salticidae/ref.h"
#include "salticidae/event.h"
#include "salticidae/util.h"
#include "salticidae/netaddr.h"
#include "salticidae/msg.h"
#include "salticidae/buffer.h"

namespace salticidae {

/** Abstraction for connection management. */
class ConnPool {
    class Worker;
    public:
    class Conn;
    /** The handle to a bi-directional connection. */
    using conn_t = ArcObj<Conn>;
    /** The type of callback invoked when connection status is changed. */
    using conn_callback_t = std::function<bool(const conn_t &, bool)>;
    using error_callback_t = std::function<void(const std::exception_ptr, bool)>;
    /** Abstraction for a bi-directional connection. */
    class Conn {
        friend ConnPool;
        public:
        enum ConnMode {
            ACTIVE, /**< the connection is established by connect() */
            PASSIVE, /**< the connection is established by accept() */
            DEAD, /**< the connection is dead */
        };
    
        protected:
        size_t seg_buff_size;
        conn_t self_ref;
        std::mutex ref_mlock;
        int fd;
        Worker *worker;
        ConnPool *cpool;
        std::atomic<ConnMode> mode;
        NetAddr addr;

        MPSCWriteBuffer send_buffer;
        SegBuffer recv_buffer;

        TimedFdEvent ev_connect;
        FdEvent ev_socket;
        TimerEvent ev_send_wait;
        /** does not need to wait if true */
        bool ready_send;

        typedef void (socket_io_func)(const conn_t &, int, int);
        socket_io_func *send_data_func;
        socket_io_func *recv_data_func;
        BoxObj<TLS> tls;
        BoxObj<const X509> peer_cert;
    
        static socket_io_func _recv_data;
        static socket_io_func _send_data;

        static socket_io_func _recv_data_tls;
        static socket_io_func _send_data_tls;
        static socket_io_func _recv_data_tls_handshake;
        static socket_io_func _send_data_tls_handshake;
        static socket_io_func _recv_data_dummy;

        void conn_server(int, int);

        /** Terminate the connection (from the worker thread). */
        void worker_terminate();
        /** Terminate the connection (from the dispatcher thread). */
        void disp_terminate();

        public:
        Conn(): worker(nullptr), ready_send(false),
            send_data_func(nullptr), recv_data_func(nullptr),
            tls(nullptr), peer_cert(nullptr) {}
        Conn(const Conn &) = delete;
        Conn(Conn &&other) = delete;
    
        virtual ~Conn() {
            SALTICIDAE_LOG_INFO("destroyed %s", std::string(*this).c_str());
        }

        /** Get the handle to itself. */
        conn_t self() {
            mutex_lg_t _(ref_mlock);
            return self_ref;
        }

        void release_self() {
            mutex_lg_t _(ref_mlock);
            self_ref = nullptr;
        }

        operator std::string() const;
        const NetAddr &get_addr() const { return addr; }
        const X509 *get_peer_cert() const { return peer_cert.get(); }
        ConnMode get_mode() const { return mode; }
        ConnPool *get_pool() const { return cpool; }
        MPSCWriteBuffer &get_send_buffer() { return send_buffer; }

        /** Write data to the connection (non-blocking). The data will be sent
         * whenever I/O is available. */
        bool write(bytearray_t &&data) {
            return send_buffer.push(std::move(data), !cpool->queue_capacity);
        }

        protected:
        /** Close the IO and clear all on-going or planned events. Remove the
         * connection from a Worker. */
        virtual void stop();
        /** Called when new data is available. */
        virtual void on_read() {}
        /** Called when the underlying connection is established. */
        virtual void on_setup() {}
        /** Called when the underlying connection breaks. */
        virtual void on_teardown() {}
    };

    protected:
    EventContext ec;
    EventContext disp_ec;
    ThreadCall* disp_tcall;
    /** Should be implemented by derived class to return a new Conn object. */
    virtual Conn *create_conn() = 0;
    using worker_error_callback_t = std::function<void(const std::exception_ptr err)>;
    worker_error_callback_t disp_error_cb;
    worker_error_callback_t worker_error_cb;


    private:
    const int max_listen_backlog;
    const double conn_server_timeout;
    const size_t seg_buff_size;
    const size_t queue_capacity;
    const bool enable_tls;
    tls_context_t tls_ctx;

    /* owned by user loop */
    protected:
    BoxObj<ThreadCall> user_tcall;

    private:
    conn_callback_t conn_cb;
    error_callback_t error_cb;

    /* owned by the dispatcher */
    FdEvent ev_listen;
    std::unordered_map<int, conn_t> pool;
    int listen_fd;  /**< for accepting new network connections */

    void update_conn(const conn_t &conn, bool connected) {
        user_tcall->async_call([this, conn, connected](ThreadCall::Handle &) {
            bool ret = !conn_cb || conn_cb(conn, connected);
            if (enable_tls && connected)
            {
                conn->worker->get_tcall()->async_call([conn, ret](ThreadCall::Handle &) {
                    if (ret)
                        conn->recv_data_func = Conn::_recv_data_tls;
                    else
                        conn->worker_terminate();
                });
            }
        });
    }

    class Worker {
        EventContext ec;
        ThreadCall tcall;
        std::thread handle;
        bool disp_flag;
        std::atomic<size_t> nconn;
        ConnPool::worker_error_callback_t on_fatal_error;

        public:
        Worker(): tcall(ec), disp_flag(false), nconn(0) {}

        void set_error_callback(ConnPool::worker_error_callback_t _on_error) {
            on_fatal_error = std::move(_on_error);
        }

        void error_callback(const std::exception_ptr err) const {
            on_fatal_error(err);
        }

        /* the following functions are called by the dispatcher */
        void start() {
            handle = std::thread([this]() { ec.dispatch(); });
        }

        void feed(const conn_t &conn, int client_fd) {
            /* the caller should finalize all the preparation */
            tcall.async_call([this, conn, client_fd](ThreadCall::Handle &) {
                try {
                    if (conn->mode == Conn::ConnMode::DEAD)
                    {
                        SALTICIDAE_LOG_INFO("worker %x discarding dead connection",
                            std::this_thread::get_id());
                        return;
                    }
                    auto cpool = conn->cpool;
                    if (cpool->enable_tls)
                    {
                        conn->tls = new TLS(
                                cpool->tls_ctx, client_fd,
                                conn->mode == Conn::ConnMode::PASSIVE);
                        conn->send_data_func = Conn::_send_data_tls_handshake;
                        conn->recv_data_func = Conn::_recv_data_tls_handshake;
                    }
                    else
                    {
                        conn->send_data_func = Conn::_send_data;
                        conn->recv_data_func = Conn::_recv_data;
                        cpool->update_conn(conn, true);
                    }
                    assert(conn->fd != -1);
                    SALTICIDAE_LOG_INFO("worker %x got %s",
                            std::this_thread::get_id(),
                            std::string(*conn).c_str());
                    conn->get_send_buffer()
                            .get_queue()
                            .reg_handler(this->ec, [conn, client_fd]
                                        (MPSCWriteBuffer::queue_t &) {
                        if (conn->ready_send)
                        {
                            conn->ev_socket.del();
                            conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE);
                            conn->send_data_func(conn, client_fd, FdEvent::WRITE);
                        }
                        return false;
                    });
                    conn->ev_socket = FdEvent(ec, client_fd, [this, conn=conn](int fd, int what) {
                        try {
                            if (what & FdEvent::READ)
                                conn->recv_data_func(conn, fd, what);
                            else
                                conn->send_data_func(conn, fd, what);
                        } catch (...) {
                            conn->cpool->recoverable_error(std::current_exception());
                            conn->worker_terminate();
                        }
                    });
                    conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE);
                    nconn++;
                } catch (...) { on_fatal_error(std::current_exception()); }
            });
        }

        void unfeed() { nconn--; }

        void stop() {
            tcall.async_call([this](ThreadCall::Handle &) { ec.stop(); });
        }

        std::thread &get_handle() { return handle; }
        const EventContext &get_ec() { return ec; }
        ThreadCall *get_tcall() { return &tcall; }
        void set_dispatcher() { disp_flag = true; }
        bool is_dispatcher() const { return disp_flag; }
        size_t get_nconn() { return nconn; }
        void stop_tcall() { tcall.stop(); }
    };

    /* related to workers */
    size_t nworker;
    salticidae::BoxObj<Worker[]> workers;
    int system_state;

    void accept_client(int, int);
    conn_t add_conn(const conn_t &conn);
    void del_conn(const conn_t &conn);

    protected:
    conn_t _connect(const NetAddr &addr);
    void _listen(NetAddr listen_addr);
    void recoverable_error(const std::exception_ptr err) const {
        user_tcall->async_call([this, err](ThreadCall::Handle &) {
            if (error_cb) error_cb(err, false);
        });
    }

    private:

    Worker &select_worker() {
        size_t idx = 0;
        size_t best = workers[idx].get_nconn();
        for (size_t i = 0; i < nworker; i++)
        {
            size_t t = workers[i].get_nconn();
            if (t < best)
            {
                best = t;
                idx = i;
            }
        }
        return workers[idx];
    }

    public:

    class Config {
        friend ConnPool;
        int _max_listen_backlog;
        double _conn_server_timeout;
        size_t _seg_buff_size;
        size_t _nworker;
        size_t _queue_capacity;
        bool _enable_tls;
        std::string _tls_cert_file;
        std::string _tls_key_file;
        RcObj<X509> _tls_cert;
        RcObj<PKey> _tls_key;
        bool _tls_skip_ca_check;
        SSL_verify_cb _tls_verify_callback;

        public:
        Config():
            _max_listen_backlog(10),
            _conn_server_timeout(2),
            _seg_buff_size(4096),
            _nworker(1),
            _queue_capacity(0),
            _enable_tls(false),
            _tls_cert_file(""),
            _tls_key_file(""),
            _tls_cert(nullptr),
            _tls_key(nullptr),
            _tls_skip_ca_check(true),
            _tls_verify_callback(nullptr) {}

        Config &max_listen_backlog(int x) {
            _max_listen_backlog = x;
            return *this;
        }

        Config &conn_server_timeout(double x) {
            _conn_server_timeout = x;
            return *this;
        }

        Config &seg_buff_size(size_t x) {
            _seg_buff_size = x;
            return *this;
        }

        Config &nworker(size_t x) {
            _nworker = std::max((size_t)1, x);
            return *this;
        }

        Config &queue_capacity(size_t x) {
            _queue_capacity = x;
            return *this;
        }

        Config &enable_tls(bool x) {
            _enable_tls = x;
            return *this;
        }

        Config &tls_cert_file(const std::string &x) {
            _tls_cert_file = x;
            return *this;
        }

        Config &tls_key_file(const std::string &x) {
            _tls_key_file = x;
            return *this;
        }

        Config &tls_cert(X509 *x) {
            _tls_cert = x;
            return *this;
        }

        Config &tls_key(PKey *x) {
            _tls_key = x;
            return *this;
        }

        Config &tls_skip_ca_check(bool *x) {
            _tls_skip_ca_check = x;
            return *this;
        }

        Config &tls_verify_callback(SSL_verify_cb x) {
            _tls_verify_callback = x;
            return *this;
        }
    };

    ConnPool(const EventContext &ec, const Config &config):
            ec(ec),
            max_listen_backlog(config._max_listen_backlog),
            conn_server_timeout(config._conn_server_timeout),
            seg_buff_size(config._seg_buff_size),
            queue_capacity(config._queue_capacity),
            enable_tls(config._enable_tls),
            tls_ctx(nullptr),
            listen_fd(-1),
            nworker(config._nworker),
            system_state(0) {
        if (enable_tls)
        {
            tls_ctx = new TLSContext();
            if (config._tls_cert)
                tls_ctx->use_cert(*config._tls_cert);
            else
                tls_ctx->use_cert_file(config._tls_cert_file);
            if (config._tls_key)
                tls_ctx->use_privkey(*config._tls_key);
            else
                tls_ctx->use_privkey_file(config._tls_key_file);
            tls_ctx->set_verify(config._tls_skip_ca_check, config._tls_verify_callback);
            if (!tls_ctx->check_privkey())
                throw SalticidaeError(SALTI_ERROR_TLS_KEY_NOT_MATCH);
        }
        signal(SIGPIPE, SIG_IGN);
        workers = new Worker[nworker];
        user_tcall = new ThreadCall(ec);
        disp_ec = workers[0].get_ec();
        disp_tcall = workers[0].get_tcall();
        workers[0].set_dispatcher();
        disp_error_cb = [this](const std::exception_ptr err) {
            user_tcall->async_call([this, err](ThreadCall::Handle &) {
                stop_workers();
                std::rethrow_exception(err);
                //if (error_cb) error_cb(err, true);
            });
            disp_ec.stop();
            workers[0].stop_tcall();
        };

        worker_error_cb = [this](const std::exception_ptr err) {
            disp_tcall->async_call([this, err](ThreadCall::Handle &) {
                // forward to the dispatcher
                disp_error_cb(err);
            });
        };
        for (size_t i = 0; i < nworker; i++)
        {
            auto &worker = workers[i];
            if (worker.is_dispatcher())
                worker.set_error_callback(disp_error_cb);
            else
                worker.set_error_callback(worker_error_cb);
        }
    }

    ~ConnPool() { stop(); }

    ConnPool(const ConnPool &) = delete;
    ConnPool(ConnPool &&) = delete;

    void start() {
        if (system_state) return;
        SALTICIDAE_LOG_INFO("starting all threads...");
        for (size_t i = 0; i < nworker; i++)
            workers[i].start();
        system_state = 1;
    }

    void stop_workers() {
        if (system_state != 1) return;
        system_state = 2;
        SALTICIDAE_LOG_INFO("stopping all threads...");
        /* stop the dispatcher */
        workers[0].stop();
        workers[0].get_handle().join();
        /* stop all workers */
        for (size_t i = 1; i < nworker; i++)
            workers[i].stop();
        /* join all worker threads */
        for (size_t i = 1; i < nworker; i++)
            workers[i].get_handle().join();
        for (auto it: pool)
        {
            conn_t conn = it.second;
            conn->stop();
            conn->self_ref = nullptr;
            ::close(conn->fd);
        }
    }

    void stop() {
        stop_workers();
        if (listen_fd != -1)
        {
            close(listen_fd);
            listen_fd = -1;
        }
    }

    /** Actively connect to remote addr. */
    conn_t connect(const NetAddr &addr, bool blocking = true) {
        if (blocking)
        {
            auto ret = *(static_cast<std::pair<conn_t, std::exception_ptr> *>(
                        disp_tcall->call([this, addr](ThreadCall::Handle &h) {
                conn_t conn;
                std::exception_ptr err = nullptr;
                try {
                    conn = _connect(addr);
                } catch (...) {
                    err = std::current_exception();
                }
                h.set_result(std::make_pair(std::move(conn), err));
            }).get()));
            if (ret.second) std::rethrow_exception(ret.second);
            return std::move(ret.first);
        }
        else
        {
            disp_tcall->async_call([this, addr](ThreadCall::Handle &) {
                try {
                    _connect(addr);
                } catch (...) {
                    disp_error_cb(std::current_exception());
                }
            });
            return nullptr;
        }
    }

    /** Listen for passive connections (connection initiated from remote).
     * Does not need to be called if do not want to accept any passive
     * connections. */
    void listen(NetAddr listen_addr) {
        auto ret = *(static_cast<std::exception_ptr *>(
                disp_tcall->call([this, listen_addr](ThreadCall::Handle &h) {
            std::exception_ptr err = nullptr;
            try {
                _listen(listen_addr);
            } catch (...) {
                err = std::current_exception();
            }
            h.set_result(err);
        }).get()));
        if (ret) std::rethrow_exception(ret);
    }

    template<typename Func>
    void reg_conn_handler(Func cb) { conn_cb = cb; }

    template<typename Func>
    void reg_error_handler(Func cb) { error_cb = cb; }

    void terminate(const conn_t &conn) {
        disp_tcall->async_call([this, conn](ThreadCall::Handle &) {
            try {
                conn->disp_terminate();
            } catch (...) {
                disp_error_cb(std::current_exception());
            }
        });
    }
};

}

#endif

#endif