diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/salticidae/conn.h | 63 | ||||
-rw-r--r-- | include/salticidae/event.h | 5 | ||||
-rw-r--r-- | include/salticidae/network.h | 79 |
3 files changed, 74 insertions, 73 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index ceec176..9e2408f 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -59,6 +59,7 @@ class ConnPool { using conn_t = ArcObj<Conn>; /** The type of callback invoked when connection status is changed. */ using conn_callback_t = std::function<bool(const conn_t &, bool)>; + /** The type of callback invoked when an error occured (during async execution). */ using error_callback_t = std::function<void(const std::exception_ptr, bool)>; /** Abstraction for a bi-directional connection. */ class Conn { @@ -71,9 +72,8 @@ class ConnPool { }; protected: + std::atomic<bool> terminated; size_t seg_buff_size; - conn_t self_ref; - std::mutex ref_mlock; int fd; Worker *worker; ConnPool *cpool; @@ -85,7 +85,6 @@ class ConnPool { TimedFdEvent ev_connect; FdEvent ev_socket; - TimerEvent ev_send_wait; /** does not need to wait if true */ bool ready_send; @@ -104,15 +103,8 @@ class ConnPool { static socket_io_func _send_data_tls_handshake; static socket_io_func _recv_data_dummy; - void conn_server(int, int); - - /** Terminate the connection (from the worker thread). */ - void worker_terminate(); - /** Terminate the connection (from the dispatcher thread). */ - void disp_terminate(); - public: - Conn(): worker(nullptr), ready_send(false), + Conn(): terminated(false), worker(nullptr), ready_send(false), send_data_func(nullptr), recv_data_func(nullptr), tls(nullptr), peer_cert(nullptr) {} Conn(const Conn &) = delete; @@ -122,15 +114,12 @@ class ConnPool { SALTICIDAE_LOG_INFO("destroyed %s", std::string(*this).c_str()); } - /** Get the handle to itself. */ - conn_t self() { - mutex_lg_t _(ref_mlock); - return self_ref; + bool is_terminated() { + return terminated.load(std::memory_order_acquire); } - void release_self() { - mutex_lg_t _(ref_mlock); - self_ref = nullptr; + bool set_terminated() { + return !terminated.exchange(true, std::memory_order_acq_rel); } operator std::string() const; @@ -150,12 +139,6 @@ class ConnPool { /** Close the IO and clear all on-going or planned events. Remove the * connection from a Worker. */ virtual void stop(); - /** Called when new data is available. */ - virtual void on_read() {} - /** Called when the underlying connection is established. */ - virtual void on_setup() {} - /** Called when the underlying connection breaks. */ - virtual void on_teardown() {} }; protected: @@ -168,6 +151,18 @@ class ConnPool { worker_error_callback_t disp_error_cb; worker_error_callback_t worker_error_cb; + /** Terminate the connection (from the worker thread). */ + void worker_terminate(const conn_t &conn); + /** Terminate the connection (from the dispatcher thread). */ + void disp_terminate(const conn_t &conn); + + void conn_server(const conn_t &conn, int, int); + /** Called when new data is available. */ + virtual void on_read(const conn_t &) {} + /** Called when the underlying connection is established. */ + virtual void on_setup(const conn_t &) {} + /** Called when the underlying connection breaks. */ + virtual void on_teardown(const conn_t &) {} private: const int max_listen_backlog; @@ -195,11 +190,11 @@ class ConnPool { bool ret = !conn_cb || conn_cb(conn, connected); if (enable_tls && connected) { - conn->worker->get_tcall()->async_call([conn, ret](ThreadCall::Handle &) { + conn->worker->get_tcall()->async_call([this, conn, ret](ThreadCall::Handle &) { if (ret) conn->recv_data_func = Conn::_recv_data_tls; else - conn->worker_terminate(); + worker_terminate(conn); }); } }); @@ -214,6 +209,7 @@ class ConnPool { ConnPool::worker_error_callback_t on_fatal_error; public: + Worker(): tcall(ec), disp_flag(false), nconn(0) {} void set_error_callback(ConnPool::worker_error_callback_t _on_error) { @@ -233,6 +229,7 @@ class ConnPool { /* the caller should finalize all the preparation */ tcall.async_call([this, conn, client_fd](ThreadCall::Handle &) { try { + conn->ev_connect.clear(); if (conn->mode == Conn::ConnMode::DEAD) { SALTICIDAE_LOG_INFO("worker %x discarding dead connection", @@ -258,6 +255,7 @@ class ConnPool { SALTICIDAE_LOG_INFO("worker %x got %s", std::this_thread::get_id(), std::string(*conn).c_str()); + assert(conn->worker == this); conn->get_send_buffer() .get_queue() .reg_handler(this->ec, [conn, client_fd] @@ -270,7 +268,7 @@ class ConnPool { } return false; }); - conn->ev_socket = FdEvent(ec, client_fd, [this, conn=conn](int fd, int what) { + conn->ev_socket = FdEvent(ec, client_fd, [this, conn](int fd, int what) { try { if (what & FdEvent::READ) conn->recv_data_func(conn, fd, what); @@ -278,7 +276,7 @@ class ConnPool { conn->send_data_func(conn, fd, what); } catch (...) { conn->cpool->recoverable_error(std::current_exception()); - conn->worker_terminate(); + conn->cpool->worker_terminate(conn); } }); conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); @@ -310,6 +308,7 @@ class ConnPool { void accept_client(int, int); conn_t add_conn(const conn_t &conn); void del_conn(const conn_t &conn); + void release_conn(const conn_t &conn); protected: conn_t _connect(const NetAddr &addr); @@ -494,6 +493,7 @@ class ConnPool { ConnPool(ConnPool &&) = delete; void start() { + std::atomic_thread_fence(std::memory_order_acq_rel); if (system_state) return; SALTICIDAE_LOG_INFO("starting all threads..."); for (size_t i = 0; i < nworker; i++) @@ -516,10 +516,9 @@ class ConnPool { workers[i].get_handle().join(); for (auto it: pool) { - conn_t conn = it.second; + auto &conn = it.second; conn->stop(); - conn->self_ref = nullptr; - ::close(conn->fd); + release_conn(conn); } } @@ -589,7 +588,7 @@ class ConnPool { void terminate(const conn_t &conn) { disp_tcall->async_call([this, conn](ThreadCall::Handle &) { try { - conn->disp_terminate(); + disp_terminate(conn); } catch (...) { disp_error_cb(std::current_exception()); } diff --git a/include/salticidae/event.h b/include/salticidae/event.h index b243865..ad78a6e 100644 --- a/include/salticidae/event.h +++ b/include/salticidae/event.h @@ -308,6 +308,7 @@ class TimedFdEvent: public FdEvent, public TimerEvent { void clear() { TimerEvent::clear(); FdEvent::clear(); + callback = nullptr; } using FdEvent::set_callback; @@ -532,7 +533,7 @@ class MPSCQueueEventDriven: public MPSCQueue<T> { fd(eventfd(0, EFD_NONBLOCK)) { if (fd == -1) throw SalticidaeError(SALTI_ERROR_FD); } - ~MPSCQueueEventDriven() { close(fd); } + ~MPSCQueueEventDriven() { close(fd); unreg_handler(); } template<typename Func> void reg_handler(const EventContext &ec, Func &&func) { @@ -587,7 +588,7 @@ class MPMCQueueEventDriven: public MPMCQueue<T> { fd(eventfd(0, EFD_NONBLOCK)) { if (fd == -1) throw SalticidaeError(SALTI_ERROR_FD); } - ~MPMCQueueEventDriven() { close(fd); } + ~MPMCQueueEventDriven() { close(fd); unreg_handlers(); } // this function is *NOT* thread-safe template<typename Func> diff --git a/include/salticidae/network.h b/include/salticidae/network.h index 07c6ba5..20dc696 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -109,7 +109,6 @@ class MsgNetwork: public ConnPool { #endif protected: - void on_read() override; }; using conn_t = ArcObj<Conn>; @@ -127,6 +126,7 @@ class MsgNetwork: public ConnPool { protected: ConnPool::Conn *create_conn() override { return new Conn(); } + void on_read(const ConnPool::conn_t &) override; public: @@ -287,12 +287,9 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { protected: void stop() override { - ev_timeout.clear(); + ev_timeout.del(); MsgNet::Conn::stop(); } - - void on_setup() override; - void on_teardown() override; }; using conn_t = ArcObj<Conn>; @@ -326,7 +323,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { if (ev_ping_timer) ev_ping_timer.del(); } - void reset_conn(conn_t conn); + void reset_conn(const conn_t &conn); }; std::unordered_map<NetAddr, BoxObj<Peer>> id2peer; @@ -381,6 +378,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { virtual double gen_conn_timeout() { return gen_rand_timeout(retry_conn_delay); } + void on_setup(const ConnPool::conn_t &) override; + void on_teardown(const ConnPool::conn_t &) override; public: @@ -466,11 +465,13 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { /* this callback is run by a worker */ template<typename OpcodeType> -void MsgNetwork<OpcodeType>::Conn::on_read() { - ConnPool::Conn::on_read(); - auto &recv_buffer = this->recv_buffer; - auto mn = get_net(); - while (self_ref) +void MsgNetwork<OpcodeType>::on_read(const ConnPool::conn_t &_conn) { + ConnPool::on_read(_conn); + auto conn = static_pointer_cast<Conn>(_conn); + auto &recv_buffer = conn->recv_buffer; + auto &msg = conn->msg; + auto &msg_state = conn->msg_state; + while (true) //(!conn->is_terminated()) { if (msg_state == Conn::HEADER) { @@ -493,8 +494,7 @@ void MsgNetwork<OpcodeType>::Conn::on_read() { return; } #endif - auto conn = static_pointer_cast<Conn>(self()); - while (!mn->incoming_msgs.enqueue(std::make_pair(msg, conn), false)) + while (!incoming_msgs.enqueue(std::make_pair(msg, conn), false)) std::this_thread::yield(); } } @@ -550,46 +550,47 @@ void PeerNetwork<O, _, __>::tcall_reset_timeout(ConnPool::Worker *worker, /* begin: functions invoked by the dispatcher */ template<typename O, O _, O __> -void PeerNetwork<O, _, __>::Conn::on_setup() { - MsgNet::Conn::on_setup(); - auto pn = get_net(); - auto conn = static_pointer_cast<Conn>(this->self()); - auto worker = this->worker; +void PeerNetwork<O, _, __>::on_setup(const ConnPool::conn_t &_conn) { + MsgNet::on_setup(_conn); + auto conn = static_pointer_cast<Conn>(_conn); + auto worker = conn->worker; + auto &ev_timeout = conn->ev_timeout; assert(!ev_timeout); ev_timeout = TimerEvent(worker->get_ec(), [worker, conn](TimerEvent &) { try { SALTICIDAE_LOG_INFO("peer ping-pong timeout"); - conn->worker_terminate(); + conn->get_net()->worker_terminate(conn); } catch (...) { worker->error_callback(std::current_exception()); } }); /* the initial ping-pong to set up the connection */ - tcall_reset_timeout(worker, conn, pn->conn_timeout); - pn->send_msg(MsgPing(pn->listen_port), conn); + tcall_reset_timeout(worker, conn, conn_timeout); + send_msg(MsgPing(listen_port), conn); } template<typename O, O _, O __> -void PeerNetwork<O, _, __>::Conn::on_teardown() { - MsgNet::Conn::on_teardown(); - auto pn = get_net(); - auto p = pn->get_peer(peer_id); +void PeerNetwork<O, _, __>::on_teardown(const ConnPool::conn_t &_conn) { + MsgNet::on_teardown(_conn); + auto conn = static_pointer_cast<Conn>(_conn); + conn->ev_timeout.clear(); + const auto &peer_id = conn->peer_id; + auto p = get_peer(peer_id); if (!p) return; - if (this != p->conn.get()) return; + if (conn != p->conn) return; p->ev_ping_timer.del(); p->connected = false; //p->conn = nullptr; - SALTICIDAE_LOG_INFO("connection lost: %s", std::string(*this).c_str()); + SALTICIDAE_LOG_INFO("connection lost: %s", std::string(*conn).c_str()); // try to reconnect - p->ev_retry_timer = TimerEvent(pn->disp_ec, - [pn, peer_id = this->peer_id](TimerEvent &) { + p->ev_retry_timer = TimerEvent(this->disp_ec, [this, peer_id](TimerEvent &) { try { - pn->start_active_conn(peer_id); - } catch (...) { pn->disp_error_cb(std::current_exception()); } + start_active_conn(peer_id); + } catch (...) { this->disp_error_cb(std::current_exception()); } }); - p->ev_retry_timer.add(pn->gen_conn_timeout()); + p->ev_retry_timer.add(gen_conn_timeout()); } template<typename O, O _, O __> -void PeerNetwork<O, _, __>::Peer::reset_conn(conn_t new_conn) { +void PeerNetwork<O, _, __>::Peer::reset_conn(const conn_t &new_conn) { if (conn != new_conn) { if (conn) @@ -597,7 +598,8 @@ void PeerNetwork<O, _, __>::Peer::reset_conn(conn_t new_conn) { //SALTICIDAE_LOG_DEBUG("moving send buffer"); //new_conn->move_send_buffer(conn); SALTICIDAE_LOG_INFO("terminating old connection %s", std::string(*conn).c_str()); - conn->disp_terminate(); + auto net = conn->get_net(); + net->disp_terminate(conn); } addr = new_conn->get_addr(); conn = new_conn; @@ -656,7 +658,7 @@ bool PeerNetwork<O, _, __>::check_new_conn(const conn_t &conn, uint16_t port) { } else { - conn->disp_terminate(); + this->disp_terminate(conn); return true; } } @@ -665,7 +667,7 @@ bool PeerNetwork<O, _, __>::check_new_conn(const conn_t &conn, uint16_t port) { { if (conn != p->conn) { - conn->disp_terminate(); + this->disp_terminate(conn); return true; } return false; @@ -797,7 +799,7 @@ void PeerNetwork<O, _, __>::del_peer(const NetAddr &addr) { auto it = id2peer.find(addr); if (it == id2peer.end()) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - it->second->conn->disp_terminate(); + this->disp_terminate(it->second->conn); id2peer.erase(it); } catch (const PeerNetworkError &) { this->recoverable_error(std::current_exception()); @@ -899,8 +901,7 @@ void ClientNetwork<OpcodeType>::Conn::on_setup() { auto cn = get_net(); cn->addr2conn.erase(addr); cn->addr2conn.insert( - std::make_pair(addr, - static_pointer_cast<Conn>(this->self()))); + std::make_pair(addr, static_pointer_cast<Conn>(this->self()))); } template<typename OpcodeType> |