diff options
Diffstat (limited to 'include/salticidae/conn.h')
-rw-r--r-- | include/salticidae/conn.h | 392 |
1 files changed, 259 insertions, 133 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index f290e3d..26d19fe 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -36,6 +36,9 @@ #include <list> #include <algorithm> #include <exception> +#include <mutex> +#include <thread> +#include <fcntl.h> #include "salticidae/type.h" #include "salticidae/ref.h" @@ -43,110 +46,10 @@ #include "salticidae/util.h" #include "salticidae/netaddr.h" #include "salticidae/msg.h" +#include "salticidae/buffer.h" namespace salticidae { -class SegBuffer { - struct buffer_entry_t { - bytearray_t data; - bytearray_t::iterator offset; - buffer_entry_t(bytearray_t &&_data): data(std::move(_data)) { - offset = data.begin(); - } - - buffer_entry_t(buffer_entry_t &&other) { - size_t _offset = other.offset - other.data.begin(); - data = std::move(other.data); - offset = data.begin() + _offset; - } - - buffer_entry_t(const buffer_entry_t &other): data(other.data) { - offset = data.begin() + (other.offset - other.data.begin()); - } - - size_t length() const { return data.end() - offset; } - }; - - std::list<buffer_entry_t> buffer; - size_t _size; - - public: - SegBuffer(): _size(0) {} - ~SegBuffer() { clear(); } - - void swap(SegBuffer &other) { - std::swap(buffer, other.buffer); - std::swap(_size, other._size); - } - - SegBuffer(const SegBuffer &other): - buffer(other.buffer), _size(other._size) {} - - SegBuffer(SegBuffer &&other): - buffer(std::move(other.buffer)), _size(other._size) { - other._size = 0; - } - - SegBuffer &operator=(SegBuffer &&other) { - if (this != &other) - { - SegBuffer tmp(std::move(other)); - tmp.swap(*this); - } - return *this; - } - - SegBuffer &operator=(const SegBuffer &other) { - if (this != &other) - { - SegBuffer tmp(other); - tmp.swap(*this); - } - return *this; - } - - void rewind(bytearray_t &&data) { - _size += data.size(); - buffer.push_front(buffer_entry_t(std::move(data))); - } - - void push(bytearray_t &&data) { - _size += data.size(); - buffer.push_back(buffer_entry_t(std::move(data))); - } - - bytearray_t move_pop() { - auto res = std::move(buffer.front().data); - buffer.pop_front(); - return std::move(res); - } - - bytearray_t pop(size_t len) { - bytearray_t res; - auto i = buffer.begin(); - while (len && i != buffer.end()) - { - size_t copy_len = std::min(i->length(), len); - res.insert(res.end(), i->offset, i->offset + copy_len); - i->offset += copy_len; - len -= copy_len; - if (i->offset == i->data.end()) - i++; - } - buffer.erase(buffer.begin(), i); - _size -= res.size(); - return std::move(res); - } - - size_t size() const { return _size; } - bool empty() const { return buffer.empty(); } - - void clear() { - buffer.clear(); - _size = 0; - } -}; - struct ConnPoolError: public SalticidaeError { using SalticidaeError::SalticidaeError; }; @@ -156,7 +59,7 @@ class ConnPool { public: class Conn; /** The handle to a bi-directional connection. */ - using conn_t = RcObj<Conn>; + using conn_t = ArcObj<Conn>; /** The type of callback invoked when connection status is changed. */ using conn_callback_t = std::function<void(Conn &)>; @@ -177,7 +80,8 @@ class ConnPool { ConnMode mode; NetAddr addr; - SegBuffer send_buffer; + // TODO: send_buffer should be a thread-safe mpsc queue + MPSCWriteBuffer send_buffer; SegBuffer recv_buffer; Event ev_read; @@ -190,6 +94,9 @@ class ConnPool { void send_data(evutil_socket_t, short); void conn_server(evutil_socket_t, short); + /** Terminate the connection. */ + void terminate(); + public: Conn(): ready_send(false) {} Conn(const Conn &) = delete; @@ -206,7 +113,8 @@ class ConnPool { const NetAddr &get_addr() const { return addr; } ConnMode get_mode() const { return mode; } ConnPool *get_pool() const { return cpool; } - SegBuffer &read() { return recv_buffer; } + SegBuffer &get_recv_buffer() { return recv_buffer; } + MPSCWriteBuffer &get_send_buffer() { return send_buffer; } /** Set the buffer size used for send/receive data. */ void set_seg_buff_size(size_t size) { seg_buff_size = size; } @@ -214,17 +122,12 @@ class ConnPool { * whenever I/O is available. */ void write(bytearray_t &&data) { send_buffer.push(std::move(data)); - if (ready_send) - send_data(fd, EV_WRITE); - } - - /** Move the send buffer from the other (old) connection. */ - void move_send_buffer(conn_t other) { - send_buffer = std::move(other->send_buffer); } - /** Terminate the connection. */ - void terminate(); + ///** Move the send buffer from the other (old) connection. */ + //void move_send_buffer(conn_t other) { + // send_buffer = std::move(other->send_buffer); + //} protected: /** Close the IO and clear all on-going or planned events. */ @@ -238,35 +141,196 @@ class ConnPool { } /** Called when new data is available. */ - virtual void on_read() { - if (cpool->read_cb) cpool->read_cb(*this); - } + virtual void on_read() {} /** Called when the underlying connection is established. */ virtual void on_setup() { - if (cpool->conn_cb) cpool->conn_cb(*this); + cpool->update_conn(self()); } /** Called when the underlying connection breaks. */ virtual void on_teardown() { - if (cpool->conn_cb) cpool->conn_cb(*this); + cpool->update_conn(self()); } }; - + private: - int max_listen_backlog; - double conn_server_timeout; - size_t seg_buff_size; - conn_callback_t read_cb; - conn_callback_t conn_cb; + const int max_listen_backlog; + const double conn_server_timeout; + const size_t seg_buff_size; + + /* owned by user loop */ + int mlisten_fd[2]; /**< for connection events sent to the user loop */ + Event ev_mlisten; + conn_callback_t conn_cb; + /* owned by the dispatcher */ std::unordered_map<int, conn_t> pool; - int listen_fd; + int listen_fd; /**< for accepting new network connections */ + int dlisten_fd[2]; /**< for control command sent to the dispatcher */ Event ev_listen; + Event ev_dlisten; + std::mutex cp_mlock; + + void update_conn(const conn_t &conn) { + auto ptr = new conn_t(conn); + write(mlisten_fd[1], &ptr, sizeof(ptr)); + } + + struct Worker; + class WorkerFeed; + + class WorkerCmd { + public: + virtual ~WorkerCmd() = default; + virtual void exec(Worker *worker) = 0; + }; + + class Worker { + EventContext ec; + Event ev_ctl; + int ctl_fd[2]; /**< for control messages from dispatcher */ + std::thread handle; + + public: + Worker() { + if (pipe2(ctl_fd, O_NONBLOCK)) + throw ConnPoolError(std::string("failed to create worker pipe")); + ev_ctl = Event(ec, ctl_fd[0], EV_READ | EV_PERSIST, [this](int fd, short) { + WorkerCmd *dcmd; + read(fd, &dcmd, sizeof(dcmd)); + dcmd->exec(this); + delete dcmd; + }); + ev_ctl.add(); + } + + ~Worker() { + close(ctl_fd[0]); + close(ctl_fd[1]); + } + + /* the following functions are called by the dispatcher */ + void start() { + handle = std::thread([this]() { ec.dispatch(); }); + } + + void feed(const conn_t &conn, int client_fd) { + auto dcmd = new WorkerFeed(conn, client_fd); + write(ctl_fd[1], &dcmd, sizeof(dcmd)); + } + + void stop() { + auto dcmd = new WorkerStop(); + write(ctl_fd[1], &dcmd, sizeof(dcmd)); + } + + std::thread &get_handle() { return handle; } + const EventContext &get_ec() { return ec; } + }; + + class WorkerFeed: public WorkerCmd { + conn_t conn; + int client_fd; + + public: + WorkerFeed(const conn_t &conn, int client_fd): + conn(conn), client_fd(client_fd) {} + void exec(Worker *worker) override { + SALTICIDAE_LOG_INFO("worker %x got %s", + std::this_thread::get_id(), + std::string(*conn).c_str()); + auto &ec = worker->get_ec(); + conn->get_send_buffer() + .get_queue() + .reg_handler(ec, [conn=this->conn, + client_fd=this->client_fd](MPSCWriteBuffer::queue_t &) { + if (conn->ready_send) + conn->send_data(client_fd, EV_WRITE); + return false; + }); + auto conn_ptr = conn.get(); + conn->ev_read = Event(ec, client_fd, EV_READ, + std::bind(&Conn::recv_data, conn_ptr, _1, _2)); + conn->ev_write = Event(ec, client_fd, EV_WRITE, + std::bind(&Conn::send_data, conn_ptr, _1, _2)); + conn->ev_read.add(); + conn->ev_write.add(); + } + }; + + class WorkerStop: public WorkerCmd { + public: + void exec(Worker *worker) override { worker->get_ec().stop(); } + }; + + /* related to workers */ + size_t nworker; + salticidae::BoxObj<Worker[]> workers; void accept_client(evutil_socket_t, short); - conn_t add_conn(conn_t conn); + conn_t add_conn(const conn_t &conn); + conn_t _connect(const NetAddr &addr); + void _listen(NetAddr listen_addr); + void _post_terminate(int fd); + + class DispatchCmd { + public: + virtual ~DispatchCmd() = default; + virtual void exec(ConnPool *cpool) = 0; + }; + + // TODO: the following two are untested + class DspListen: public DispatchCmd { + const NetAddr addr; + public: + DspListen(const NetAddr &addr): addr(addr) {} + void exec(ConnPool *cpool) override { + cpool->_listen(addr); + } + }; + + class DspConnect: public DispatchCmd { + const NetAddr addr; + public: + DspConnect(const NetAddr &addr): addr(addr) {} + void exec(ConnPool *cpool) override { + cpool->update_conn(cpool->_connect(addr)); + } + }; + + class DspPostTerm: public DispatchCmd { + int fd; + public: + DspPostTerm(int fd): fd(fd) {} + void exec(ConnPool *cpool) override { + cpool->_post_terminate(fd); + } + }; + + class DspMulticast: public DispatchCmd { + std::vector<conn_t> receivers; + bytearray_t data; + public: + DspMulticast(std::vector<conn_t> &&receivers, bytearray_t &&data): + receivers(std::move(receivers)), + data(std::move(data)) {} + void exec(ConnPool *) override { + for (auto &r: receivers) r->write(bytearray_t(data)); + } + }; + + void post_terminate(int fd) { + auto dcmd = new DspPostTerm(fd); + write(dlisten_fd[1], &dcmd, sizeof(dcmd)); + } + + Worker &select_worker() { + return workers[1]; + } protected: EventContext ec; + EventContext dispatcher_ec; + std::mutex dsp_ec_mlock; /** Should be implemented by derived class to return a new Conn object. */ virtual Conn *create_conn() = 0; @@ -274,29 +338,91 @@ class ConnPool { ConnPool(const EventContext &ec, int max_listen_backlog = 10, double conn_server_timeout = 2, - size_t seg_buff_size = 4096): - max_listen_backlog(max_listen_backlog), - conn_server_timeout(conn_server_timeout), - seg_buff_size(seg_buff_size), - ec(ec) {} + size_t seg_buff_size = 4096, + size_t nworker = 2): + max_listen_backlog(max_listen_backlog), + conn_server_timeout(conn_server_timeout), + seg_buff_size(seg_buff_size), + listen_fd(-1), + nworker(std::min((size_t)1, nworker)), + ec(ec) { + if (pipe2(mlisten_fd, O_NONBLOCK)) + throw ConnPoolError(std::string("failed to create main pipe")); + if (pipe2(dlisten_fd, O_NONBLOCK)) + throw ConnPoolError(std::string("failed to create dispatcher pipe")); + + ev_mlisten = Event(ec, mlisten_fd[0], EV_READ | EV_PERSIST, [this](int fd, short) { + conn_t *conn_ptr; + read(fd, &conn_ptr, sizeof(conn_ptr)); + if (conn_cb) + conn_cb(**conn_ptr); + delete conn_ptr; + }); + ev_mlisten.add(); + + workers = new Worker[nworker]; + dispatcher_ec = workers[0].get_ec(); + + ev_dlisten = Event(dispatcher_ec, dlisten_fd[0], EV_READ | EV_PERSIST, [this](int fd, short) { + DispatchCmd *dcmd; + read(fd, &dcmd, sizeof(dcmd)); + dcmd->exec(this); + delete dcmd; + }); + ev_dlisten.add(); + + SALTICIDAE_LOG_INFO("starting all threads..."); + for (size_t i = 0; i < nworker; i++) + workers[i].start(); + } ~ConnPool() { + /* stop all workers */ + for (size_t i = 0; i < nworker; i++) + workers[i].stop(); + /* join all worker threads */ + for (size_t i = 0; i < nworker; i++) + workers[i].get_handle().join(); for (auto it: pool) { conn_t conn = it.second; conn->on_close(); } + if (listen_fd != -1) close(listen_fd); + for (int i = 0; i < 2; i++) + { + close(mlisten_fd[i]); + close(dlisten_fd[i]); + } } ConnPool(const ConnPool &) = delete; ConnPool(ConnPool &&) = delete; /** Actively connect to remote addr. */ - conn_t connect(const NetAddr &addr); + conn_t connect(const NetAddr &addr, bool blocking = true) { + if (blocking) + return _connect(addr); + else + { + auto dcmd = new DspConnect(addr); + write(dlisten_fd[1], &dcmd, sizeof(dcmd)); + return nullptr; + } + } + /** Listen for passive connections (connection initiated from remote). * Does not need to be called if do not want to accept any passive * connections. */ - void listen(NetAddr listen_addr); + void listen(NetAddr listen_addr, bool blocking = true) { + if (blocking) + _listen(listen_addr); + else + { + auto dcmd = new DspListen(listen_addr); + write(dlisten_fd[1], &dcmd, sizeof(dcmd)); + } + } template<typename Func> void reg_conn_handler(Func cb) { conn_cb = cb; } |