diff options
-rw-r--r-- | include/salticidae/conn.h | 8 | ||||
-rw-r--r-- | include/salticidae/crypto.h | 55 | ||||
-rw-r--r-- | include/salticidae/event.h | 4 | ||||
-rw-r--r-- | include/salticidae/network.h | 277 | ||||
-rw-r--r-- | include/salticidae/stream.h | 70 | ||||
-rw-r--r-- | src/conn.cpp | 10 | ||||
-rw-r--r-- | src/network.cpp | 4 | ||||
-rw-r--r-- | test/.gitignore | 1 | ||||
-rw-r--r-- | test/CMakeLists.txt | 3 | ||||
-rw-r--r-- | test/test_p2p.cpp | 14 | ||||
-rw-r--r-- | test/test_p2p_stress.cpp | 26 | ||||
-rw-r--r-- | test/test_p2p_tls.cpp | 317 |
12 files changed, 620 insertions, 169 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 87966ac..e39d31d 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -195,7 +195,12 @@ class ConnPool { if (enable_tls && connected) { conn->worker->get_tcall()->async_call([this, conn, ret](ThreadCall::Handle &) { - if (ret) conn->recv_data_func = Conn::_recv_data_tls; + if (ret) + { + conn->recv_data_func = Conn::_recv_data_tls; + conn->ev_socket.del(); + conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); + } else worker_terminate(conn); }); } @@ -262,6 +267,7 @@ class ConnPool { conn->send_data_func = Conn::_send_data; conn->recv_data_func = Conn::_recv_data; enable_send_buffer(conn, client_fd); + cpool->on_setup(conn); cpool->update_conn(conn, true); } assert(conn->fd != -1); diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index d7231a4..fe4de4f 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -31,6 +31,7 @@ #ifdef __cplusplus #include <openssl/sha.h> #include <openssl/ssl.h> +#include <openssl/bn.h> namespace salticidae { @@ -137,10 +138,30 @@ static int _skip_CA_check(int, X509_STORE_CTX *) { class PKey { EVP_PKEY *key; friend class TLSContext; + friend class X509; public: PKey(EVP_PKEY *key): key(key) {} PKey(const PKey &) = delete; PKey(PKey &&other): key(other.key) { other.key = nullptr; } + + static PKey create_privkey_rsa(size_t nbits = 2048) { + auto key = EVP_PKEY_new(); + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + auto e = BN_new(); + BN_set_word(e, 17); + auto rsa = RSA_new(); + auto ret = RSA_generate_key_ex(rsa, nbits, e, nullptr); + BN_free(e); + if (!ret) + { + RSA_free(rsa); + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + } + EVP_PKEY_set1_RSA(key, rsa); + RSA_free(rsa); + return PKey(key); + } static PKey create_privkey_from_pem_file(std::string pem_fname, std::string *passwd = nullptr) { FILE *fp = fopen(pem_fname.c_str(), "r"); @@ -193,6 +214,14 @@ class PKey { return std::move(res); } + void save_privkey_to_file(const std::string &fname) { + FILE *fp = fopen(fname.c_str(), "wb"); + auto ret = PEM_write_PrivateKey(fp, key, nullptr, nullptr, 0, nullptr, nullptr); + fclose(fp); + if (!ret) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + } + ~PKey() { if (key) EVP_PKEY_free(key); } }; @@ -203,6 +232,24 @@ class X509 { X509(::X509 *x509): x509(x509) {} X509(const X509 &) = delete; X509(X509 &&other): x509(other.x509) { other.x509 = nullptr; } + + static X509 create_self_signed_from_pubkey(const PKey &pkey, const char *country = "US", const char *common_name = "localhost") { + auto x509 = X509_new(); + ASN1_INTEGER_set(X509_get_serialNumber(x509), 1); + X509_set_pubkey(x509, pkey.key); + X509_gmtime_adj(X509_get_notBefore(x509), 0); + X509_gmtime_adj(X509_get_notAfter(x509), 0); + auto name = X509_get_subject_name(x509); + X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC, + (unsigned char *)country, -1, -1, 0); + X509_NAME_add_entry_by_txt(name, "O", MBSTRING_ASC, + (unsigned char *)"libsalticidae", -1, -1, 0); + X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC, + (unsigned char *)common_name, -1, -1, 0); + X509_set_issuer_name(x509, name); + X509_sign(x509, pkey.key, EVP_sha1()); + return X509(x509); + } static X509 create_from_pem_file(std::string pem_fname, std::string *passwd = nullptr) { FILE *fp = fopen(pem_fname.c_str(), "r"); @@ -251,6 +298,14 @@ class X509 { return std::move(res); } + void save_to_file(const std::string &fname) { + FILE *fp = fopen(fname.c_str(), "wb"); + auto ret = PEM_write_X509(fp, x509); + fclose(fp); + if (!ret) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + } + ~X509() { if (x509) X509_free(x509); } }; diff --git a/include/salticidae/event.h b/include/salticidae/event.h index ad78a6e..a7ea209 100644 --- a/include/salticidae/event.h +++ b/include/salticidae/event.h @@ -565,7 +565,7 @@ class MPSCQueueEventDriven: public MPSCQueue<T> { // memory barrier here, so any load/store in enqueue must be finialized if (wait_sig.exchange(false, std::memory_order_acq_rel)) { - SALTICIDAE_LOG_DEBUG("mpsc notify"); + //SALTICIDAE_LOG_DEBUG("mpsc notify"); write(fd, &dummy, 8); } return true; @@ -616,7 +616,7 @@ class MPMCQueueEventDriven: public MPMCQueue<T> { // memory barrier here, so any load/store in enqueue must be finialized if (wait_sig.exchange(false, std::memory_order_acq_rel)) { - SALTICIDAE_LOG_DEBUG("mpsc notify"); + //SALTICIDAE_LOG_DEBUG("mpmc notify"); write(fd, &dummy, 8); } return true; diff --git a/include/salticidae/network.h b/include/salticidae/network.h index 975084f..d3f3bae 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -261,7 +261,6 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { public: using MsgNet = MsgNetwork<OpcodeType>; using Msg = typename MsgNet::Msg; - using unknown_callback_t = std::function<void(const NetAddr &)>; enum IdentityMode { ADDR_BASED, @@ -282,6 +281,9 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { public: Conn(): MsgNet::Conn(), peer(nullptr) {} + NetAddr get_peer_addr() { + return peer ? peer->peer_addr : NetAddr(); + } PeerNetwork *get_net() { return static_cast<PeerNetwork *>(ConnPool::Conn::get_pool()); @@ -295,11 +297,13 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { }; using conn_t = ArcObj<Conn>; + using peer_callback_t = std::function<void(const conn_t &peer_conn, bool connected)>; + using unknown_peer_callback_t = std::function<void(const NetAddr &claimed_addr)>; private: - struct Peer { + class Peer { + friend PeerNetwork; /** connection addr, may be different due to passive mode */ - uint256_t nonce; uint256_t peer_id; NetAddr peer_addr; /** the underlying connection, may be invalid when connected = false */ @@ -316,17 +320,17 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { double ping_period; Peer() = delete; - Peer(conn_t conn, conn_t inbound_conn, conn_t outbound_conn, const PeerNetwork *pn): - conn(conn), - inbound_conn(inbound_conn), - outbound_conn(outbound_conn), - ev_ping_timer( - TimerEvent(pn->disp_ec, std::bind(&Peer::ping_timer, this, _1))), - connected(false), - outbound_handshake(false), - inbound_handshake(false), - ping_period(pn->ping_period) {} - ~Peer() {} + Peer(const uint256_t &peer_id, conn_t conn, conn_t inbound_conn, conn_t outbound_conn, const PeerNetwork *pn): + peer_id(peer_id), + conn(conn), + inbound_conn(inbound_conn), + outbound_conn(outbound_conn), + ev_ping_timer( + TimerEvent(pn->disp_ec, std::bind(&Peer::ping_timer, this, _1))), + connected(false), + outbound_handshake(false), + inbound_handshake(false), + ping_period(pn->ping_period) {} Peer &operator=(const Peer &) = delete; Peer(const Peer &) = delete; @@ -337,12 +341,15 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { if (ev_ping_timer) ev_ping_timer.del(); } + public: + ~Peer() {} }; std::unordered_map<NetAddr, conn_t> pending_peers; std::unordered_map<NetAddr, uint256_t> known_peers; std::unordered_map<uint256_t, BoxObj<Peer>> pid2peer; - unknown_callback_t unknown_peer_cb; + peer_callback_t peer_cb; + unknown_peer_callback_t unknown_peer_cb; const IdentityMode id_mode; double retry_conn_delay; @@ -350,41 +357,35 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { double conn_timeout; NetAddr listen_addr; bool allow_unknown_peer; - uint256_t my_pname; uint256_t my_nonce; struct MsgPing { static const OpcodeType opcode; DataStream serialized; - uint256_t pname; + NetAddr claimed_addr; uint256_t nonce; - uint256_t peer_id; - MsgPing() { serialized << false; } - MsgPing(const uint256_t &_pname, const uint256_t &_nonce) { - serialized << true << _pname << _nonce; + MsgPing() { serialized << (uint8_t)false; } + MsgPing(const NetAddr &_claimed_addr, const uint256_t &_nonce) { + serialized << (uint8_t)true << _claimed_addr << _nonce; } MsgPing(DataStream &&s) { uint8_t flag; s >> flag; if (flag) - { - s >> pname >> nonce; - DataStream tmp; - tmp << pname << nonce; - peer_id = tmp.get_hash(); - } + s >> claimed_addr >> nonce; } }; struct MsgPong: public MsgPing { static const OpcodeType opcode; MsgPong(): MsgPing() {} - MsgPong(const uint256_t &_pname, const uint256_t _nonce): MsgPing(_pname, _nonce) {} + MsgPong(const NetAddr &_claimed_addr, const uint256_t _nonce): + MsgPing(_claimed_addr, _nonce) {} MsgPong(DataStream &&s): MsgPing(std::move(s)) {} }; - void msg_ping(MsgPing &&msg, const conn_t &conn); - void msg_pong(MsgPong &&msg, const conn_t &conn); + void ping_handler(MsgPing &&msg, const conn_t &conn); + void pong_handler(MsgPong &&msg, const conn_t &conn); void _ping_msg_cb(const conn_t &conn, uint16_t port); void _pong_msg_cb(const conn_t &conn, uint16_t port); bool check_handshake(Peer *peer); @@ -400,6 +401,14 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { } void on_setup(const ConnPool::conn_t &) override; void on_teardown(const ConnPool::conn_t &) override; + uint256_t gen_peer_id(const conn_t &conn, const NetAddr &claimed_addr, const uint256_t &nonce) { + DataStream tmp; + if (!this->enable_tls || id_mode == ADDR_BASED) + tmp << nonce << claimed_addr; + else + tmp << conn->get_peer_cert()->get_der(); + return tmp.get_hash(); + } public: @@ -420,7 +429,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { _ping_period(30), _conn_timeout(180), _allow_unknown_peer(false), - _id_mode(ADDR_BASED) {} + _id_mode(CERT_BASED) {} Config &retry_conn_delay(double x) { @@ -450,14 +459,14 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { }; PeerNetwork(const EventContext &ec, const Config &config): - MsgNet(ec, config), - id_mode(config._id_mode), - retry_conn_delay(config._retry_conn_delay), - ping_period(config._ping_period), - conn_timeout(config._conn_timeout), - allow_unknown_peer(config._allow_unknown_peer) { - this->reg_handler(generic_bind(&PeerNetwork::msg_ping, this, _1, _2)); - this->reg_handler(generic_bind(&PeerNetwork::msg_pong, this, _1, _2)); + MsgNet(ec, config), + id_mode(config._id_mode), + retry_conn_delay(config._retry_conn_delay), + ping_period(config._ping_period), + conn_timeout(config._conn_timeout), + allow_unknown_peer(config._allow_unknown_peer) { + this->reg_handler(generic_bind(&PeerNetwork::ping_handler, this, _1, _2)); + this->reg_handler(generic_bind(&PeerNetwork::pong_handler, this, _1, _2)); } virtual ~PeerNetwork() { this->stop(); } @@ -481,6 +490,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { conn_t connect(const NetAddr &addr) = delete; template<typename Func> void reg_unknown_peer_handler(Func cb) { unknown_peer_cb = cb; } + template<typename Func> + void reg_peer_handler(Func cb) { peer_cb = cb; } }; /* this callback is run by a worker */ @@ -577,43 +588,63 @@ void PeerNetwork<O, _, __>::on_setup(const ConnPool::conn_t &_conn) { auto &ev_timeout = conn->ev_timeout; conn->ev_retry_timer.del(); assert(!ev_timeout); - ev_timeout = TimerEvent(worker->get_ec(), [worker, conn](TimerEvent &) { + ev_timeout = TimerEvent(worker->get_ec(), [listen_addr=this->listen_addr, worker, conn](TimerEvent &) { try { - SALTICIDAE_LOG_INFO("peer ping-pong timeout"); - conn->get_net()->worker_terminate(conn); + SALTICIDAE_LOG_INFO("peer ping-pong timeout %s <-> %s", + std::string(listen_addr).c_str(), + std::string(conn->get_peer_addr()).c_str()); + //conn->get_net()->worker_terminate(conn); } catch (...) { worker->error_callback(std::current_exception()); } }); /* the initial ping-pong to set up the connection */ tcall_reset_timeout(worker, conn, conn_timeout); pending_peers[conn->get_addr()] = conn; if (conn->get_mode() == Conn::ConnMode::ACTIVE) - send_msg(MsgPing(my_pname, my_nonce), conn); + send_msg(MsgPing(listen_addr, my_nonce), conn); } template<typename O, O _, O __> void PeerNetwork<O, _, __>::on_teardown(const ConnPool::conn_t &_conn) { MsgNet::on_teardown(_conn); auto conn = static_pointer_cast<Conn>(_conn); - const auto &addr = conn->get_addr(); + auto addr = conn->get_addr(); conn->ev_retry_timer.clear(); conn->ev_timeout.clear(); pending_peers.erase(addr); + SALTICIDAE_LOG_INFO("connection lost: %s", std::string(*conn).c_str()); auto p = conn->peer; + if (p) addr = p->peer_addr; + TimerEvent retry_timer(this->disp_ec, [this, addr](TimerEvent &) { + try { + start_active_conn(addr); + } catch (...) { this->disp_error_cb(std::current_exception()); } + }); + auto it = known_peers.find(addr); + if (it == known_peers.end()) return; + pending_peers[addr] = conn; if (p) { if (conn != p->conn) return; + p->inbound_conn = nullptr; + p->outbound_conn = nullptr; p->ev_ping_timer.del(); p->connected = false; + p->outbound_handshake = false; + p->inbound_handshake = false; known_peers[p->peer_addr] = uint256_t(); - // try to reconnect - conn->ev_retry_timer = TimerEvent(this->disp_ec, [this, addr](TimerEvent &) { - try { - start_active_conn(addr); - } catch (...) { this->disp_error_cb(std::current_exception()); } + pid2peer.erase(p->peer_id); + this->user_tcall->async_call([this, conn](ThreadCall::Handle &) { + if (peer_cb) peer_cb(conn, false); }); + conn->ev_retry_timer = std::move(retry_timer); + conn->ev_retry_timer.add(gen_conn_timeout()); + } + else + { + if (!it->second.is_null()) return; + conn->ev_retry_timer = std::move(retry_timer); conn->ev_retry_timer.add(gen_conn_timeout()); } - SALTICIDAE_LOG_INFO("connection lost: %s", std::string(*conn).c_str()); } template<typename O, O _, O __> @@ -666,13 +697,16 @@ bool PeerNetwork<O, _, __>::check_handshake(Peer *p) { color_begin = TTY_COLOR_BLUE; color_end = TTY_COLOR_RESET; } - SALTICIDAE_LOG_INFO("%sPeerNetwork: established connection with %s <-> %s via %s", + SALTICIDAE_LOG_INFO("%sPeerNetwork: established connection %s <-> %s via %s", color_begin, std::string(listen_addr).c_str(), std::string(p->peer_addr).c_str(), std::string(*(p->conn)).c_str(), color_end); } + this->user_tcall->async_call([this, conn=p->conn](ThreadCall::Handle &) { + if (peer_cb) peer_cb(conn, true); + }); return true; } @@ -684,52 +718,67 @@ void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) { template<typename O, O _, O __> inline typename PeerNetwork<O, _, __>::conn_t PeerNetwork<O, _, __>::_get_peer_conn(const NetAddr &addr) const { - auto it = pending_peers.find(addr); - if (it == pending_peers.end()) + auto it = known_peers.find(addr); + if (it == known_peers.end()) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - return it->second; + auto it2 = pid2peer.find(it->second); + assert(it2 != pid2peer.end()); + return it2->second->conn; } /* end: functions invoked by the dispatcher */ /* begin: functions invoked by the user loop */ template<typename O, O _, O __> -void PeerNetwork<O, _, __>::msg_ping(MsgPing &&msg, const conn_t &conn) { +void PeerNetwork<O, _, __>::ping_handler(MsgPing &&msg, const conn_t &conn) { this->disp_tcall->async_call([this, conn, msg=std::move(msg)](ThreadCall::Handle &) { try { auto conn_mode = conn->get_mode(); if (conn_mode == ConnPool::Conn::DEAD) return; - if (!msg.peer_id.is_null()) + if (!msg.claimed_addr.is_null()) { + auto peer_id = gen_peer_id(conn, msg.claimed_addr, msg.nonce); if (conn_mode == Conn::ConnMode::PASSIVE) { - send_msg(MsgPong(my_pname, my_nonce), conn); + if (!known_peers.count(msg.claimed_addr)) + { + this->user_tcall->async_call([this, addr=msg.claimed_addr](ThreadCall::Handle &) { + if (unknown_peer_cb) unknown_peer_cb(addr); + }); + this->disp_terminate(conn); + return; + } SALTICIDAE_LOG_INFO("%s inbound handshake from %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); - auto it = pid2peer.find(msg.peer_id); + send_msg(MsgPong(listen_addr, my_nonce), conn); + auto it = pid2peer.find(peer_id); if (it != pid2peer.end()) { + auto p = it->second.get(); + if (p->connected) + { + //conn->get_net()->disp_terminate(conn); + return; + } + auto &old_conn = p->inbound_conn; + if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) + { + SALTICIDAE_LOG_DEBUG("%s terminating old connection %s", + std::string(listen_addr).c_str(), + std::string(*old_conn).c_str()); + old_conn->peer = nullptr; + old_conn->get_net()->disp_terminate(old_conn); + } + old_conn = conn; if (msg.nonce < my_nonce) { - auto p = it->second.get(); - auto &old_conn = p->inbound_conn; - if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) - { - SALTICIDAE_LOG_INFO("%s terminating old connection %s", - std::string(listen_addr).c_str(), - std::string(*old_conn).c_str()); - old_conn->peer = nullptr; - old_conn->get_net()->disp_terminate(old_conn); - } - old_conn = conn; p->conn = conn; } } else { - it = pid2peer.insert(std::make_pair( - msg.peer_id, - new Peer(conn, conn, nullptr, this))).first; + it = pid2peer.insert(std::make_pair(peer_id, + new Peer(peer_id, conn, conn, nullptr, this))).first; } auto p = it->second.get(); p->inbound_handshake = true; @@ -749,39 +798,45 @@ void PeerNetwork<O, _, __>::msg_ping(MsgPing &&msg, const conn_t &conn) { } template<typename O, O _, O __> -void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { +void PeerNetwork<O, _, __>::pong_handler(MsgPong &&msg, const conn_t &conn) { this->disp_tcall->async_call([this, conn, msg=std::move(msg)](ThreadCall::Handle &) { try { auto conn_mode = conn->get_mode(); if (conn_mode == ConnPool::Conn::DEAD) return; - if (!msg.peer_id.is_null()) + if (!msg.claimed_addr.is_null()) { + auto peer_id = gen_peer_id(conn, msg.claimed_addr, msg.nonce); if (conn_mode == Conn::ConnMode::ACTIVE) { SALTICIDAE_LOG_INFO("%s outbound handshake to %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); - auto it = pid2peer.find(msg.peer_id); + auto it = pid2peer.find(peer_id); if (it != pid2peer.end()) { + auto p = it->second.get(); + if (p->connected) + { + conn->get_net()->disp_terminate(conn); + return; + } + auto &old_conn = p->outbound_conn; + if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) + { + SALTICIDAE_LOG_DEBUG("%s terminating old connection %s", + std::string(listen_addr).c_str(), + std::string(*old_conn).c_str()); + old_conn->peer = nullptr; + old_conn->get_net()->disp_terminate(old_conn); + } + old_conn = conn; if (my_nonce < msg.nonce) { - auto p = it->second.get(); - auto &old_conn = p->outbound_conn; - if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) - { - SALTICIDAE_LOG_INFO("%s terminating old connection %s", - std::string(listen_addr).c_str(), - std::string(*old_conn).c_str()); - old_conn->peer = nullptr; - old_conn->get_net()->disp_terminate(old_conn); - } - old_conn = conn; p->conn = conn; } else { - SALTICIDAE_LOG_INFO("%s terminating low connection %s", + SALTICIDAE_LOG_DEBUG("%s terminating low connection %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); conn->get_net()->disp_terminate(conn); @@ -789,9 +844,8 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { } else { - it = pid2peer.insert(std::make_pair( - msg.peer_id, - new Peer(conn, nullptr, conn, this))).first; + it = pid2peer.insert(std::make_pair(peer_id, + new Peer(peer_id, conn, nullptr, conn, this))).first; } auto p = it->second.get(); p->outbound_handshake = true; @@ -806,7 +860,11 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { else { auto p = conn->peer; - if (!p) return; + if (!p) + { + SALTICIDAE_LOG_WARN("unexpected poing mesage"); + return; + } p->pong_msg_ok = true; if (p->ping_timer_ok) { @@ -826,18 +884,6 @@ void PeerNetwork<O, _, __>::listen(NetAddr _listen_addr) { try { MsgNet::_listen(_listen_addr); listen_addr = _listen_addr; - DataStream pid; - if (id_mode == CERT_BASED) - { - if (!this->enable_tls) - throw PeerNetworkError(SALTI_ERROR_TLS_LOAD_CERT); - pid << this->tls_cert->get_der(); - } - else - { - pid << listen_addr; - } - my_pname = pid.get_hash(); uint8_t rand_bytes[32]; if (!RAND_bytes(rand_bytes, 32)) throw PeerNetworkError(SALTI_ERROR_RAND_SOURCE); @@ -868,14 +914,25 @@ template<typename O, O _, O __> void PeerNetwork<O, _, __>::del_peer(const NetAddr &addr) { this->disp_tcall->async_call([this, addr](ThreadCall::Handle &) { try { - if (!known_peers.erase(addr)) + auto it = known_peers.find(addr); + if (it == known_peers.end()) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - auto it = pending_peers.find(addr); - assert(it != pending_peers.end()); - auto conn = it->second; - auto p = conn->peer; - if (p) pid2peer.erase(p->peer_id); - this->disp_terminate(conn); + auto peer_id = it->second; + known_peers.erase(it); + auto it2 = pending_peers.find(addr); + if (it2 != pending_peers.end()) + { + if (!it2->second->peer) + this->disp_terminate(it2->second); + pending_peers.erase(it2); + } + auto it3 = pid2peer.find(peer_id); + if (it3 != pid2peer.end()) + { + auto p = it3->second.get(); + this->disp_terminate(p->conn); + pid2peer.erase(it3); + } } catch (const PeerNetworkError &) { this->recoverable_error(std::current_exception()); } catch (...) { this->disp_error_cb(std::current_exception()); } diff --git a/include/salticidae/stream.h b/include/salticidae/stream.h index dc47792..2efc532 100644 --- a/include/salticidae/stream.h +++ b/include/salticidae/stream.h @@ -198,8 +198,40 @@ class DataStream { inline uint256_t get_hash() const; }; +class Serializable { + public: + virtual ~Serializable() = default; + virtual void serialize(DataStream &s) const = 0; + virtual void unserialize(DataStream &s) = 0; + + virtual void from_bytes(const bytearray_t &raw_bytes) { + DataStream s(raw_bytes); + s >> *this; + } + + virtual void from_bytes(bytearray_t &&raw_bytes) { + DataStream s(std::move(raw_bytes)); + s >> *this; + } + + + virtual void from_hex(const std::string &hex_str) { + DataStream s; + s.load_hex(hex_str); + s >> *this; + } + + bytearray_t to_bytes() const { + DataStream s; + s << *this; + return std::move(s); + } + + inline std::string to_hex() const; +}; + template<size_t N, typename T = uint64_t> -class Blob { +class Blob: public Serializable { using _impl_type = T; static const size_t bit_per_datum = sizeof(_impl_type) * 8; static_assert(!(N % bit_per_datum), "N must be divisible by bit_per_datum"); @@ -252,7 +284,7 @@ class Blob { size_t cheap_hash() const { return *data; } - void serialize(DataStream &s) const { + void serialize(DataStream &s) const override { if (loaded) { for (const _impl_type *ptr = data; ptr < data + _len; ptr++) @@ -265,7 +297,7 @@ class Blob { } } - void unserialize(DataStream &s) { + void unserialize(DataStream &s) override { for (_impl_type *ptr = data; ptr < data + _len; ptr++) { _impl_type x; @@ -424,37 +456,7 @@ inline bytearray_t from_hex(const std::string &hex_str) { return std::move(s); } -class Serializable { - public: - virtual ~Serializable() = default; - virtual void serialize(DataStream &s) const = 0; - virtual void unserialize(DataStream &s) = 0; - - virtual void from_bytes(const bytearray_t &raw_bytes) { - DataStream s(raw_bytes); - s >> *this; - } - - virtual void from_bytes(bytearray_t &&raw_bytes) { - DataStream s(std::move(raw_bytes)); - s >> *this; - } - - - virtual void from_hex(const std::string &hex_str) { - DataStream s; - s.load_hex(hex_str); - s >> *this; - } - - bytearray_t to_bytes() const { - DataStream s; - s << *this; - return std::move(s); - } - - std::string to_hex() const { return get_hex(*this); } -}; +inline std::string Serializable::to_hex() const { return get_hex(*this); } } diff --git a/src/conn.cpp b/src/conn.cpp index ab08399..84f08a4 100644 --- a/src/conn.cpp +++ b/src/conn.cpp @@ -223,9 +223,13 @@ void ConnPool::Conn::_recv_data_tls_handshake(const conn_t &conn, int, int) { /* finishing TLS handshake */ conn->send_data_func = _send_data_tls; conn->recv_data_func = _recv_data_dummy; + conn->ev_socket.del(); + conn->ev_socket.add(FdEvent::WRITE); conn->peer_cert = new X509(conn->tls->get_peer_cert()); conn->worker->enable_send_buffer(conn, conn->fd); - conn->cpool->update_conn(conn, true); + auto cpool = conn->cpool; + cpool->on_setup(conn); + cpool->update_conn(conn, true); } else { @@ -301,7 +305,6 @@ void ConnPool::accept_client(int fd, int) { SALTICIDAE_LOG_INFO("accepted %s", std::string(*conn).c_str()); auto &worker = select_worker(); conn->worker = &worker; - on_setup(conn); worker.feed(conn, client_fd); } } catch (...) { recoverable_error(std::current_exception()); } @@ -315,14 +318,13 @@ void ConnPool::conn_server(const conn_t &conn, int fd, int events) { SALTICIDAE_LOG_INFO("connected to remote %s", std::string(*conn).c_str()); auto &worker = select_worker(); conn->worker = &worker; - on_setup(conn); worker.feed(conn, fd); } else { if (events & TimedFdEvent::TIMEOUT) SALTICIDAE_LOG_INFO("%s connect timeout", std::string(*conn).c_str()); - throw SalticidaeError(SALTI_ERROR_CONNECT); + throw SalticidaeError(SALTI_ERROR_CONNECT, errno); } } catch (...) { disp_terminate(conn); diff --git a/src/network.cpp b/src/network.cpp index 74f4df9..a02ead1 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -250,8 +250,8 @@ void peernetwork_listen(peernetwork_t *self, const netaddr_t *listen_addr, Salti void peernetwork_reg_unknown_peer_handler(peernetwork_t *self, msgnetwork_unknown_peer_callback_t cb, void *userdata) { - self->reg_unknown_peer_handler([=](const NetAddr &addr) { - cb(&addr, userdata); + self->reg_unknown_peer_handler([=](const NetAddr &claimed_addr) { + cb(&claimed_addr, userdata); }); } diff --git a/test/.gitignore b/test/.gitignore index 7cebf5d..f50f029 100644 --- a/test/.gitignore +++ b/test/.gitignore @@ -2,6 +2,7 @@ test_msg test_bits test_msgnet test_p2p +test_p2p_tls test_p2p_stress test_queue bench_network diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d112b7a..0a1d3f1 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -35,6 +35,9 @@ target_link_libraries(test_msgnet_tls salticidae_static) add_executable(test_p2p test_p2p.cpp) target_link_libraries(test_p2p salticidae_static) +add_executable(test_p2p_tls test_p2p_tls.cpp) +target_link_libraries(test_p2p_tls salticidae_static) + add_executable(test_p2p_stress test_p2p_stress.cpp) target_link_libraries(test_p2p_stress salticidae_static) diff --git a/test/test_p2p.cpp b/test/test_p2p.cpp index 14304eb..7f80f85 100644 --- a/test/test_p2p.cpp +++ b/test/test_p2p.cpp @@ -87,8 +87,14 @@ struct Net { this->id, fatal ? "fatal" : "recoverable", err.what()); } }); - net->reg_unknown_peer_handler([this](const NetAddr &addr) { - fprintf(stdout, "net %lu: unknown peer attempts to connnect %s\n", this->id, std::string(addr).c_str()); + net->reg_peer_handler([this](const PeerNetwork::conn_t &conn, bool connected) { + fprintf(stdout, "net %lu: %s peer %s\n", this->id, + connected ? "connected to" : "disconnected from", + std::string(conn->get_peer_addr()).c_str()); + }); + net->reg_unknown_peer_handler([this](const NetAddr &claimed_addr) { + fprintf(stdout, "net %lu: unknown peer %s attempts to connnect\n", + this->id, std::string(claimed_addr).c_str()); }); th = std::thread([=](){ try { @@ -258,8 +264,8 @@ int main(int argc, char **argv) { fprintf(stdout, "add <node-id> <port> -- start a node (create a PeerNetwork instance)\n" "addpeer <node-id> <peer-id> -- add a peer to a given node\n" - "rmpeer <node-id> <peer-id> -- add a peer to a given node\n" - "rm <node-id> -- remove a node (destroy a PeerNetwork instance)\n" + "delpeer <node-id> <peer-id> -- add a peer to a given node\n" + "del <node-id> -- remove a node (destroy a PeerNetwork instance)\n" "msg <node-id> <peer-id> <msg> -- send a text message to a node\n" "ls -- list all node ids\n" "exit -- quit the program\n" diff --git a/test/test_p2p_stress.cpp b/test/test_p2p_stress.cpp index 1cb2ca3..1eb4a0d 100644 --- a/test/test_p2p_stress.cpp +++ b/test/test_p2p_stress.cpp @@ -97,29 +97,29 @@ void install_proto(AppContext &app, const size_t &seg_buff_size) { auto &ec = app.ec; auto &net = *app.net; auto send_rand = [&](int size, const MyNet::conn_t &conn) { - auto &tc = app.tc[conn->get_addr()]; + auto addr = conn->get_peer_addr(); + assert(!addr.is_null()); + auto &tc = app.tc[addr]; MsgRand msg(size); tc.hash = msg.serialized.get_hash(); net.send_msg(std::move(msg), conn); }; - net.reg_conn_handler([&, send_rand](const ConnPool::conn_t &conn, bool connected) { + net.reg_peer_handler([&, send_rand](const MyNet::conn_t &conn, bool connected) { if (connected) { - if (conn->get_mode() == ConnPool::Conn::ACTIVE) - { - auto &tc = app.tc[conn->get_addr()]; - tc.state = 1; - SALTICIDAE_LOG_INFO("increasing phase"); - send_rand(tc.state, static_pointer_cast<MyNet::Conn>(conn)); - } + auto addr = conn->get_peer_addr(); + assert(!addr.is_null()); + auto &tc = app.tc[addr]; + tc.state = 1; + SALTICIDAE_LOG_INFO("increasing phase"); + send_rand(tc.state, conn); } - return true; }); net.reg_error_handler([ec](const std::exception_ptr _err, bool fatal) { try { std::rethrow_exception(_err); } catch (const std::exception & err) { - SALTICIDAE_LOG_WARN("main thread captured %s error: %s", + SALTICIDAE_LOG_WARN("captured %s error: %s", fatal ? "fatal" : "recoverable", err.what()); } }); @@ -128,7 +128,9 @@ void install_proto(AppContext &app, const size_t &seg_buff_size) { net.send_msg(MsgAck(hash), conn); }); net.reg_handler([&, send_rand](MsgAck &&msg, const MyNet::conn_t &conn) { - auto &tc = app.tc[conn->get_addr()]; + auto addr = conn->get_peer_addr(); + assert(!addr.is_null()); + auto &tc = app.tc[addr]; if (msg.hash != tc.hash) { SALTICIDAE_LOG_ERROR("corrupted I/O!"); diff --git a/test/test_p2p_tls.cpp b/test/test_p2p_tls.cpp new file mode 100644 index 0000000..9fe0aec --- /dev/null +++ b/test/test_p2p_tls.cpp @@ -0,0 +1,317 @@ +/** + * Copyright (c) 2019 Ava Labs, Inc. + * + * Author: Ted Yin <[email protected]> + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + * of the Software, and to permit persons to whom the Software is furnished to do + * so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include <cstdio> +#include <cstdint> +#include <string> +#include <functional> +#include <unordered_map> + +#include "salticidae/msg.h" +#include "salticidae/event.h" +#include "salticidae/network.h" +#include "salticidae/stream.h" + +using salticidae::NetAddr; +using salticidae::DataStream; +using salticidae::htole; +using salticidae::letoh; +using salticidae::EventContext; +using salticidae::ThreadCall; +using salticidae::PKey; +using std::placeholders::_1; +using std::placeholders::_2; + +using PeerNetwork = salticidae::PeerNetwork<uint8_t>; + +struct MsgText { + static const uint8_t opcode = 0x0; + DataStream serialized; + uint64_t id; + std::string text; + + MsgText(uint64_t id, const std::string &text) { + serialized << salticidae::htole(id) << salticidae::htole((uint32_t)text.length()) << text; + } + + MsgText(DataStream &&s) { + uint32_t len; + s >> id; + id = salticidae::letoh(id); + s >> len; + len = salticidae::letoh(len); + text = std::string((const char *)s.get_data_inplace(len), len); + } +}; + +const uint8_t MsgText::opcode; + +struct Net { + uint64_t id; + EventContext ec; + ThreadCall tc; + std::thread th; + PeerNetwork *net; + const std::string listen_addr; + + Net(uint64_t id, uint16_t port): id(id), tc(ec), listen_addr("127.0.0.1:"+ std::to_string(port)) { + auto tls_key = new PKey(PKey::create_privkey_rsa(2048)); + auto tls_cert = new salticidae::X509(salticidae::X509::create_self_signed_from_pubkey(*tls_key)); + tls_key->save_privkey_to_file(std::to_string(port) + "_pkey.pem"); + tls_cert->save_to_file(std::to_string(port) + ".pem"); + net = new PeerNetwork(ec, PeerNetwork::Config(salticidae::ConnPool::Config() + .enable_tls(true) + .tls_key(tls_key) + .tls_cert(tls_cert) + ).conn_timeout(5) + .ping_period(2) + .id_mode(PeerNetwork::IdentityMode::ADDR_BASED)); + net->reg_handler([this](const MsgText &msg, const PeerNetwork::conn_t &) { + fprintf(stdout, "net %lu: peer %lu says %s\n", this->id, msg.id, msg.text.c_str()); + }); + net->reg_conn_handler([this](const salticidae::ConnPool::conn_t &conn, bool connected) { + if (connected) + { + fprintf(stdout, "net %lu: peer's cert is %s\n", this->id, + salticidae::get_hash(conn->get_peer_cert()->get_der()).to_hex().c_str()); + } + return true; + }); + net->reg_error_handler([this](const std::exception_ptr _err, bool fatal) { + try { + std::rethrow_exception(_err); + } catch (const std::exception &err) { + fprintf(stdout, "net %lu: captured %s error during an async call: %s\n", + this->id, fatal ? "fatal" : "recoverable", err.what()); + } + }); + net->reg_peer_handler([this](const PeerNetwork::conn_t &conn, bool connected) { + fprintf(stdout, "net %lu: %s peer %s\n", this->id, + connected ? "connected to" : "disconnected from", + std::string(conn->get_peer_addr()).c_str()); + }); + net->reg_unknown_peer_handler([this](const NetAddr &claimed_addr) { + fprintf(stdout, "net %lu: unknown peer %s attempts to connnect\n", + this->id, std::string(claimed_addr).c_str()); + }); + th = std::thread([=](){ + try { + net->start(); + net->listen(NetAddr(listen_addr)); + fprintf(stdout, "net %lu: listen to %s\n", id, listen_addr.c_str()); + ec.dispatch(); + } catch (std::exception &err) { + fprintf(stdout, "net %lu: got error during a sync call: %s\n", id, err.what()); + } + fprintf(stdout, "net %lu: main loop ended\n", id); + }); + } + + void add_peer(const std::string &listen_addr) { + try { + net->add_peer(NetAddr(listen_addr)); + } catch (std::exception &err) { + fprintf(stdout, "net %lu: got error during a sync call: %s\n", id, err.what()); + } + } + + void del_peer(const std::string &listen_addr) { + try { + net->del_peer(NetAddr(listen_addr)); + } catch (std::exception &err) { + fprintf(stdout, "net %lu: got error during a sync call: %s\n", id, err.what()); + } + } + + void stop_join() { + tc.async_call([ec=this->ec](ThreadCall::Handle &) { ec.stop(); }); + th.join(); + } + + ~Net() { delete net; } +}; + +std::unordered_map<uint64_t, Net *> nets; +std::unordered_map<std::string, std::function<void(char *)> > cmd_map; + +int read_int(char *buff) { + scanf("%64s", buff); + try { + int t = std::stoi(buff); + if (t < 0) throw std::invalid_argument("negative"); + return t; + } catch (std::invalid_argument) { + fprintf(stdout, "expect a non-negative integer\n"); + return -1; + } +} + +int main(int argc, char **argv) { + int i; + fprintf(stdout, "p2p network library playground (type help for more info)\n"); + fprintf(stdout, "========================================================\n"); + + auto cmd_exit = [](char *) { + for (auto &p: nets) + { + p.second->stop_join(); + delete p.second; + } + exit(0); + }; + + auto cmd_add = [](char *buff) { + int id = read_int(buff); + if (id < 0) return; + if (nets.count(id)) + { + fprintf(stdout, "net id already exists\n"); + return; + } + int port = read_int(buff); + if (port < 0) return; + if (port >= 65536) + { + fprintf(stdout, "port should be < 65536\n"); + return; + } + nets.insert(std::make_pair(id, new Net(id, port))); + }; + + auto cmd_ls = [](char *) { + for (auto &p: nets) + fprintf(stdout, "%d -> %s\n", p.first, p.second->listen_addr.c_str()); + }; + + auto cmd_del = [](char *buff) { + int id = read_int(buff); + if (id < 0) return; + auto it = nets.find(id); + if (it == nets.end()) + { + fprintf(stdout, "net id does not exist\n"); + return; + } + it->second->stop_join(); + delete it->second; + nets.erase(it); + }; + + auto cmd_addpeer = [](char *buff) { + int id = read_int(buff); + if (id < 0) return; + auto it = nets.find(id); + if (it == nets.end()) + { + fprintf(stdout, "net id does not exist\n"); + return; + } + int id2 = read_int(buff); + if (id2 < 0) return; + auto it2 = nets.find(id2); + if (it2 == nets.end()) + { + fprintf(stdout, "net id does not exist\n"); + return; + } + it->second->add_peer(it2->second->listen_addr); + }; + + auto cmd_delpeer = [](char *buff) { + int id = read_int(buff); + if (id < 0) return; + auto it = nets.find(id); + if (it == nets.end()) + { + fprintf(stdout, "net id does not exist\n"); + return; + } + int id2 = read_int(buff); + if (id2 < 0) return; + auto it2 = nets.find(id2); + if (it2 == nets.end()) + { + fprintf(stdout, "net id does not exist\n"); + return; + } + it->second->del_peer(it2->second->listen_addr); + }; + + auto cmd_msg = [](char *buff) { + int id = read_int(buff); + if (id < 0) return; + auto it = nets.find(id); + if (it == nets.end()) + { + fprintf(stdout, "net id does not exist\n"); + return; + } + int id2 = read_int(buff); + if (id2 < 0) return; + auto it2 = nets.find(id2); + if (it2 == nets.end()) + { + fprintf(stdout, "net id does not exist\n"); + return; + } + scanf("%64s", buff); + it->second->net->send_msg(MsgText(id, buff), it2->second->listen_addr); + }; + + auto cmd_help = [](char *) { + fprintf(stdout, + "add <node-id> <port> -- start a node (create a PeerNetwork instance)\n" + "addpeer <node-id> <peer-id> -- add a peer to a given node\n" + "delpeer <node-id> <peer-id> -- add a peer to a given node\n" + "del <node-id> -- remove a node (destroy a PeerNetwork instance)\n" + "msg <node-id> <peer-id> <msg> -- send a text message to a node\n" + "ls -- list all node ids\n" + "exit -- quit the program\n" + "help -- show this info\n" + ); + }; + + cmd_map.insert(std::make_pair("add", cmd_add)); + cmd_map.insert(std::make_pair("addpeer", cmd_addpeer)); + cmd_map.insert(std::make_pair("del", cmd_del)); + cmd_map.insert(std::make_pair("delpeer", cmd_delpeer)); + cmd_map.insert(std::make_pair("msg", cmd_msg)); + cmd_map.insert(std::make_pair("ls", cmd_ls)); + cmd_map.insert(std::make_pair("exit", cmd_exit)); + cmd_map.insert(std::make_pair("help", cmd_help)); + + for (;;) + { + fprintf(stdout, "> "); + char buff[128]; + if (scanf("%64s", buff) == EOF) break; + auto it = cmd_map.find(buff); + if (it == cmd_map.end()) + fprintf(stdout, "invalid comand \"%s\"\n", buff); + else + (it->second)(buff); + } + + return 0; +} |