From 5eaa90dba2e0720abb2f8b0f858d54ae544ff4d0 Mon Sep 17 00:00:00 2001
From: Determinant <tederminant@gmail.com>
Date: Thu, 13 Jun 2019 17:59:59 -0400
Subject: support unknown peer callback

---
 include/salticidae/conn.h    |  3 ++
 include/salticidae/network.h | 91 +++++++++++++++++++++++++++++++++-----------
 2 files changed, 71 insertions(+), 23 deletions(-)

(limited to 'include')

diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h
index 62825db..e19ae3a 100644
--- a/include/salticidae/conn.h
+++ b/include/salticidae/conn.h
@@ -160,7 +160,10 @@ class ConnPool {
     const size_t queue_capacity;
 
     /* owned by user loop */
+    protected:
     BoxObj<ThreadCall> user_tcall;
+
+    private:
     conn_callback_t conn_cb;
     error_callback_t error_cb;
 
diff --git a/include/salticidae/network.h b/include/salticidae/network.h
index 1095c89..fc33414 100644
--- a/include/salticidae/network.h
+++ b/include/salticidae/network.h
@@ -259,6 +259,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
     public:
     using MsgNet = MsgNetwork<OpcodeType>;
     using Msg = typename MsgNet::Msg;
+    using unknown_callback_t = std::function<void(const NetAddr &)>;
+
     enum IdentityMode {
         IP_BASED,
         IP_PORT_BASED
@@ -323,13 +325,16 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
         void reset_conn(conn_t conn);
     };
 
-    std::unordered_map <NetAddr, BoxObj<Peer>> id2peer;
+    std::unordered_map<NetAddr, BoxObj<Peer>> id2peer;
+    std::unordered_map<NetAddr, BoxObj<Peer>> id2upeer;
+    unknown_callback_t unknown_peer_cb;
 
     const IdentityMode id_mode;
     double retry_conn_delay;
     double ping_period;
     double conn_timeout;
     uint16_t listen_port;
+    bool allow_unknown_peer;
 
     struct MsgPing {
         static const OpcodeType opcode;
@@ -365,6 +370,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
     void start_active_conn(const NetAddr &paddr);
     static void tcall_reset_timeout(ConnPool::Worker *worker,
                                     const conn_t &conn, double timeout);
+    Peer *get_peer(const NetAddr &id) const;
 
     protected:
     ConnPool::Conn *create_conn() override { return new Conn(); }
@@ -379,6 +385,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
         double _retry_conn_delay;
         double _ping_period;
         double _conn_timeout;
+        bool _allow_unknown_peer;
         IdentityMode _id_mode;
 
         public:
@@ -389,6 +396,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
             _retry_conn_delay(2),
             _ping_period(30),
             _conn_timeout(180),
+            _allow_unknown_peer(false),
             _id_mode(IP_PORT_BASED) {}
 
 
@@ -411,6 +419,11 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
             _id_mode = x;
             return *this;
         }
+
+        Config &allow_unknown_peer(bool x) {
+            _allow_unknown_peer = x;
+            return *this;
+        }
     };
 
     PeerNetwork(const EventContext &ec, const Config &config):
@@ -418,7 +431,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
         id_mode(config._id_mode),
         retry_conn_delay(config._retry_conn_delay),
         ping_period(config._ping_period),
-        conn_timeout(config._conn_timeout) {
+        conn_timeout(config._conn_timeout),
+        allow_unknown_peer(config._allow_unknown_peer) {
         this->reg_handler(generic_bind(&PeerNetwork::msg_ping, this, _1, _2));
         this->reg_handler(generic_bind(&PeerNetwork::msg_pong, this, _1, _2));
     }
@@ -439,6 +453,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
 
     void listen(NetAddr listen_addr);
     conn_t connect(const NetAddr &addr) = delete;
+    template<typename Func>
+    void reg_unknown_peer_handler(Func cb) { unknown_peer_cb = cb; }
 };
 
 /* this callback is run by a worker */
