aboutsummaryrefslogtreecommitdiff
path: root/include/salticidae/network.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/salticidae/network.h')
-rw-r--r--include/salticidae/network.h91
1 files changed, 68 insertions, 23 deletions
diff --git a/include/salticidae/network.h b/include/salticidae/network.h
index 1095c89..fc33414 100644
--- a/include/salticidae/network.h
+++ b/include/salticidae/network.h
@@ -259,6 +259,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
public:
using MsgNet = MsgNetwork<OpcodeType>;
using Msg = typename MsgNet::Msg;
+ using unknown_callback_t = std::function<void(const NetAddr &)>;
+
enum IdentityMode {
IP_BASED,
IP_PORT_BASED
@@ -323,13 +325,16 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
void reset_conn(conn_t conn);
};
- std::unordered_map <NetAddr, BoxObj<Peer>> id2peer;
+ std::unordered_map<NetAddr, BoxObj<Peer>> id2peer;
+ std::unordered_map<NetAddr, BoxObj<Peer>> id2upeer;
+ unknown_callback_t unknown_peer_cb;
const IdentityMode id_mode;
double retry_conn_delay;
double ping_period;
double conn_timeout;
uint16_t listen_port;
+ bool allow_unknown_peer;
struct MsgPing {
static const OpcodeType opcode;
@@ -365,6 +370,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
void start_active_conn(const NetAddr &paddr);
static void tcall_reset_timeout(ConnPool::Worker *worker,
const conn_t &conn, double timeout);
+ Peer *get_peer(const NetAddr &id) const;
protected:
ConnPool::Conn *create_conn() override { return new Conn(); }
@@ -379,6 +385,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
double _retry_conn_delay;
double _ping_period;
double _conn_timeout;
+ bool _allow_unknown_peer;
IdentityMode _id_mode;
public:
@@ -389,6 +396,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
_retry_conn_delay(2),
_ping_period(30),
_conn_timeout(180),
+ _allow_unknown_peer(false),
_id_mode(IP_PORT_BASED) {}
@@ -411,6 +419,11 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
_id_mode = x;
return *this;
}
+
+ Config &allow_unknown_peer(bool x) {
+ _allow_unknown_peer = x;
+ return *this;
+ }
};
PeerNetwork(const EventContext &ec, const Config &config):
@@ -418,7 +431,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
id_mode(config._id_mode),
retry_conn_delay(config._retry_conn_delay),
ping_period(config._ping_period),
- conn_timeout(config._conn_timeout) {
+ conn_timeout(config._conn_timeout),
+ allow_unknown_peer(config._allow_unknown_peer) {
this->reg_handler(generic_bind(&PeerNetwork::msg_ping, this, _1, _2));
this->reg_handler(generic_bind(&PeerNetwork::msg_pong, this, _1, _2));
}
@@ -439,6 +453,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
void listen(NetAddr listen_addr);
conn_t connect(const NetAddr &addr) = delete;
+ template<typename Func>
+ void reg_unknown_peer_handler(Func cb) { unknown_peer_cb = cb; }
};
/* this callback is run by a worker */
@@ -529,9 +545,8 @@ template<typename O, O _, O __>
void PeerNetwork<O, _, __>::Conn::on_teardown() {
MsgNet::Conn::on_teardown();
auto pn = get_net();
- auto it = pn->id2peer.find(peer_id);
- if (it == pn->id2peer.end()) return;
- auto p = it->second.get();
+ auto p = pn->get_peer(peer_id);
+ if (!p) return;
if (this != p->conn.get()) return;
p->ev_ping_timer.del();
p->connected = false;
@@ -599,11 +614,25 @@ bool PeerNetwork<O, _, __>::check_new_conn(const conn_t &conn, uint16_t port) {
conn->peer_id.ip = conn->get_addr().ip;
conn->peer_id.port = port;
}
- auto it = id2peer.find(conn->peer_id);
+ const auto &id = conn->peer_id;
+ auto it = id2peer.find(id);
if (it == id2peer.end())
- {
- conn->disp_terminate();
- return true;
+ { /* found an unknown peer */
+ const auto &addr = conn->get_addr();
+ this->user_tcall->async_call([this, id](ThreadCall::Handle &) {
+ unknown_peer_cb(id);
+ });
+ if (allow_unknown_peer)
+ {
+ auto it2 = id2upeer.find(id);
+ if (it2 == id2upeer.end())
+ it = id2upeer.insert(std::make_pair(id, new Peer(addr, nullptr, this->disp_ec))).first;
+ }
+ else
+ {
+ conn->disp_terminate();
+ return true;
+ }
}
auto p = it->second.get();
if (p->connected)
@@ -638,7 +667,7 @@ bool PeerNetwork<O, _, __>::check_new_conn(const conn_t &conn, uint16_t port) {
template<typename O, O _, O __>
void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) {
- auto p = id2peer.find(addr)->second.get();
+ auto p = get_peer(addr);
if (p->connected) return;
auto conn = static_pointer_cast<Conn>(MsgNet::_connect(addr));
//assert(p->conn == nullptr);
@@ -648,6 +677,15 @@ void PeerNetwork<O, _, __>::start_active_conn(const NetAddr &addr) {
conn->peer_id.port = 0;
}
+template<typename O, O _, O __>
+typename PeerNetwork<O, _, __>::Peer *PeerNetwork<O, _, __>::get_peer(const NetAddr &addr) const {
+ auto it = id2peer.find(addr);
+ if (it != id2peer.end()) return it->second.get();
+ it = id2upeer.find(addr);
+ if (it != id2upeer.end()) return it->second.get();
+ return nullptr;
+}
+
template<typename OpcodeType>
inline void MsgNetwork<OpcodeType>::_send_msg_dispatcher(const Msg &msg, const conn_t &conn) {
bytearray_t msg_data = msg.serialize();
@@ -683,14 +721,13 @@ void PeerNetwork<O, _, __>::msg_pong(MsgPong &&msg, const conn_t &conn) {
this->disp_tcall->async_call([this, conn, port](ThreadCall::Handle &) {
try {
if (conn->get_mode() == ConnPool::Conn::DEAD) return;
- auto it = id2peer.find(conn->peer_id);
- if (it == id2peer.end())
+ auto p = get_peer(conn->peer_id);
+ if (!p)
{
SALTICIDAE_LOG_WARN("pong message discarded");
return;
}
if (check_new_conn(conn, port)) return;
- auto p = it->second.get();
p->pong_msg_ok = true;
if (p->ping_timer_ok)
{
@@ -724,7 +761,15 @@ void PeerNetwork<O, _, __>::add_peer(const NetAddr &addr) {
auto it = id2peer.find(addr);
if (it != id2peer.end())
throw PeerNetworkError(SALTI_ERROR_PEER_ALREADY_EXISTS);
- id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this->disp_ec)));
+ auto it2 = id2upeer.find(addr);
+ if (it2 != id2peer.end())
+ { /* move to the known peer set */
+ auto p = std::move(it2->second);
+ id2upeer.erase(it2);
+ id2peer.insert(std::make_pair(addr, std::move(p)));
+ }
+ else
+ id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this->disp_ec)));
start_active_conn(addr);
} catch (const PeerNetworkError &) {
this->recoverable_error(std::current_exception());
@@ -755,10 +800,10 @@ PeerNetwork<O, _, __>::get_peer_conn(const NetAddr &paddr) const {
conn_t conn;
std::exception_ptr err = nullptr;
try {
- auto it = id2peer.find(paddr);
- if (it == id2peer.end())
+ auto p = get_peer(paddr);
+ if (!p)
throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST);
- conn = it->second->conn;
+ conn = p->conn;
} catch (const PeerNetworkError &) {
this->recoverable_error(std::current_exception());
} catch (...) {
@@ -789,10 +834,10 @@ void PeerNetwork<O, _, __>::_send_msg(Msg &&msg, const NetAddr &paddr) {
this->disp_tcall->async_call(
[this, msg=std::move(msg), paddr](ThreadCall::Handle &) {
try {
- auto it = id2peer.find(paddr);
- if (it == id2peer.end())
+ auto p = get_peer(paddr);
+ if (!p)
throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST);
- this->_send_msg_dispatcher(msg, it->second->conn);
+ this->_send_msg_dispatcher(msg, p->conn);
} catch (const PeerNetworkError &) {
this->recoverable_error(std::current_exception());
} catch (...) { this->recoverable_error(std::current_exception()); }
@@ -812,10 +857,10 @@ void PeerNetwork<O, _, __>::_multicast_msg(Msg &&msg, const std::vector<NetAddr>
try {
for (auto &addr: paddrs)
{
- auto it = id2peer.find(addr);
- if (it == id2peer.end())
+ auto p = get_peer(addr);
+ if (!p)
throw PeerNetworkError(SALTI_ERROR_PEER_NOT_EXIST);
- this->_send_msg_dispatcher(msg, it->second->conn);
+ this->_send_msg_dispatcher(msg, p->conn);
}
} catch (const PeerNetworkError &) {
this->recoverable_error(std::current_exception());