diff options
-rw-r--r-- | include/salticidae/conn.h | 3 | ||||
-rw-r--r-- | include/salticidae/network.h | 91 | ||||
-rw-r--r-- | test/test_p2p.cpp | 8 |
3 files changed, 78 insertions, 24 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 62825db..e19ae3a 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -160,7 +160,10 @@ class ConnPool { const size_t queue_capacity; /* owned by user loop */ + protected: BoxObj<ThreadCall> user_tcall; + + private: conn_callback_t conn_cb; error_callback_t error_cb; diff --git a/include/salticidae/network.h b/include/salticidae/network.h index 1095c89..fc33414 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -259,6 +259,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { public: using MsgNet = MsgNetwork<OpcodeType>; using Msg = typename MsgNet::Msg; + using unknown_callback_t = std::function<void(const NetAddr &)>; + enum IdentityMode { IP_BASED, IP_PORT_BASED @@ -323,13 +325,16 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { void reset_conn(conn_t conn); }; - std::unordered_map <NetAddr, BoxObj<Peer>> id2peer; + std::unordered_map<NetAddr, BoxObj<Peer>> id2peer; + std::unordered_map<NetAddr, BoxObj<Peer>> id2upeer; + unknown_callback_t unknown_peer_cb; const IdentityMode id_mode; double retry_conn_delay; double ping_period; double conn_timeout; uint16_t listen_port; + bool allow_unknown_peer; struct MsgPing { static const OpcodeType opcode; @@ -365,6 +370,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { void start_active_conn(const NetAddr &paddr); static void tcall_reset_timeout(ConnPool::Worker *worker, const conn_t &conn, double timeout); + Peer *get_peer(const NetAddr &id) const; protected: ConnPool::Conn *create_conn() override { return new Conn(); } @@ -379,6 +385,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { double _retry_conn_delay; double _ping_period; double _conn_timeout; + bool _allow_unknown_peer; IdentityMode _id_mode; public: @@ -389,6 +396,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { _retry_conn_delay(2), _ping_period(30), _conn_timeout(180), + _allow_unknown_peer(false), _id_mode(IP_PORT_BASED) {} @@ -411,6 +419,11 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { _id_mode = x; return *this; } + + Config &allow_unknown_peer(bool x) { + _allow_unknown_peer = x; + return *this; + } }; PeerNetwork(const EventContext &ec, const Config &config): @@ -418,7 +431,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { id_mode(config._id_mode), retry_conn_delay(config._retry_conn_delay), ping_period(config._ping_period), - conn_timeout(config._conn_timeout) { + conn_timeout(config._conn_timeout), + allow_unknown_peer(config._allow_unknown_peer) { this->reg_handler(generic_bind(&PeerNetwork::msg_ping, this, _1, _2)); this->reg_handler(generic_bind(&PeerNetwork::msg_pong, this, _1, _2)); } @@ -439,6 +453,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { void listen(NetAddr listen_addr); conn_t connect(const NetAddr &addr) = delete; + template<typename Func> + void reg_unknown_peer_handler(Func cb) { unknown_peer_cb = cb; } }; /* this callback is run by a worker */ @@ -529,9 +545,8 @@ template<typename O, O _, O __> void PeerNetwork<O, _, __>::Conn::on_teardown() { MsgNet::Conn::on_teardown(); auto pn = get_net(); - auto it = pn->id2peer.find(peer_id); - if (it == pn->id2peer.end()) return; - auto p = it->second.get(); + auto p = pn->get_peer(peer_id); + if (!p) return; if (this != p->conn.get()) return; p->ev_ping_timer.del(); p->connected = false; @@ -599,11 +614,25 @@ bool PeerNetwork<O, _, __>::check_new_conn(const conn_t &conn, uint16_t port) { conn->peer_id.ip = conn->get_addr().ip; conn->peer_id.port = port; } - auto it = id2peer.find(conn->peer_id); + const auto &id = conn->peer_id; + auto it = id2peer.find(id); if (it == id2peer.end()) - { - conn->disp_terminate(); - return true; + { /* found an unknown peer */ + const auto &addr = conn->get_addr(); + this->user_tcall->async_call([this, id](ThreadCall::Handle &) { + unknown_peer_cb(id); + }); + if (allow_unknown_peer) + { + auto it2 = id2upeer.find(id); + if (it2 == id2upeer.end()) + it = id2upeer.insert(std::make_pair(id, new Peer(addr, nullptr, this->disp_ec))).first; + } + else + { + conn->disp_terminate(); + return true; + } } auto p = it->second.get(); if (p->connected) @@ -638,7 +667,7 @@ bool PeerNetwork<O, _, __>::check_new_conn(const conn_t &conn, uint16_t port) { template<typename O, O _, O __> void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) { - auto p = id2peer.find(addr)->second.get(); + auto p = get_peer(addr); if (p->connected) return; auto conn = static_pointer_cast<Conn>(MsgNet::_connect(addr)); //assert(p->conn == nullptr); @@ -648,6 +677,15 @@ void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) { conn->peer_id.port = 0; } +template<typename O, O _, O __> +typename PeerNetwork<O, _, __>::Peer *PeerNetwork<O, _, __>::get_peer(const NetAddr &addr) const { + auto it = id2peer.find(addr); + if (it != id2peer.end()) return it->second.get(); + it = id2upeer.find(addr); + if (it != id2upeer.end()) return it->second.get(); + return nullptr; +} + template<typename OpcodeType> inline void MsgNetwork<OpcodeType>::_send_msg_dispatcher(const Msg &msg, const conn_t &conn) { bytearray_t msg_data = msg.serialize(); @@ -683,14 +721,13 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) { this->disp_tcall->async_call([this, conn, port](ThreadCall::Handle &) { try { if (conn->get_mode() == ConnPool::Conn::DEAD) return; - auto it = id2peer.find(conn->peer_id); - if (it == id2peer.end()) + auto p = get_peer(conn->peer_id); + if (!p) { SALTICIDAE_LOG_WARN("pong message discarded"); return; } if (check_new_conn(conn, port)) return; - auto p = it->second.get(); p->pong_msg_ok = true; if (p->ping_timer_ok) { @@ -724,7 +761,15 @@ void PeerNetwork<O, _, __>::add_peer(const NetAddr &addr) { auto it = id2peer.find(addr); if (it != id2peer.end()) throw PeerNetworkError(SALTI_ERROR_PEER_ALREADY_EXISTS); - id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this->disp_ec))); + auto it2 = id2upeer.find(addr); + if (it2 != id2peer.end()) + { /* move to the known peer set */ + auto p = std::move(it2->second); + id2upeer.erase(it2); + id2peer.insert(std::make_pair(addr, std::move(p))); + } + else + id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this->disp_ec))); start_active_conn(addr); } catch (const PeerNetworkError &) { this->recoverable_error(std::current_exception()); @@ -755,10 +800,10 @@ PeerNetwork<O, _, __>::get_peer_conn(const NetAddr &paddr) const { conn_t conn; std::exception_ptr err = nullptr; try { - auto it = id2peer.find(paddr); - if (it == id2peer.end()) + auto p = get_peer(paddr); + if (!p) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - conn = it->second->conn; + conn = p->conn; } catch (const PeerNetworkError &) { this->recoverable_error(std::current_exception()); } catch (...) { @@ -789,10 +834,10 @@ void PeerNetwork<O, _, __>::_send_msg(Msg &&msg, const NetAddr &paddr) { this->disp_tcall->async_call( [this, msg=std::move(msg), paddr](ThreadCall::Handle &) { try { - auto it = id2peer.find(paddr); - if (it == id2peer.end()) + auto p = get_peer(paddr); + if (!p) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - this->_send_msg_dispatcher(msg, it->second->conn); + this->_send_msg_dispatcher(msg, p->conn); } catch (const PeerNetworkError &) { this->recoverable_error(std::current_exception()); } catch (...) { this->recoverable_error(std::current_exception()); } @@ -812,10 +857,10 @@ void PeerNetwork<O, _, __>::_multicast_msg(Msg &&msg, const std::vector<NetAddr> try { for (auto &addr: paddrs) { - auto it = id2peer.find(addr); - if (it == id2peer.end()) + auto p = get_peer(addr); + if (!p) throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST); - this->_send_msg_dispatcher(msg, it->second->conn); + this->_send_msg_dispatcher(msg, p->conn); } } catch (const PeerNetworkError &) { this->recoverable_error(std::current_exception()); diff --git a/test/test_p2p.cpp b/test/test_p2p.cpp index 28bab10..85aeca1 100644 --- a/test/test_p2p.cpp +++ b/test/test_p2p.cpp @@ -58,6 +58,9 @@ struct Net { net->reg_error_handler([this](const std::exception &err, bool fatal) { fprintf(stdout, "net %lu: captured %s error during an async call: %s\n", this->id, fatal ? "fatal" : "recoverable", err.what()); }); + net->reg_unknown_peer_handler([this](const NetAddr &addr) { + fprintf(stdout, "net %lu: unknown peer attempts to connnect %s\n", this->id, std::string(addr).c_str()); + }); th = std::thread([=](){ try { net->start(); @@ -117,7 +120,10 @@ int main(int argc, char **argv) { auto cmd_exit = [](char *) { for (auto &p: nets) + { p.second->stop_join(); + delete p.second; + } exit(0); }; @@ -141,7 +147,7 @@ int main(int argc, char **argv) { auto cmd_ls = [](char *) { for (auto &p: nets) - fprintf(stdout, "%d\n", p.first); + fprintf(stdout, "%d -> %s\n", p.first, p.second->listen_addr.c_str()); }; auto cmd_del = [](char *buff) { |