@@ -529,9 +545,8 @@ template<typename O, O _, O __>
 void PeerNetwork<O, _, __>::Conn::on_teardown() {
     MsgNet::Conn::on_teardown();
     auto pn = get_net();
-    auto it = pn->id2peer.find(peer_id);
-    if (it == pn->id2peer.end()) return;
-    auto p = it->second.get();
+    auto p = pn->get_peer(peer_id);
+    if (!p) return;
     if (this != p->conn.get()) return;
     p->ev_ping_timer.del();
     p->connected = false;
@@ -599,11 +614,25 @@ bool PeerNetwork<O, _, __>::check_new_conn(const conn_t &conn, uint16_t port) {
         conn->peer_id.ip = conn->get_addr().ip;
         conn->peer_id.port = port;
     }
-    auto it = id2peer.find(conn->peer_id);
+    const auto &id = conn->peer_id;
+    auto it = id2peer.find(id);
     if (it == id2peer.end())
-    {
-        conn->disp_terminate();
-        return true;
+    {   /* found an unknown peer */
+        const auto &addr = conn->get_addr();
+        this->user_tcall->async_call([this, id](ThreadCall::Handle &) {
+            unknown_peer_cb(id);
+        });
+        if (allow_unknown_peer)
+        {
+            auto it2 = id2upeer.find(id);
+            if (it2 == id2upeer.end())
+                it = id2upeer.insert(std::make_pair(id, new Peer(addr, nullptr, this->disp_ec))).first;
+        }
+        else
+        {
+            conn->disp_terminate();
+            return true;
+        }
     }
     auto p = it->second.get();
     if (p->connected)
@@ -638,7 +667,7 @@ bool PeerNetwork<O, _, __>::check_new_conn(const conn_t &conn, uint16_t port) {
 
 template<typename O, O _, O __>
 void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) {
-    auto p = id2peer.find(addr)->second.get();
+    auto p = get_peer(addr);
     if (p->connected) return;
     auto conn = static_pointer_cast<Conn>(MsgNet::_connect(addr));
     //assert(p->conn == nullptr);
@@ -648,6 +677,15 @@ void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) {
         conn->peer_id.port = 0;
 }
 
+template<typename O, O _, O __>
+typename PeerNetwork<O, _, __>::Peer *PeerNetwork<O, _, __>::get_peer(const NetAddr &addr) const {
+    auto it = id2peer.find(addr);
+    if (it != id2peer.end()) return it->second.get();
+    it = id2upeer.find(addr);
+    if (it != id2upeer.end()) return it->second.get();
+    return nullptr;
+}
+
 template<typename OpcodeType>
 inline void MsgNetwork<OpcodeType>::_send_msg_dispatcher(const Msg &msg, const conn_t &conn) {
     bytearray_t msg_data = msg.serialize();
@@ -683,14 +721,13 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) {
     this->disp_tcall->async_call([this, conn, port](ThreadCall::Handle &) {
         try {
             if (conn->get_mode() == ConnPool::Conn::DEAD) return;
-            auto it = id2peer.find(conn->peer_id);
-            if (it == id2peer.end())
+            auto p = get_peer(conn->peer_id);
+            if (!p)
             {
                 SALTICIDAE_LOG_WARN("pong message discarded");
                 return;
             }
             if (check_new_conn(conn, port)) return;
-            auto p = it->second.get();
             p->pong_msg_ok = true;
             if (p->ping_timer_ok)
             {
@@ -724,7 +761,15 @@ void PeerNetwork<O, _, __>::add_peer(const NetAddr &addr) {
             auto it = id2peer.find(addr);
             if (it != id2peer.end())
                 throw PeerNetworkError(SALTI_ERROR_PEER_ALREADY_EXISTS);
-            id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this->disp_ec)));
+            auto it2 = id2upeer.find(addr);
+            if (it2 != id2peer.end())
+            { /* move to the known peer set */
+                auto p = std::move(it2->second);
+                id2upeer.erase(it2);
+                id2peer.insert(std::make_pair(addr, std::move(p)));
+            }
+            else
+                id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this->disp_ec)));
             start_active_conn(addr);
         } catch (const PeerNetworkError &) {
             this->recoverable_error(std::current_exception());
@@ -755,10 +800,10 @@ PeerNetwork<O, _, __>::get_peer_conn(const NetAddr &paddr) const {
         conn_t conn;
         std::exception_ptr err = nullptr;
         try {
-            auto it = id2peer.find(paddr);
-            if (it == id2peer.end())
+            auto p = get_peer(paddr);
+            if (!p)
                 throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST);
-            conn = it->second->conn;
+            conn = p->conn;
         } catch (const PeerNetworkError &) {
             this->recoverable_error(std::current_exception());
         } catch (...) {
@@ -789,10 +834,10 @@ void PeerNetwork<O, _, __>::_send_msg(Msg &&msg, const NetAddr &paddr) {
     this->disp_tcall->async_call(
                 [this, msg=std::move(msg), paddr](ThreadCall::Handle &) {
         try {
-            auto it = id2peer.find(paddr);
-            if (it == id2peer.end())
+            auto p = get_peer(paddr);
+            if (!p)
                 throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST);
-            this->_send_msg_dispatcher(msg, it->second->conn);
+            this->_send_msg_dispatcher(msg, p->conn);
         } catch (const PeerNetworkError &) {
             this->recoverable_error(std::current_exception());
         } catch (...) { this->recoverable_error(std::current_exception()); }
@@ -812,10 +857,10 @@ void PeerNetwork<O, _, __>::_multicast_msg(Msg &&msg, const std::vector<NetAddr>
         try {
             for (auto &addr: paddrs)
             {
-                auto it = id2peer.find(addr);
-                if (it == id2peer.end())
+                auto p = get_peer(addr);
+                if (!p)
                     throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST);
-                this->_send_msg_dispatcher(msg, it->second->conn);
+                this->_send_msg_dispatcher(msg, p->conn);
             }
         } catch (const PeerNetworkError &) {
             this->recoverable_error(std::current_exception());
-- 
cgit v1.2.3-70-g09d2