From 3707146f42ccc066edf214cc77118f91b687e47b Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 21 Feb 2020 18:23:40 -0500 Subject: adjust on_setup and on_teardown and fix minor bugs --- include/salticidae/conn.h | 34 +++++++++---- include/salticidae/network.h | 115 ++++++++++++++++++++++++------------------- include/salticidae/util.h | 1 + src/conn.cpp | 25 +++++----- src/util.cpp | 1 + test/test_p2p_stress.cpp | 35 ++++++++++--- 6 files changed, 132 insertions(+), 79 deletions(-) diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 44a1bf9..e5890f6 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -108,10 +108,6 @@ class ConnPool { static socket_io_func _send_data_tls_handshake; static socket_io_func _recv_data_dummy; - /** Close the IO and clear all on-going or planned events. Remove the - * connection from a Worker. */ - virtual void stop(); - public: Conn(): terminated(false), // recv_chunk_size initialized later @@ -184,9 +180,18 @@ class ConnPool { /** Called when new data is available. */ virtual void on_read(const conn_t &) {} /** Called when the underlying connection is established. */ - virtual void on_setup(const conn_t &) {} + virtual void on_worker_setup(const conn_t &) {} + /** Called when the underlying connection is established. */ + virtual void on_dispatcher_setup(const conn_t &) {} /** Called when the underlying connection breaks. */ - virtual void on_teardown(const conn_t &) {} + virtual void on_worker_teardown(const conn_t &conn) { + if (conn->worker) conn->worker->unfeed(); + if (conn->tls) conn->tls->shutdown(); + conn->ev_socket.clear(); + conn->send_buffer.get_queue().unreg_handler(); + } + /** Called when the underlying connection breaks. */ + virtual void on_dispatcher_teardown(const conn_t &) {} private: const int max_listen_backlog; @@ -212,6 +217,7 @@ class ConnPool { if (enable_tls) { conn->worker->get_tcall()->async_call([this, conn, ret](ThreadCall::Handle &) { + if (conn->is_terminated()) return; if (ret) { conn->recv_data_func = Conn::_recv_data_tls; @@ -223,6 +229,7 @@ class ConnPool { } else conn->worker->get_tcall()->async_call([conn](ThreadCall::Handle &) { + if (conn->is_terminated()) return; conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); }); } @@ -306,9 +313,15 @@ class ConnPool { conn->send_data_func = Conn::_send_data; conn->recv_data_func = Conn::_recv_data; enable_send_buffer(conn, client_fd); + cpool->on_worker_setup(conn); cpool->disp_tcall->async_call([cpool, conn](ThreadCall::Handle &) { - cpool->on_setup(conn); - cpool->update_conn(conn, true); + try { + cpool->on_dispatcher_setup(conn); + cpool->update_conn(conn, true); + } catch (...) { + cpool->recoverable_error(std::current_exception(), -1); + cpool->disp_terminate(conn); + } }); } assert(conn->fd != -1); @@ -559,7 +572,8 @@ class ConnPool { for (auto it: pool) { auto &conn = it.second; - conn->stop(); + on_worker_teardown(conn); + //conn->stop(); conn->set_terminated(); release_conn(conn); } @@ -623,6 +637,8 @@ class ConnPool { } }); } + + const X509 *get_cert() const { return tls_cert.get(); } }; } diff --git a/include/salticidae/network.h b/include/salticidae/network.h index 40f17a1..19d6db0 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -89,10 +89,6 @@ class MsgNetwork: public ConnPool { mutable std::atomic nsentb; mutable std::atomic nrecvb; #endif - void stop() override { - ev_enqueue_poll.clear(); - ConnPool::Conn::stop(); - } public: Conn(): msg_state(HEADER), msg_sleep(false) @@ -138,12 +134,10 @@ class MsgNetwork: public ConnPool { ConnPool::Conn *create_conn() override { return new Conn(); } void on_read(const ConnPool::conn_t &) override; - void on_setup(const ConnPool::conn_t &_conn) override { + void on_worker_setup(const ConnPool::conn_t &_conn) override { auto conn = static_pointer_cast(_conn); - auto worker = conn->worker; - worker->get_tcall()->async_call([this, conn, worker](ThreadCall::Handle &) { - conn->ev_enqueue_poll = TimerEvent(worker->get_ec(), - [this, conn](TimerEvent &) { + conn->ev_enqueue_poll = TimerEvent(conn->worker->get_ec(), + [this, conn](TimerEvent &) { if (!incoming_msgs.enqueue(std::make_pair(conn->msg, conn), false)) { conn->msg_sleep = true; @@ -153,7 +147,12 @@ class MsgNetwork: public ConnPool { conn->msg_sleep = false; on_read(conn); }); - }); + } + + void on_worker_teardown(const ConnPool::conn_t &_conn) override { + auto conn = static_pointer_cast(_conn); + conn->ev_enqueue_poll.clear(); + ConnPool::on_worker_teardown(_conn); } public: @@ -287,8 +286,8 @@ class ClientNetwork: public MsgNetwork { protected: ConnPool::Conn *create_conn() override { return new Conn(); } - void on_setup(const ConnPool::conn_t &) override; - void on_teardown(const ConnPool::conn_t &) override; + void on_dispatcher_setup(const ConnPool::conn_t &) override; + void on_dispatcher_teardown(const ConnPool::conn_t &) override; public: using Config = typename MsgNet::Config; @@ -376,12 +375,6 @@ class PeerNetwork: public MsgNetwork { PeerNetwork *get_net() { return static_cast(ConnPool::Conn::get_pool()); } - - protected: - void stop() override { - ev_timeout.clear(); - MsgNet::Conn::stop(); - } }; using conn_t = ArcObj; @@ -520,8 +513,10 @@ class PeerNetwork: public MsgNetwork { protected: ConnPool::Conn *create_conn() override { return new Conn(); } - void on_setup(const ConnPool::conn_t &) override; - void on_teardown(const ConnPool::conn_t &) override; + void on_worker_setup(const ConnPool::conn_t &) override; + void on_worker_teardown(const ConnPool::conn_t &) override; + void on_dispatcher_setup(const ConnPool::conn_t &) override; + void on_dispatcher_teardown(const ConnPool::conn_t &) override; PeerId _get_peer_id(const X509 *cert, const NetAddr &addr) { if (!this->enable_tls || id_mode == ADDR_BASED) @@ -738,47 +733,65 @@ void PeerNetwork::tcall_reset_timeout(ConnPool::Worker *worker, }); } -/* begin: functions invoked by the dispatcher */ template -void PeerNetwork::on_setup(const ConnPool::conn_t &_conn) { - MsgNet::on_setup(_conn); +void PeerNetwork::on_worker_setup(const ConnPool::conn_t &_conn) { + MsgNet::on_worker_setup(_conn); auto conn = static_pointer_cast(_conn); auto worker = conn->worker; + auto &ev_timeout = conn->ev_timeout; + assert(!ev_timeout); + ev_timeout = TimerEvent(worker->get_ec(), [=](TimerEvent &) { + try { + SALTICIDAE_LOG_INFO("%s%s%s: peer ping-pong timeout", + tty_secondary_color, + id_hex.c_str(), + tty_reset_color); + this->worker_terminate(conn); + } catch (...) { worker->error_callback(std::current_exception()); } + }); +} + +template +void PeerNetwork::on_worker_teardown(const ConnPool::conn_t &_conn) { + auto conn = static_pointer_cast(_conn); + conn->ev_timeout.clear(); + MsgNet::on_worker_teardown(_conn); +} + +/* begin: functions invoked by the dispatcher */ + +/* the initial ping-pong to set up the connection */ +template +void PeerNetwork::on_dispatcher_setup(const ConnPool::conn_t &_conn) { + MsgNet::on_dispatcher_setup(_conn); + auto conn = static_pointer_cast(_conn); SALTICIDAE_LOG_INFO("%s%s%s: setup connection %s", tty_secondary_color, id_hex.c_str(), tty_reset_color, std::string(*conn).c_str()); - worker->get_tcall()->async_call([this, conn, worker](ThreadCall::Handle &) { - auto &ev_timeout = conn->ev_timeout; - assert(!ev_timeout); - ev_timeout = TimerEvent(worker->get_ec(), [=](TimerEvent &) { - try { - SALTICIDAE_LOG_INFO("%s%s%s: peer ping-pong timeout", - tty_secondary_color, - id_hex.c_str(), - tty_reset_color); - this->worker_terminate(conn); - } catch (...) { worker->error_callback(std::current_exception()); } - }); - }); - /* the initial ping-pong to set up the connection */ - tcall_reset_timeout(worker, conn, conn_timeout); + tcall_reset_timeout(conn->worker, conn, conn_timeout); if (conn->get_mode() == Conn::ConnMode::ACTIVE) { auto pid = get_peer_id(conn, conn->get_addr()); - pinfo_slock_t _g(known_peers_lock); - send_msg(MsgPing( - listen_addr, - known_peers.find(pid)->second->get_nonce()), conn); + auto it = known_peers.find(pid); + if (it == known_peers.end()) + throw PeerNetworkError(SALTI_ERROR_PEER_NOT_MATCH); + else + { + pinfo_slock_t _g(known_peers_lock); + send_msg(MsgPing( + listen_addr, + it->second->get_nonce()), conn); + } } else replace_pending_conn(conn); } template -void PeerNetwork::on_teardown(const ConnPool::conn_t &_conn) { - MsgNet::on_teardown(_conn); +void PeerNetwork::on_dispatcher_teardown(const ConnPool::conn_t &_conn) { + MsgNet::on_dispatcher_teardown(_conn); auto conn = static_pointer_cast(_conn); auto addr = conn->get_addr(); pending_peers.erase(addr); @@ -949,8 +962,7 @@ void PeerNetwork::ping_handler(MsgPing &&msg, const conn_t &conn) { this->user_tcall->async_call([this, addr=msg.claimed_addr, conn](ThreadCall::Handle &) { if (unknown_peer_cb) unknown_peer_cb(addr, conn->get_peer_cert()); }); - this->disp_terminate(conn); - return; + throw PeerNetworkError(SALTI_ERROR_PEER_NOT_MATCH); } auto &p = pit->second; if (p->state != Peer::State::DISCONNECTED || @@ -1018,8 +1030,7 @@ void PeerNetwork::pong_handler(MsgPong &&msg, const conn_t &conn) { SALTICIDAE_LOG_WARN( "%s%s%s: unexpected pong from an unknown peer", tty_secondary_color, id_hex.c_str(), tty_reset_color); - this->disp_terminate(conn); - return; + throw PeerNetworkError(SALTI_ERROR_PEER_NOT_MATCH); } auto &p = pit->second; assert(!p->addr.is_null() && p->addr == conn->get_addr()); @@ -1290,8 +1301,8 @@ inline int32_t PeerNetwork::_multicast_msg(Msg &&msg, const std::vecto /* end: functions invoked by the user loop */ template -void ClientNetwork::on_setup(const ConnPool::conn_t &_conn) { - MsgNet::on_setup(_conn); +void ClientNetwork::on_dispatcher_setup(const ConnPool::conn_t &_conn) { + MsgNet::on_dispatcher_setup(_conn); auto conn = static_pointer_cast(_conn); assert(conn->get_mode() == Conn::PASSIVE); const auto &addr = conn->get_addr(); @@ -1301,8 +1312,8 @@ void ClientNetwork::on_setup(const ConnPool::conn_t &_conn) { } template -void ClientNetwork::on_teardown(const ConnPool::conn_t &_conn) { - MsgNet::on_teardown(_conn); +void ClientNetwork::on_dispatcher_teardown(const ConnPool::conn_t &_conn) { + MsgNet::on_dispatcher_teardown(_conn); auto conn = static_pointer_cast(_conn); conn->get_net()->addr2conn.erase(conn->get_addr()); } diff --git a/include/salticidae/util.h b/include/salticidae/util.h index cb09c0e..8c8fcb9 100644 --- a/include/salticidae/util.h +++ b/include/salticidae/util.h @@ -91,6 +91,7 @@ enum SalticidaeErrorCode { SALTI_ERROR_PEER_ALREADY_EXISTS, SALTI_ERROR_PEER_NOT_EXIST, SALTI_ERROR_PEER_NOT_READY, + SALTI_ERROR_PEER_NOT_MATCH, SALTI_ERROR_CLIENT_NOT_EXIST, SALTI_ERROR_NETADDR_INVALID, SALTI_ERROR_OPTVAL_INVALID, diff --git a/src/conn.cpp b/src/conn.cpp index a5d60a7..af15276 100644 --- a/src/conn.cpp +++ b/src/conn.cpp @@ -251,9 +251,15 @@ void ConnPool::Conn::_recv_data_tls_handshake(const conn_t &conn, int, int) { conn->peer_cert = new X509(conn->tls->get_peer_cert()); conn->worker->enable_send_buffer(conn, conn->fd); auto cpool = conn->cpool; + cpool->on_worker_setup(conn); cpool->disp_tcall->async_call([cpool, conn](ThreadCall::Handle &) { - cpool->on_setup(conn); - cpool->update_conn(conn, true); + try { + cpool->on_dispatcher_setup(conn); + cpool->update_conn(conn, true); + } catch (...) { + cpool->recoverable_error(std::current_exception(), -1); + cpool->disp_terminate(conn); + } }); } else @@ -266,17 +272,11 @@ void ConnPool::Conn::_recv_data_tls_handshake(const conn_t &conn, int, int) { void ConnPool::Conn::_recv_data_dummy(const conn_t &, int, int) {} -void ConnPool::Conn::stop() { - if (worker) worker->unfeed(); - if (tls) tls->shutdown(); - ev_socket.clear(); - send_buffer.get_queue().unreg_handler(); -} - void ConnPool::worker_terminate(const conn_t &conn) { conn->worker->get_tcall()->async_call([this, conn](ThreadCall::Handle &) { if (!conn->set_terminated()) return; - conn->stop(); + on_worker_teardown(conn); + //conn->stop(); disp_tcall->async_call([this, conn](ThreadCall::Handle &) { del_conn(conn); }); @@ -292,7 +292,8 @@ void ConnPool::disp_terminate(const conn_t &conn) { else disp_tcall->async_call([this, conn](ThreadCall::Handle &) { if (!conn->set_terminated()) return; - conn->stop(); + on_worker_teardown(conn); + //conn->stop(); del_conn(conn); }); } @@ -440,7 +441,7 @@ void ConnPool::del_conn(const conn_t &conn) { void ConnPool::release_conn(const conn_t &conn) { /* inform the upper layer the connection will be destroyed */ conn->ev_connect.clear(); - on_teardown(conn); + on_dispatcher_teardown(conn); ::close(conn->fd); } diff --git a/src/util.cpp b/src/util.cpp index 8ca01aa..01f6b06 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -43,6 +43,7 @@ const char *SALTICIDAE_ERROR_STRINGS[] = { "peer already exists", "peer does not exist", "peer is not ready", + "peer id does not match the record", "client does not exist", "invalid NetAddr format", "invalid OptVal format", diff --git a/test/test_p2p_stress.cpp b/test/test_p2p_stress.cpp index d054a57..7a078eb 100644 --- a/test/test_p2p_stress.cpp +++ b/test/test_p2p_stress.cpp @@ -46,6 +46,7 @@ using salticidae::static_pointer_cast; using salticidae::Config; using salticidae::ThreadCall; using salticidae::BoxObj; +using salticidae::PKey; using std::placeholders::_1; using std::placeholders::_2; @@ -88,6 +89,8 @@ const uint8_t MsgAck::opcode; using MyNet = salticidae::PeerNetwork; +bool use_tls; +std::unordered_set valid_certs; std::vector addrs; struct TestContext { @@ -116,6 +119,11 @@ void install_proto(AppContext &app, const size_t &recv_chunk_size) { net.send_msg(std::move(msg), conn); }; net.reg_conn_handler([](const ConnPool::conn_t &conn, bool connected) { + if (connected && use_tls) + { + auto cert_hash = salticidae::get_hash(conn->get_peer_cert()->get_der()); + return valid_certs.count(cert_hash) > 0; + } return true; }); net.reg_peer_handler([&, send_rand](const MyNet::conn_t &conn, bool connected) { @@ -197,6 +205,7 @@ int main(int argc, char **argv) { auto opt_nworker = Config::OptValInt::create(2); auto opt_conn_timeout = Config::OptValDouble::create(5); auto opt_ping_peroid = Config::OptValDouble::create(2); + auto opt_tls = Config::OptValFlag::create(false); auto opt_help = Config::OptValFlag::create(false); config.add_opt("no-msg", opt_no_msg, Config::SWITCH_ON); config.add_opt("npeers", opt_npeers, Config::SET_VAL); @@ -204,6 +213,7 @@ int main(int argc, char **argv) { config.add_opt("nworker", opt_nworker, Config::SET_VAL); config.add_opt("conn-timeout", opt_conn_timeout, Config::SET_VAL); config.add_opt("ping-period", opt_ping_peroid, Config::SET_VAL); + config.add_opt("tls", opt_tls, Config::SWITCH_ON, 't'); config.add_opt("help", opt_help, Config::SWITCH_ON, 'h', "show this help info"); config.parse(argc, argv); if (opt_help->get()) @@ -216,13 +226,24 @@ int main(int argc, char **argv) { addrs.push_back(NetAddr("127.0.0.1:" + std::to_string(12345 + i))); std::vector apps; std::vector threads; + use_tls = opt_tls->get(); apps.resize(addrs.size()); for (size_t i = 0; i < apps.size(); i++) { auto &a = apps[i]; a.addr = addrs[i]; - a.net = new MyNet(a.ec, MyNet::Config( - salticidae::ConnPool::Config() + salticidae::ConnPool::Config cfg{}; + if (use_tls) + { + auto tls_key = new PKey(PKey::create_privkey_rsa(2048)); + auto tls_cert = new salticidae::X509(salticidae::X509::create_self_signed_from_pubkey(*tls_key)); + cfg.enable_tls(true) + .tls_key(tls_key) + .tls_cert(tls_cert); + valid_certs.insert(salticidae::get_hash(tls_cert->get_der())); + } + else cfg.enable_tls(false); + a.net = new MyNet(a.ec, MyNet::Config(cfg .nworker(opt_nworker->get()) .recv_chunk_size(recv_chunk_size)) .conn_timeout(opt_conn_timeout->get()) @@ -238,12 +259,14 @@ int main(int argc, char **argv) { threads.push_back(std::thread([&]() { masksigs(); a.net->listen(a.addr); - for (auto &paddr: addrs) - if (paddr != a.addr) + for (auto &b: apps) + if (b.addr != a.addr) { - salticidae::PeerId pid{paddr}; + auto pid = use_tls ? + salticidae::PeerId(*b.net->get_cert()) : + salticidae::PeerId(b.addr); a.net->add_peer(pid); - a.net->set_peer_addr(pid, paddr); + a.net->set_peer_addr(pid, b.addr); a.net->conn_peer(pid); } a.ec.dispatch();})); -- cgit v1.2.3