diff options
Diffstat (limited to 'include/salticidae/network.h')
-rw-r--r-- | include/salticidae/network.h | 277 |
1 files changed, 167 insertions, 110 deletions
diff --git a/include/salticidae/network.h b/include/salticidae/network.h index 975084f..d3f3bae 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -261,7 +261,6 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { public: using MsgNet = MsgNetwork<OpcodeType>; using Msg = typename MsgNet::Msg; - using unknown_callback_t = std::function<void(const NetAddr &)>; enum IdentityMode { ADDR_BASED, @@ -282,6 +281,9 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { public: Conn(): MsgNet::Conn(), peer(nullptr) {} + NetAddr get_peer_addr() { + return peer ? peer->peer_addr : NetAddr(); + } PeerNetwork *get_net() { return static_cast<PeerNetwork *>(ConnPool::Conn::get_pool()); @@ -295,11 +297,13 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { }; using conn_t = ArcObj<Conn>; + using peer_callback_t = std::function<void(const conn_t &peer_conn, bool connected)>; + using unknown_peer_callback_t = std::function<void(const NetAddr &claimed_addr)>; private: - struct Peer { + class Peer { + friend PeerNetwork; /** connection addr, may be different due to passive mode */ - uint256_t nonce; uint256_t peer_id; NetAddr peer_addr; /** the underlying connection, may be invalid when connected = false */ @@ -316,17 +320,17 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { double ping_period; Peer() = delete; - Peer(conn_t conn, conn_t inbound_conn, conn_t outbound_conn, const PeerNetwork *pn): - conn(conn), - inbound_conn(inbound_conn), - outbound_conn(outbound_conn), - ev_ping_timer( - TimerEvent(pn->disp_ec, std::bind(&Peer::ping_timer, this, _1))), - connected(false), - outbound_handshake(false), - inbound_handshake(false), - ping_period(pn->ping_period) {} - ~Peer() {} + Peer(const uint256_t &peer_id, conn_t conn, conn_t inbound_conn, conn_t outbound_conn, const PeerNetwork *pn): + peer_id(peer_id), + conn(conn), + inbound_conn(inbound_conn), + outbound_conn(outbound_conn), + ev_ping_timer( + TimerEvent(pn->disp_ec, std::bind(&Peer::ping_timer, this, _1))), + connected(false), + outbound_handshake(false), + inbound_handshake(false), + ping_period(pn->ping_period) {} Peer &operator=(const Peer &) = delete; Peer(const Peer &) = delete; @@ -337,12 +341,15 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { if (ev_ping_timer) ev_ping_timer.del(); } + public: + ~Peer() {} }; std::unordered_map<NetAddr, conn_t> pending_peers; std::unordered_map<NetAddr, uint256_t> known_peers; std::unordered_map<uint256_t, BoxObj<Peer>> pid2peer; - unknown_callback_t unknown_peer_cb; + peer_callback_t peer_cb; + unknown_peer_callback_t unknown_peer_cb; const IdentityMode id_mode; double retry_conn_delay; @@ -350,41 +357,35 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { double conn_timeout; NetAddr listen_addr; bool allow_unknown_peer; - uint256_t my_pname; uint256_t my_nonce; struct MsgPing { static const OpcodeType opcode; DataStream serialized; - uint256_t pname; + NetAddr claimed_addr; uint256_t nonce; - uint256_t peer_id; - MsgPing() { serialized << false; } - MsgPing(const uint256_t &_pname, const uint256_t &_nonce) { - serialized << true << _pname << _nonce; + MsgPing() { serialized << (uint8_t)false; } + MsgPing(const NetAddr &_claimed_addr, const uint256_t &_nonce) { + serialized << (uint8_t)true << _claimed_addr << _nonce; } MsgPing(DataStream &&s) { uint8_t flag; s >> flag; if (flag) - { - s >> pname >> nonce; - DataStream tmp; - tmp << pname << nonce; - peer_id = tmp.get_hash(); - } + s >> claimed_addr >> nonce; } }; struct MsgPong: public MsgPing { static const OpcodeType opcode; MsgPong(): MsgPing() {} - MsgPong(const uint256_t &_pname, const uint256_t _nonce): MsgPing(_pname, _nonce) {} + MsgPong(const NetAddr &_claimed_addr, const uint256_t _nonce): + MsgPing(_claimed_addr, _nonce) {} MsgPong(DataStream &&s): MsgPing(std::move(s)) {} }; - void msg_ping(MsgPing &&msg, const conn_t &conn); - void msg_pong(MsgPong &&msg, const conn_t &conn); + void ping_handler(MsgPing &&msg, const conn_t &conn); + void pong_handler(MsgPong &&msg, const conn_t &conn); void _ping_msg_cb(const conn_t &conn, uint16_t port); void _pong_msg_cb(const conn_t &conn, uint16_t port); bool check_handshake(Peer *peer); @@ -400,6 +401,14 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { } void on_setup(const ConnPool::conn_t &) override; void on_teardown(const ConnPool::conn_t &) override; + uint256_t gen_peer_id(const conn_t &conn, const NetAddr &claimed_addr, const uint256_t &nonce) { + DataStream tmp; + if (!this->enable_tls || id_mode == ADDR_BASED) + tmp << nonce << claimed_addr; + else + tmp << conn->get_peer_cert()->get_der(); + return tmp.get_hash(); + } public: @@ -420,7 +429,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { _ping_period(30), _conn_timeout(180), _allow_unknown_peer(false), - _id_mode(ADDR_BASED) {} + _id_mode(CERT_BASED) {} Config &retry_conn_delay(double x) { @@ -450,14 +459,14 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { }; PeerNetwork(const EventContext &ec, const Config &config): - MsgNet(ec, config), - id_mode(config._id_mode), - retry_conn_delay(config._retry_conn_delay), - ping_period(config._ping_period), - conn_timeout(config._conn_timeout), - allow_unknown_peer(config._allow_unknown_peer) { - this->reg_handler(generic_bind(&PeerNetwork::msg_ping, this, _1, _2)); - this->reg_handler(generic_bind(&PeerNetwork::msg_pong, this, _1, _2)); + MsgNet(ec, config), + id_mode(config._id_mode), + retry_conn_delay(config._retry_conn_delay), + ping_period(config._ping_period), + conn_timeout(config._conn_timeout), + allow_unknown_peer(config._allow_unknown_peer) { + this->reg_handler(generic_bind(&PeerNetwork::ping_handler, this, _1, _2)); + this->reg_handler(generic_bind(&PeerNetwork::pong_handler, this, _1, _2)); } virtual ~PeerNetwork() { this->stop(); } @@ -481,6 +490,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { conn_t connect(const NetAddr &addr) = delete; template<typename Func> void reg_unknown_peer_handler(Func cb) { unknown_peer_cb = cb; } + template<typename Func> + void reg_peer_handler(Func cb) { peer_cb = cb; } }; /* this callback is run by a worker */ @@ -577,43 +588,63 @@ void PeerNetwork<O, _, __>::on_setup(const ConnPool::conn_t &_conn) { auto &ev_timeout = conn->ev_timeout; conn->ev_retry_timer.del(); assert(!ev_timeout); - ev_timeout = TimerEvent(worker->get_ec(), [worker, conn](TimerEvent &) { + ev_timeout = TimerEvent(worker->get_ec(), [listen_addr=this->listen_addr, worker, conn](TimerEvent &) { try { - SALTICIDAE_LOG_INFO("peer ping-pong timeout"); - conn->get_net()->worker_terminate(conn); + SALTICIDAE_LOG_INFO("peer ping-pong timeout %s <-> %s", + std::string(listen_addr).c_str(), + std::string(conn->get_peer_addr()).c_str()); + //conn->get_net()->worker_terminate(conn); } catch (...) { worker->error_callback(std::current_exception()); } }); /* the initial ping-pong to set up the connection */ tcall_reset_timeout(worker, conn, conn_timeout); pending_peers[conn->get_addr()] = conn; if (conn->get_mode() == Conn::ConnMode::ACTIVE) - send_msg(MsgPing(my_pname, my_nonce), conn); + send_msg(MsgPing(listen_addr, my_nonce), conn); } template<typename O, O _, O __> void PeerNetwork<O, _, __>::on_teardown(const ConnPool::conn_t &_conn) { MsgNet::on_teardown(_conn); auto conn = static_pointer_cast<Conn>(_conn); - const auto &addr = conn->get_addr(); + auto addr = conn->get_addr(); conn->ev_retry_timer.clear(); conn->ev_timeout.clear(); pending_peers.erase(addr); + SALTICIDAE_LOG_INFO("connection lost: %s", std::string(*conn).c_str()); auto p = conn->peer; + if (p) addr = p->peer_addr; + TimerEvent retry_timer(this->disp_ec, [this, addr](TimerEvent &) { + try { + start_active_conn(addr); + } catch (...) { this->disp_error_cb(std::current_exception()); } + }); + auto it = known_peers.find(addr); + if (it == known_peers.end()) return; + pending_peers[addr] = conn; if (p) { if (conn != p->conn) return; + p->inbound_conn = nullptr; + p->outbound_conn = nullptr; p->ev_ping_timer.del(); p->connected = false; + p->outbound_handshake = false; + p->inbound_handshake = false; known_peers[p->peer_addr] = uint256_t(); - // try to reconnect - conn->ev_retry_timer = TimerEvent(this->disp_ec, [this, addr](TimerEvent &) { - try { - start_active_conn(addr); - } catch (...) { this->disp_error_cb(std::current_exception()); } + pid2peer.erase(p->peer_id); + this->user_tcall->async_call([this, conn](ThreadCall::Handle &) { + if (peer_cb) peer_cb(conn, false); }); + conn->ev_retry_timer = std::move(retry_timer); + conn->ev_retry_timer.add(gen_conn_timeout()); + } + else + { + if (!it->second.is_null()) return; + conn->ev_retry_timer = std::move(retry_timer); conn->ev_retry_timer.add(gen_conn_timeout()); } - SALTICIDAE_LOG_INFO("connection lost: %s", std::string(*conn).c_str()); } template<typename O, O _, O __> @@ -666,13 +697,16 @@ bool PeerNetwork<O, _, __>::check_handshake(Peer *p) { color_begin = TTY_COLOR_BLUE; color_end = TTY_COLOR_RESET; } - SALTICIDAE_LOG_INFO("%sPeerNetwork: established connection with %s <-> %s via %s", + SALTICIDAE_LOG_INFO("%sPeerNetwork: established connection %s <-> %s via %s", color_begin, std::string(listen_addr).c_str(), std::string(p->peer_addr).c_str(), std::string(*(p->conn)).c_str(), color_end); } + this->user_tcall->async_call([this, conn=p->conn](ThreadCall::Handle &) { + if (peer_cb) peer_cb(conn, true); + }); return true; } @@ -684,52 +718,67 @@ void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) { template<typename O, O _, O __> inline typename PeerNetwork<O, _, __>::conn_t PeerNetwork<O, _, __>::_get_peer_conn(const NetAddr &addr) const { - auto it = pending_peers.find(addr); - if (it == pending_peers.end()) + auto it = known_peers.find(addr); + if (it == known_peers.end()) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - return it->second; + auto it2 = pid2peer.find(it->second); + assert(it2 != pid2peer.end()); + return it2->second->conn; } /* end: functions invoked by the dispatcher */ /* begin: functions invoked by the user loop */ template<typename O, O _, O __> -void PeerNetwork<O, _, __>::msg_ping(MsgPing &&msg, const conn_t &conn) { +void PeerNetwork<O, _, __>::ping_handler(MsgPing &&msg, const conn_t &conn) { this->disp_tcall->async_call([this, conn, msg=std::move(msg)](ThreadCall::Handle &) { try { auto conn_mode = conn->get_mode(); if (conn_mode == ConnPool::Conn::DEAD) return; - if (!msg.peer_id.is_null()) + if (!msg.claimed_addr.is_null()) { + auto peer_id = gen_peer_id(conn, msg.claimed_addr, msg.nonce); if (conn_mode == Conn::ConnMode::PASSIVE) { - send_msg(MsgPong(my_pname, my_nonce), conn); + if (!known_peers.count(msg.claimed_addr)) + { + this->user_tcall->async_call([this, addr=msg.claimed_addr](ThreadCall::Handle &) { + if (unknown_peer_cb) unknown_peer_cb(addr); + }); + this->disp_terminate(conn); + return; + } SALTICIDAE_LOG_INFO("%s inbound handshake from %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); - auto it = pid2peer.find(msg.peer_id); + send_msg(MsgPong(listen_addr, my_nonce), conn); + auto it = pid2peer.find(peer_id); if (it != pid2peer.end()) { + auto p = it->second.get(); + if (p->connected) + { + //conn->get_net()->disp_terminate(conn); + return; + } + auto &old_conn = p->inbound_conn; + if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) + { + SALTICIDAE_LOG_DEBUG("%s terminating old connection %s", + std::string(listen_addr).c_str(), + std::string(*old_conn).c_str()); + old_conn->peer = nullptr; + old_conn->get_net()->disp_terminate(old_conn); + } + old_conn = conn; if (msg.nonce < my_nonce) { - auto p = it->second.get(); - auto &old_conn = p->inbound_conn; - if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) - { - SALTICIDAE_LOG_INFO("%s terminating old connection %s", - std::string(listen_addr).c_str(), - std::string(*old_conn).c_str()); - old_conn->peer = nullptr; - old_conn->get_net()->disp_terminate(old_conn); - } - old_conn = conn; p->conn = conn; } } else { - it = pid2peer.insert(std::make_pair( - msg.peer_id, - new Peer(conn, conn, nullptr, this))).first; + it = pid2peer.insert(std::make_pair(peer_id, + new Peer(peer_id, conn, conn, nullptr, this))).first; } auto p = it->second.get(); p->inbound_handshake = true; @@ -749,39 +798,45 @@ void PeerNetwork<O, _, __>::msg_ping(MsgPing &&msg, const conn_t &conn) { } template<typename O, O _, O __> -void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { +void PeerNetwork<O, _, __>::pong_handler(MsgPong &&msg, const conn_t &conn) { this->disp_tcall->async_call([this, conn, msg=std::move(msg)](ThreadCall::Handle &) { try { auto conn_mode = conn->get_mode(); if (conn_mode == ConnPool::Conn::DEAD) return; - if (!msg.peer_id.is_null()) + if (!msg.claimed_addr.is_null()) { + auto peer_id = gen_peer_id(conn, msg.claimed_addr, msg.nonce); if (conn_mode == Conn::ConnMode::ACTIVE) { SALTICIDAE_LOG_INFO("%s outbound handshake to %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); - auto it = pid2peer.find(msg.peer_id); + auto it = pid2peer.find(peer_id); if (it != pid2peer.end()) { + auto p = it->second.get(); + if (p->connected) + { + conn->get_net()->disp_terminate(conn); + return; + } + auto &old_conn = p->outbound_conn; + if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) + { + SALTICIDAE_LOG_DEBUG("%s terminating old connection %s", + std::string(listen_addr).c_str(), + std::string(*old_conn).c_str()); + old_conn->peer = nullptr; + old_conn->get_net()->disp_terminate(old_conn); + } + old_conn = conn; if (my_nonce < msg.nonce) { - auto p = it->second.get(); - auto &old_conn = p->outbound_conn; - if (old_conn && old_conn->get_mode() != Conn::ConnMode::DEAD) - { - SALTICIDAE_LOG_INFO("%s terminating old connection %s", - std::string(listen_addr).c_str(), - std::string(*old_conn).c_str()); - old_conn->peer = nullptr; - old_conn->get_net()->disp_terminate(old_conn); - } - old_conn = conn; p->conn = conn; } else { - SALTICIDAE_LOG_INFO("%s terminating low connection %s", + SALTICIDAE_LOG_DEBUG("%s terminating low connection %s", std::string(listen_addr).c_str(), std::string(*conn).c_str()); conn->get_net()->disp_terminate(conn); @@ -789,9 +844,8 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { } else { - it = pid2peer.insert(std::make_pair( - msg.peer_id, - new Peer(conn, nullptr, conn, this))).first; + it = pid2peer.insert(std::make_pair(peer_id, + new Peer(peer_id, conn, nullptr, conn, this))).first; } auto p = it->second.get(); p->outbound_handshake = true; @@ -806,7 +860,11 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { else { auto p = conn->peer; - if (!p) return; + if (!p) + { + SALTICIDAE_LOG_WARN("unexpected poing mesage"); + return; + } p->pong_msg_ok = true; if (p->ping_timer_ok) { @@ -826,18 +884,6 @@ void PeerNetwork<O, _, __>::listen(NetAddr _listen_addr) { try { MsgNet::_listen(_listen_addr); listen_addr = _listen_addr; - DataStream pid; - if (id_mode == CERT_BASED) - { - if (!this->enable_tls) - throw PeerNetworkError(SALTI_ERROR_TLS_LOAD_CERT); - pid << this->tls_cert->get_der(); - } - else - { - pid << listen_addr; - } - my_pname = pid.get_hash(); uint8_t rand_bytes[32]; if (!RAND_bytes(rand_bytes, 32)) throw PeerNetworkError(SALTI_ERROR_RAND_SOURCE); @@ -868,14 +914,25 @@ template<typename O, O _, O __> void PeerNetwork<O, _, __>::del_peer(const NetAddr &addr) { this->disp_tcall->async_call([this, addr](ThreadCall::Handle &) { try { - if (!known_peers.erase(addr)) + auto it = known_peers.find(addr); + if (it == known_peers.end()) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - auto it = pending_peers.find(addr); - assert(it != pending_peers.end()); - auto conn = it->second; - auto p = conn->peer; - if (p) pid2peer.erase(p->peer_id); - this->disp_terminate(conn); + auto peer_id = it->second; + known_peers.erase(it); + auto it2 = pending_peers.find(addr); + if (it2 != pending_peers.end()) + { + if (!it2->second->peer) + this->disp_terminate(it2->second); + pending_peers.erase(it2); + } + auto it3 = pid2peer.find(peer_id); + if (it3 != pid2peer.end()) + { + auto p = it3->second.get(); + this->disp_terminate(p->conn); + pid2peer.erase(it3); + } } catch (const PeerNetworkError &) { this->recoverable_error(std::current_exception()); } catch (...) { this->disp_error_cb(std::current_exception()); } |