diff options
author | Determinant <[email protected]> | 2019-06-27 20:33:06 -0400 |
---|---|---|
committer | Determinant <[email protected]> | 2019-06-27 20:33:06 -0400 |
commit | d15ec0b93def57e5f3832f429a3b948e86a62887 (patch) | |
tree | 407070a72607b6b0f0b1742e3676c731fb3f7e06 /include | |
parent | 85552ce1b0bc997f58341f21ab8bbcf7d937ab4b (diff) |
finish p2p & TLS integration and testing
Diffstat (limited to 'include')
-rw-r--r-- | include/salticidae/conn.h | 8 | ||||
-rw-r--r-- | include/salticidae/crypto.h | 55 | ||||
-rw-r--r-- | include/salticidae/event.h | 4 | ||||
-rw-r--r-- | include/salticidae/network.h | 277 | ||||
-rw-r--r-- | include/salticidae/stream.h | 70 |
5 files changed, 267 insertions, 147 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 87966ac..e39d31d 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -195,7 +195,12 @@ class ConnPool { if (enable_tls && connected) { conn->worker->get_tcall()->async_call([this, conn, ret](ThreadCall::Handle &) { - if (ret) conn->recv_data_func = Conn::_recv_data_tls; + if (ret) + { + conn->recv_data_func = Conn::_recv_data_tls; + conn->ev_socket.del(); + conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); + } else worker_terminate(conn); }); } @@ -262,6 +267,7 @@ class ConnPool { conn->send_data_func = Conn::_send_data; conn->recv_data_func = Conn::_recv_data; enable_send_buffer(conn, client_fd); + cpool->on_setup(conn); cpool->update_conn(conn, true); } assert(conn->fd != -1); diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index d7231a4..fe4de4f 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -31,6 +31,7 @@ #ifdef __cplusplus #include <openssl/sha.h> #include <openssl/ssl.h> +#include <openssl/bn.h> namespace salticidae { @@ -137,10 +138,30 @@ static int _skip_CA_check(int, X509_STORE_CTX *) { class PKey { EVP_PKEY *key; friend class TLSContext; + friend class X509; public: PKey(EVP_PKEY *key): key(key) {} PKey(const PKey &) = delete; PKey(PKey &&other): key(other.key) { other.key = nullptr; } + + static PKey create_privkey_rsa(size_t nbits = 2048) { + auto key = EVP_PKEY_new(); + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + auto e = BN_new(); + BN_set_word(e, 17); + auto rsa = RSA_new(); + auto ret = RSA_generate_key_ex(rsa, nbits, e, nullptr); + BN_free(e); + if (!ret) + { + RSA_free(rsa); + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + } + EVP_PKEY_set1_RSA(key, rsa); + RSA_free(rsa); + return PKey(key); + } static PKey create_privkey_from_pem_file(std::string pem_fname, std::string *passwd = nullptr) { FILE *fp = fopen(pem_fname.c_str(), "r"); @@ -193,6 +214,14 @@ class PKey { return std::move(res); } + void save_privkey_to_file(const std::string &fname) { + FILE *fp = fopen(fname.c_str(), "wb"); + auto ret = PEM_write_PrivateKey(fp, key, nullptr, nullptr, 0, nullptr, nullptr); + fclose(fp); + if (!ret) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + } + ~PKey() { if (key) EVP_PKEY_free(key); } }; @@ -203,6 +232,24 @@ class X509 { X509(::X509 *x509): x509(x509) {} X509(const X509 &) = delete; X509(X509 &&other): x509(other.x509) { other.x509 = nullptr; } + + static X509 create_self_signed_from_pubkey(const PKey &pkey, const char *country = "US", const char *common_name = "localhost") { + auto x509 = X509_new(); + ASN1_INTEGER_set(X509_get_serialNumber(x509), 1); + X509_set_pubkey(x509, pkey.key); + X509_gmtime_adj(X509_get_notBefore(x509), 0); + X509_gmtime_adj(X509_get_notAfter(x509), 0); + auto name = X509_get_subject_name(x509); + X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC, + (unsigned char *)country, -1, -1, 0); + X509_NAME_add_entry_by_txt(name, "O", MBSTRING_ASC, + (unsigned char *)"libsalticidae", -1, -1, 0); + X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC, + (unsigned char *)common_name, -1, -1, 0); + X509_set_issuer_name(x509, name); + X509_sign(x509, pkey.key, EVP_sha1()); + return X509(x509); + } static X509 create_from_pem_file(std::string pem_fname, std::string *passwd = nullptr) { FILE *fp = fopen(pem_fname.c_str(), "r"); @@ -251,6 +298,14 @@ class X509 { return std::move(res); } + void save_to_file(const std::string &fname) { + FILE *fp = fopen(fname.c_str(), "wb"); + auto ret = PEM_write_X509(fp, x509); + fclose(fp); + if (!ret) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + } + ~X509() { if (x509) X509_free(x509); } }; diff --git a/include/salticidae/event.h b/include/salticidae/event.h index ad78a6e..a7ea209 100644 --- a/include/salticidae/event.h +++ b/include/salticidae/event.h @@ -565,7 +565,7 @@ class MPSCQueueEventDriven: public MPSCQueue<T> { // memory barrier here, so any load/store in enqueue must be finialized if (wait_sig.exchange(false, std::memory_order_acq_rel)) { - SALTICIDAE_LOG_DEBUG("mpsc notify"); + //SALTICIDAE_LOG_DEBUG("mpsc notify"); write(fd, &dummy, 8); } return true; @@ -616,7 +616,7 @@ class MPMCQueueEventDriven: public MPMCQueue<T> { // memory barrier here, so any load/store in enqueue must be finialized if (wait_sig.exchange(false, std::memory_order_acq_rel)) { - SALTICIDAE_LOG_DEBUG("mpsc notify"); + //SALTICIDAE_LOG_DEBUG("mpmc notify"); write(fd, &dummy, 8); } return true; diff --git a/include/salticidae/network.h b/include/salticidae/network.h index 975084f..d3f3bae 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -261,7 +261,6 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { public: using MsgNet = MsgNetwork<OpcodeType>; using Msg = typename MsgNet::Msg; - using unknown_callback_t = std::function<void(const NetAddr &)>; enum IdentityMode { ADDR_BASED, @@ -282,6 +281,9 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { public: Conn(): MsgNet::Conn(), peer(nullptr) {} + NetAddr get_peer_addr() { + return peer ? peer->peer_addr : NetAddr(); + } PeerNetwork *get_net() { return static_cast<PeerNetwork *>(ConnPool::Conn::get_pool()); @@ -295,11 +297,13 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { }; using conn_t = ArcObj<Conn>; + using peer_callback_t = std::function<void(const conn_t &peer_conn, bool connected)>; + using unknown_peer_callback_t = std::function<void(const NetAddr &claimed_addr)>; private: - struct Peer { + class Peer { + friend PeerNetwork; /** connection addr, may be different due to passive mode */ - uint256_t nonce; uint256_t peer_id; NetAddr peer_addr; /** the underlying connection, may be invalid when connected = false */ @@ -316,17 +320,17 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { double ping_period; Peer() = delete; - Peer(conn_t conn, conn_t inbound_conn, conn_t outbound_conn, const PeerNetwork *pn): - conn(conn), - inbound_conn(inbound_conn), - outbound_conn(outbound_conn), - ev_ping_timer( - TimerEvent(pn->disp_ec, std::bind(&Peer::ping_timer, this, _1))), - connected(false), - outbound_handshake(false), - inbound_handshake(false), - ping_period(pn->ping_period) {} - ~Peer() {} + Peer(const uint256_t &peer_id, conn_t conn, conn_t inbound_conn, conn_t outbound_conn, const PeerNetwork *pn): + peer_id(peer_id), + conn(conn), + inbound_conn(inbound_conn), + outbound_conn(outbound_conn), + ev_ping_timer( + TimerEvent(pn->disp_ec, std::bind(&Peer::ping_timer, this, _1))), + connected(false), + outbound_handshake(false), + inbound_handshake(false), + ping_period(pn->ping_period) {} Peer &operator=(const Peer &) = delete; Peer(const Peer &) = delete; @@ -337,12 +341,15 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { if (ev_ping_timer) ev_ping_timer.del(); } + public: + ~Peer() {} }; std::unordered_map<NetAddr, conn_t> pending_peers; std::unordered_map<NetAddr, uint256_t> known_peers; std::unordered_map<uint256_t, BoxObj<Peer>> pid2peer; - unknown_callback_t unknown_peer_cb; + peer_callback_t peer_cb; + unknown_peer_callback_t unknown_peer_cb; const IdentityMode id_mode; double retry_conn_delay; @@ -350,41 +357,35 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { double conn_timeout; NetAddr listen_addr; bool allow_unknown_peer; - uint256_t my_pname; uint256_t my_nonce; struct MsgPing { static const OpcodeType opcode; DataStream serialized; - uint256_t pname; + NetAddr claimed_addr; uint256_t nonce; - uint256_t peer_id; - MsgPing() { serialized << false; } - MsgPing(const uint256_t &_pname, const uint256_t &_nonce) { - serialized << true << _pname << _nonce; + MsgPing() { serialized << (uint8_t)false; } + MsgPing(const NetAddr &_claimed_addr, const uint256_t &_nonce) { + serialized << (uint8_t)true << _claimed_addr << _nonce; } MsgPing(DataStream &&s) { uint8_t flag; s >> flag; if (flag) - { - s >> pname >> nonce; - DataStream tmp; - tmp << pname << nonce; - peer_id = tmp.get_hash(); - } + s >> claimed_addr >> nonce; } }; struct MsgPong: public MsgPing { static const OpcodeType opcode; MsgPong(): MsgPing() {} - MsgPong(const uint256_t &_pname, const uint256_t _nonce): MsgPing(_pname, _nonce) {} + MsgPong(const NetAddr &_claimed_addr, const uint256_t _nonce): + MsgPing(_claimed_addr, _nonce) {} MsgPong(DataStream &&s): MsgPing(std::move(s)) {} }; - void msg_ping(MsgPing &&msg, const conn_t &conn); - void msg_pong(MsgPong &&msg, const conn_t &conn); + void ping_handler(MsgPing &&msg, const conn_t &conn); + void pong_handler(MsgPong &&msg, const conn_t &conn); void _ping_msg_cb(const conn_t &conn, uint16_t port); void _pong_msg_cb(const conn_t &conn, uint16_t port); bool check_handshake(Peer *peer); @@ -400,6 +401,14 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { } void on_setup(const ConnPool::conn_t &) override; void on_teardown(const ConnPool::conn_t &) override; + uint256_t gen_peer_id(const conn_t &conn, const NetAddr &claimed_addr, const uint256_t &nonce) { + DataStream tmp; + if (!this->enable_tls || id_mode == ADDR_BASED) + tmp << nonce << claimed_addr; + else + tmp << conn->get_peer_cert()->get_der(); + return tmp.get_hash(); + } public: @@ -420,7 +429,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { _ping_period(30), _conn_timeout(180), _allow_unknown_peer(false), - _id_mode(ADDR_BASED) {} + _id_mode(CERT_BASED) {} Config &retry_conn_delay(double x) { @@ -450,14 +459,14 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { }; PeerNetwork(const EventContext &ec, const Config &config): - MsgNet(ec, config), - id_mode(config._id_mode), - retry_conn_delay(config._retry_conn_delay), - ping_period(config._ping_period), - conn_timeout(config._conn_timeout), - allow_unknown_peer(config._allow_unknown_peer) { - this->reg_handler(generic_bind(&PeerNetwork::msg_ping, this, _1, _2)); - this->reg_handler(generic_bind(&PeerNetwork::msg_pong, this, _1, _2)); + MsgNet(ec, config), + id_mode(config._id_mode), + retry_conn_delay(config._retry_conn_delay), + ping_period(config._ping_period), + conn_timeout(config._conn_timeout), + allow_unknown_peer(config._allow_unknown_peer) { + this->reg_handler(generic_bind(&PeerNetwork::ping_handler, this, _1, _2)); + this->reg_handler(generic_bind(&PeerNetwork::pong_handler, this, _1, _2)); } virtual ~PeerNetwork() { this->stop(); } @@ -481,6 +490,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { conn_t connect(const NetAddr &addr) = delete; template<typename Func> void reg_unknown_peer_handler(Func cb) { unknown_peer_cb = cb; } + template<typename Func> + void reg_peer_handler(Func cb) { peer_cb = cb; } }; /* this callback is run by a worker */ @@ -577,43 +588,63 @@ void PeerNetwork<O, _, __>::on_setup(const ConnPool::conn_t &_conn) { auto &ev_timeout = conn->ev_timeout; conn->ev_retry_timer.del(); assert(!ev_timeout); - ev_timeout = TimerEvent(worker->get_ec(), [worker, conn](TimerEvent &) { + ev_timeout = TimerEvent(worker->get_ec(), [listen_addr=this->listen_addr, worker, conn](TimerEvent &) { try { - SALTICIDAE_LOG_INFO("peer ping-pong timeout"); - conn->get_net()->worker_terminate(conn); + SALTICIDAE_LOG_INFO("peer ping-pong timeout %s <-> %s", + std::string(listen_addr).c_str(), + std::string(conn->get_peer_addr()).c_str()); + //conn->get_net()->worker_terminate(conn); } catch (...) { worker->error_callback(std::current_exception()); } }); /* the initial ping-pong to set up the connection */ tcall_reset_timeout(worker, conn, conn_timeout); pending_peers[conn->get_addr()] = conn; if (conn->get_mode() == Conn::ConnMode::ACTIVE) - send_msg(MsgPing(my_pname, my_nonce), conn); + send_msg(MsgPing(listen_addr, my_nonce), conn); } template<typename O, O _, O __> void PeerNetwork<O, _, __>::on_teardown(const ConnPool::conn_t &_conn) { MsgNet::on_teardown(_conn); auto conn = static_pointer_cast<Conn>(_conn); - const auto &addr = conn->get_addr(); + auto addr = conn->get_addr(); conn->ev_retry_timer.clear(); conn->ev_timeout.clear(); pending_peers.erase(addr); + SALTICIDAE_LOG_INFO("connection lost: %s", std::string(*conn).c_str()); auto p = conn->peer; + if (p) addr = p->peer_addr; + TimerEvent retry_timer(this->disp_ec, [this, addr](TimerEvent &) { + try { + start_active_conn(addr); + } catch (...) { this->disp_error_cb(std::current_exception()); } + }); + auto it = known_peers.find(addr); + if (it == known_peers.end()) return; + pending_peers[addr] = conn; if (p) { if (conn != p->conn) return; + p->inbound_conn = nullptr; + p->outbound_conn = nullptr; p->ev_ping_timer.del(); p->connected = false; + p->outbound_handshake = false; + p->inbound_handshake = false; known_peers[p->peer_addr] = uint256_t(); - // try to reconnect - conn->ev_retry_timer = TimerEvent(this->disp_ec, [this, addr](TimerEvent &) { - try { - start_active_conn(addr); - } catch (...) { this->disp_error_cb(std::current_exception()); } + pid2peer.erase(p->peer_id); + this->user_tcall->async_call([this, conn](ThreadCall::Handle &) { + if (peer_cb) peer_cb(conn, false); }); + conn->ev_retry_timer = std::move(retry_timer); + conn->ev_retry_timer.add(gen_conn_timeout()); + } + else + { + if (!it->second.is_null()) return; + conn->ev_retry_timer = std::move(retry_timer); conn->ev_retry_timer.add(gen_conn_timeout()); } - SALTICIDAE_LOG_INFO("connection lost: %s", std::string(*conn).c_str()); } template<typename O, O _, O __> @@ -666,13 +697,16 @@ bool PeerNetwork<O, _, __>::check_handshake(Peer *p) { color_begin = TTY_COLOR_BLUE; color_end = TTY_COLOR_RESET; } - SALTICIDAE_LOG_INFO("%sPeerNetwork: established connection with %s <-> %s via %s", + SALTICIDAE_LOG_INFO("%sPeerNetwork: established connection %s <-> %s via %s", color_begin, std::string(listen_addr).c_str(), std::string(p->peer_addr).c_str(), std::string(*(p->conn)).c_str(), color_end); } + this->user_tcall->async_call([this, conn=p->conn](ThreadCall::Handle &) { + if (peer_cb) peer_cb(conn, true); + }); return true; } @@ -684,52 +718,67 @@ void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) { template<typename O, O _, O __> inline typename PeerNetwork<O, _, __>::conn_t PeerNetwork<O, _, __>::_get_peer_conn(const NetAddr &addr) const { - auto it = pending_peers.find(addr); - if (it == pending_peers.end()) + auto it = known_peers.find(addr); + if (it == known_peers.end()) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - return it->second; + auto it2 = pid2peer.find(it->second); + assert(it2 != pid2peer.end()); + return it2->second->conn; } /* end: functions invoked by the dispatcher */ /* begin: functions invoked by the user loop */ template<typename O, O _, O __> -void PeerNetwork<O, _, __>::msg_ping(MsgPing &&msg, const conn_t &conn) { +void PeerNetwork<O, _, __>::ping_handler(MsgPing &&msg, const conn_t &conn) { this->disp_tcall->async_call([this, conn, msg=std::move(msg)](ThreadCall::Handle &) { try { auto conn_mode = conn->get_mode(); if (conn_mode == ConnPool::Conn::DEAD) return; - if (!msg.peer_id.is_null()) + if (!msg.claimed_addr.is_null()) { + auto peer_id = gen_peer_id(conn, msg.claimed_addr, msg.nonce); if (conn_mode == Conn::ConnMode::PASSIVE) { - send_msg(MsgPong(my_pname, my_nonce), conn); + if (!known_peers.count(msg.claimed_addr)) + { + this->user_tcall->async_call([this, addr=msg.claimed_addr](ThreadCall::Handle &) { + if (unknown_peer_cb) unknown_peer_cb(addr); + }); + this->disp_terminate(conn); + return; + } SALTICIDAE_LOG_INFO("%s inbound handshake from %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); - auto it = pid2peer.find(msg.peer_id); + send_msg(MsgPong(listen_addr, my_nonce), conn); + auto it = pid2peer.find(peer_id); if (it != pid2peer.end()) { + auto p = it->second.get(); + if (p->connected) + { + //conn->get_net()->disp_terminate(conn); + return; + } + auto &old_conn = p->inbound_conn; + if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) + { + SALTICIDAE_LOG_DEBUG("%s terminating old connection %s", + std::string(listen_addr).c_str(), + std::string(*old_conn).c_str()); + old_conn->peer = nullptr; + old_conn->get_net()->disp_terminate(old_conn); + } + old_conn = conn; if (msg.nonce < my_nonce) { - auto p = it->second.get(); - auto &old_conn = p->inbound_conn; - if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) - { - SALTICIDAE_LOG_INFO("%s terminating old connection %s", - std::string(listen_addr).c_str(), - std::string(*old_conn).c_str()); - old_conn->peer = nullptr; - old_conn->get_net()->disp_terminate(old_conn); - } - old_conn = conn; p->conn = conn; } } else { - it = pid2peer.insert(std::make_pair( - msg.peer_id, - new Peer(conn, conn, nullptr, this))).first; + it = pid2peer.insert(std::make_pair(peer_id, + new Peer(peer_id, conn, conn, nullptr, this))).first; } auto p = it->second.get(); p->inbound_handshake = true; @@ -749,39 +798,45 @@ void PeerNetwork<O, _, __>::msg_ping(MsgPing &&msg, const conn_t &conn) { } template<typename O, O _, O __> -void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { +void PeerNetwork<O, _, __>::pong_handler(MsgPong &&msg, const conn_t &conn) { this->disp_tcall->async_call([this, conn, msg=std::move(msg)](ThreadCall::Handle &) { try { auto conn_mode = conn->get_mode(); if (conn_mode == ConnPool::Conn::DEAD) return; - if (!msg.peer_id.is_null()) + if (!msg.claimed_addr.is_null()) { + auto peer_id = gen_peer_id(conn, msg.claimed_addr, msg.nonce); if (conn_mode == Conn::ConnMode::ACTIVE) { SALTICIDAE_LOG_INFO("%s outbound handshake to %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); - auto it = pid2peer.find(msg.peer_id); + auto it = pid2peer.find(peer_id); if (it != pid2peer.end()) { + auto p = it->second.get(); + if (p->connected) + { + conn->get_net()->disp_terminate(conn); + return; + } + auto &old_conn = p->outbound_conn; + if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) + { + SALTICIDAE_LOG_DEBUG("%s terminating old connection %s", + std::string(listen_addr).c_str(), + std::string(*old_conn).c_str()); + old_conn->peer = nullptr; + old_conn->get_net()->disp_terminate(old_conn); + } + old_conn = conn; if (my_nonce < msg.nonce) { - auto p = it->second.get(); - auto &old_conn = p->outbound_conn; - if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) - { - SALTICIDAE_LOG_INFO("%s terminating old connection %s", - std::string(listen_addr).c_str(), - std::string(*old_conn).c_str()); - old_conn->peer = nullptr; - old_conn->get_net()->disp_terminate(old_conn); - } - old_conn = conn; p->conn = conn; } else { - SALTICIDAE_LOG_INFO("%s terminating low connection %s", + SALTICIDAE_LOG_DEBUG("%s terminating low connection %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); conn->get_net()->disp_terminate(conn); @@ -789,9 +844,8 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { } else { - it = pid2peer.insert(std::make_pair( - msg.peer_id, - new Peer(conn, nullptr, conn, this))).first; + it = pid2peer.insert(std::make_pair(peer_id, + new Peer(peer_id, conn, nullptr, conn, this))).first; } auto p = it->second.get(); p->outbound_handshake = true; @@ -806,7 +860,11 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { else { auto p = conn->peer; - if (!p) return; + if (!p) + { + SALTICIDAE_LOG_WARN("unexpected poing mesage"); + return; + } p->pong_msg_ok = true; if (p->ping_timer_ok) { @@ -826,18 +884,6 @@ void PeerNetwork<O, _, __>::listen(NetAddr _listen_addr) { try { MsgNet::_listen(_listen_addr); listen_addr = _listen_addr; - DataStream pid; - if (id_mode == CERT_BASED) - { - if (!this->enable_tls) - throw PeerNetworkError(SALTI_ERROR_TLS_LOAD_CERT); - pid << this->tls_cert->get_der(); - } - else - { - pid << listen_addr; - } - my_pname = pid.get_hash(); uint8_t rand_bytes[32]; if (!RAND_bytes(rand_bytes, 32)) throw PeerNetworkError(SALTI_ERROR_RAND_SOURCE); @@ -868,14 +914,25 @@ template<typename O, O _, O __> void PeerNetwork<O, _, __>::del_peer(const NetAddr &addr) { this->disp_tcall->async_call([this, addr](ThreadCall::Handle &) { try { - if (!known_peers.erase(addr)) + auto it = known_peers.find(addr); + if (it == known_peers.end()) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - auto it = pending_peers.find(addr); - assert(it != pending_peers.end()); - auto conn = it->second; - auto p = conn->peer; - if (p) pid2peer.erase(p->peer_id); - this->disp_terminate(conn); + auto peer_id = it->second; + known_peers.erase(it); + auto it2 = pending_peers.find(addr); + if (it2 != pending_peers.end()) + { + if (!it2->second->peer) + this->disp_terminate(it2->second); + pending_peers.erase(it2); + } + auto it3 = pid2peer.find(peer_id); + if (it3 != pid2peer.end()) + { + auto p = it3->second.get(); + this->disp_terminate(p->conn); + pid2peer.erase(it3); + } } catch (const PeerNetworkError &) { this->recoverable_error(std::current_exception()); } catch (...) { this->disp_error_cb(std::current_exception()); } diff --git a/include/salticidae/stream.h b/include/salticidae/stream.h index dc47792..2efc532 100644 --- a/include/salticidae/stream.h +++ b/include/salticidae/stream.h @@ -198,8 +198,40 @@ class DataStream { inline uint256_t get_hash() const; }; +class Serializable { + public: + virtual ~Serializable() = default; + virtual void serialize(DataStream &s) const = 0; + virtual void unserialize(DataStream &s) = 0; + + virtual void from_bytes(const bytearray_t &raw_bytes) { + DataStream s(raw_bytes); + s >> *this; + } + + virtual void from_bytes(bytearray_t &&raw_bytes) { + DataStream s(std::move(raw_bytes)); + s >> *this; + } + + + virtual void from_hex(const std::string &hex_str) { + DataStream s; + s.load_hex(hex_str); + s >> *this; + } + + bytearray_t to_bytes() const { + DataStream s; + s << *this; + return std::move(s); + } + + inline std::string to_hex() const; +}; + template<size_t N, typename T = uint64_t> -class Blob { +class Blob: public Serializable { using _impl_type = T; static const size_t bit_per_datum = sizeof(_impl_type) * 8; static_assert(!(N % bit_per_datum), "N must be divisible by bit_per_datum"); @@ -252,7 +284,7 @@ class Blob { size_t cheap_hash() const { return *data; } - void serialize(DataStream &s) const { + void serialize(DataStream &s) const override { if (loaded) { for (const _impl_type *ptr = data; ptr < data + _len; ptr++) @@ -265,7 +297,7 @@ class Blob { } } - void unserialize(DataStream &s) { + void unserialize(DataStream &s) override { for (_impl_type *ptr = data; ptr < data + _len; ptr++) { _impl_type x; @@ -424,37 +456,7 @@ inline bytearray_t from_hex(const std::string &hex_str) { return std::move(s); } -class Serializable { - public: - virtual ~Serializable() = default; - virtual void serialize(DataStream &s) const = 0; - virtual void unserialize(DataStream &s) = 0; - - virtual void from_bytes(const bytearray_t &raw_bytes) { - DataStream s(raw_bytes); - s >> *this; - } - - virtual void from_bytes(bytearray_t &&raw_bytes) { - DataStream s(std::move(raw_bytes)); - s >> *this; - } - - - virtual void from_hex(const std::string &hex_str) { - DataStream s; - s.load_hex(hex_str); - s >> *this; - } - - bytearray_t to_bytes() const { - DataStream s; - s << *this; - return std::move(s); - } - - std::string to_hex() const { return get_hex(*this); } -}; +inline std::string Serializable::to_hex() const { return get_hex(*this); } } |