aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/salticidae/conn.h34
-rw-r--r--include/salticidae/network.h115
-rw-r--r--include/salticidae/util.h1
-rw-r--r--src/conn.cpp25
-rw-r--r--src/util.cpp1
-rw-r--r--test/test_p2p_stress.cpp35
6 files changed, 132 insertions, 79 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h
index 44a1bf9..e5890f6 100644
--- a/include/salticidae/conn.h
+++ b/include/salticidae/conn.h
@@ -108,10 +108,6 @@ class ConnPool {
static socket_io_func _send_data_tls_handshake;
static socket_io_func _recv_data_dummy;
- /** Close the IO and clear all on-going or planned events. Remove the
- * connection from a Worker. */
- virtual void stop();
-
public:
Conn(): terminated(false),
// recv_chunk_size initialized later
@@ -184,9 +180,18 @@ class ConnPool {
/** Called when new data is available. */
virtual void on_read(const conn_t &) {}
/** Called when the underlying connection is established. */
- virtual void on_setup(const conn_t &) {}
+ virtual void on_worker_setup(const conn_t &) {}
+ /** Called when the underlying connection is established. */
+ virtual void on_dispatcher_setup(const conn_t &) {}
/** Called when the underlying connection breaks. */
- virtual void on_teardown(const conn_t &) {}
+ virtual void on_worker_teardown(const conn_t &conn) {
+ if (conn->worker) conn->worker->unfeed();
+ if (conn->tls) conn->tls->shutdown();
+ conn->ev_socket.clear();
+ conn->send_buffer.get_queue().unreg_handler();
+ }
+ /** Called when the underlying connection breaks. */
+ virtual void on_dispatcher_teardown(const conn_t &) {}
private:
const int max_listen_backlog;
@@ -212,6 +217,7 @@ class ConnPool {
if (enable_tls)
{
conn->worker->get_tcall()->async_call([this, conn, ret](ThreadCall::Handle &) {
+ if (conn->is_terminated()) return;
if (ret)
{
conn->recv_data_func = Conn::_recv_data_tls;
@@ -223,6 +229,7 @@ class ConnPool {
}
else
conn->worker->get_tcall()->async_call([conn](ThreadCall::Handle &) {
+ if (conn->is_terminated()) return;
conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE);
});
}
@@ -306,9 +313,15 @@ class ConnPool {
conn->send_data_func = Conn::_send_data;
conn->recv_data_func = Conn::_recv_data;
enable_send_buffer(conn, client_fd);
+ cpool->on_worker_setup(conn);
cpool->disp_tcall->async_call([cpool, conn](ThreadCall::Handle &) {
- cpool->on_setup(conn);
- cpool->update_conn(conn, true);
+ try {
+ cpool->on_dispatcher_setup(conn);
+ cpool->update_conn(conn, true);
+ } catch (...) {
+ cpool->recoverable_error(std::current_exception(), -1);
+ cpool->disp_terminate(conn);
+ }
});
}
assert(conn->fd != -1);
@@ -559,7 +572,8 @@ class ConnPool {
for (auto it: pool)
{
auto &conn = it.second;
- conn->stop();
+ on_worker_teardown(conn);
+ //conn->stop();
conn->set_terminated();
release_conn(conn);
}
@@ -623,6 +637,8 @@ class ConnPool {
}
});
}
+
+ const X509 *get_cert() const { return tls_cert.get(); }
};
}
diff --git a/include/salticidae/network.h b/include/salticidae/network.h
index 40f17a1..19d6db0 100644
--- a/include/salticidae/network.h
+++ b/include/salticidae/network.h
@@ -89,10 +89,6 @@ class MsgNetwork: public ConnPool {
mutable std::atomic<size_t> nsentb;
mutable std::atomic<size_t> nrecvb;
#endif
- void stop() override {
- ev_enqueue_poll.clear();
- ConnPool::Conn::stop();
- }
public:
Conn(): msg_state(HEADER), msg_sleep(false)
@@ -138,12 +134,10 @@ class MsgNetwork: public ConnPool {
ConnPool::Conn *create_conn() override { return new Conn(); }
void on_read(const ConnPool::conn_t &) override;
- void on_setup(const ConnPool::conn_t &_conn) override {
+ void on_worker_setup(const ConnPool::conn_t &_conn) override {
auto conn = static_pointer_cast<Conn>(_conn);
- auto worker = conn->worker;
- worker->get_tcall()->async_call([this, conn, worker](ThreadCall::Handle &) {
- conn->ev_enqueue_poll = TimerEvent(worker->get_ec(),
- [this, conn](TimerEvent &) {
+ conn->ev_enqueue_poll = TimerEvent(conn->worker->get_ec(),
+ [this, conn](TimerEvent &) {
if (!incoming_msgs.enqueue(std::make_pair(conn->msg, conn), false))
{
conn->msg_sleep = true;
@@ -153,7 +147,12 @@ class MsgNetwork: public ConnPool {
conn->msg_sleep = false;
on_read(conn);
});
- });
+ }
+
+ void on_worker_teardown(const ConnPool::conn_t &_conn) override {
+ auto conn = static_pointer_cast<Conn>(_conn);
+ conn->ev_enqueue_poll.clear();
+ ConnPool::on_worker_teardown(_conn);
}
public:
@@ -287,8 +286,8 @@ class ClientNetwork: public MsgNetwork<OpcodeType> {
protected:
ConnPool::Conn *create_conn() override { return new Conn(); }
- void on_setup(const ConnPool::conn_t &) override;
- void on_teardown(const ConnPool::conn_t &) override;
+ void on_dispatcher_setup(const ConnPool::conn_t &) override;
+ void on_dispatcher_teardown(const ConnPool::conn_t &) override;
public:
using Config = typename MsgNet::Config;
@@ -376,12 +375,6 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
PeerNetwork *get_net() {
return static_cast<PeerNetwork *>(ConnPool::Conn::get_pool());
}
-
- protected:
- void stop() override {
- ev_timeout.clear();
- MsgNet::Conn::stop();
- }
};
using conn_t = ArcObj<Conn>;
@@ -520,8 +513,10 @@ class PeerNetwork: public MsgNetwork<OpcodeType> {
protected:
ConnPool::Conn *create_conn() override { return new Conn(); }
- void on_setup(const ConnPool::conn_t &) override;
- void on_teardown(const ConnPool::conn_t &) override;
+ void on_worker_setup(const ConnPool::conn_t &) override;
+ void on_worker_teardown(const ConnPool::conn_t &) override;
+ void on_dispatcher_setup(const ConnPool::conn_t &) override;
+ void on_dispatcher_teardown(const ConnPool::conn_t &) override;
PeerId _get_peer_id(const X509 *cert, const NetAddr &addr) {
if (!this->enable_tls || id_mode == ADDR_BASED)
@@ -738,47 +733,65 @@ void PeerNetwork<O, _, __>::tcall_reset_timeout(ConnPool::Worker *worker,
});
}
-/* begin: functions invoked by the dispatcher */
template<typename O, O _, O __>
-void PeerNetwork<O, _, __>::on_setup(const ConnPool::conn_t &_conn) {
- MsgNet::on_setup(_conn);
+void PeerNetwork<O, _, __>::on_worker_setup(const ConnPool::conn_t &_conn) {
+ MsgNet::on_worker_setup(_conn);
auto conn = static_pointer_cast<Conn>(_conn);
auto worker = conn->worker;
+ auto &ev_timeout = conn->ev_timeout;
+ assert(!ev_timeout);
+ ev_timeout = TimerEvent(worker->get_ec(), [=](TimerEvent &) {
+ try {
+ SALTICIDAE_LOG_INFO("%s%s%s: peer ping-pong timeout",
+ tty_secondary_color,
+ id_hex.c_str(),
+ tty_reset_color);
+ this->worker_terminate(conn);
+ } catch (...) { worker->error_callback(std::current_exception()); }
+ });
+}
+
+template<typename O, O _, O __>
+void PeerNetwork<O, _, __>::on_worker_teardown(const ConnPool::conn_t &_conn) {
+ auto conn = static_pointer_cast<Conn>(_conn);
+ conn->ev_timeout.clear();
+ MsgNet::on_worker_teardown(_conn);
+}
+
+/* begin: functions invoked by the dispatcher */
+
+/* the initial ping-pong to set up the connection */
+template<typename O, O _, O __>
+void PeerNetwork<O, _, __>::on_dispatcher_setup(const ConnPool::conn_t &_conn) {
+ MsgNet::on_dispatcher_setup(_conn);
+ auto conn = static_pointer_cast<Conn>(_conn);
SALTICIDAE_LOG_INFO("%s%s%s: setup connection %s",
tty_secondary_color,
id_hex.c_str(),
tty_reset_color,
std::string(*conn).c_str());
- worker->get_tcall()->async_call([this, conn, worker](ThreadCall::Handle &) {
- auto &ev_timeout = conn->ev_timeout;
- assert(!ev_timeout);
- ev_timeout = TimerEvent(worker->get_ec(), [=](TimerEvent &) {
- try {
- SALTICIDAE_LOG_INFO("%s%s%s: peer ping-pong timeout",
- tty_secondary_color,
- id_hex.c_str(),
- tty_reset_color);
- this->worker_terminate(conn);
- } catch (...) { worker->error_callback(std::current_exception()); }
- });
- });
- /* the initial ping-pong to set up the connection */
- tcall_reset_timeout(worker, conn, conn_timeout);
+ tcall_reset_timeout(conn->worker, conn, conn_timeout);
if (conn->get_mode() == Conn::ConnMode::ACTIVE)
{
auto pid = get_peer_id(conn, conn->get_addr());
- pinfo_slock_t _g(known_peers_lock);
- send_msg(MsgPing(
- listen_addr,
- known_peers.find(pid)->second->get_nonce()), conn);
+ auto it = known_peers.find(pid);
+ if (it == known_peers.end())
+ throw PeerNetworkError(SALTI_ERROR_PEER_NOT_MATCH);
+ else
+ {
+ pinfo_slock_t _g(known_peers_lock);
+ send_msg(MsgPing(
+ listen_addr,
+ it->second->get_nonce()), conn);
+ }
}
else
replace_pending_conn(conn);
}
template<typename O, O _, O __>
-void PeerNetwork<O, _, __>::on_teardown(const ConnPool::conn_t &_conn) {
- MsgNet::on_teardown(_conn);
+void PeerNetwork<O, _, __>::on_dispatcher_teardown(const ConnPool::conn_t &_conn) {
+ MsgNet::on_dispatcher_teardown(_conn);
auto conn = static_pointer_cast<Conn>(_conn);
auto addr = conn->get_addr();
pending_peers.erase(addr);
@@ -949,8 +962,7 @@ void PeerNetwork<O, _, __>::ping_handler(MsgPing &&msg, const conn_t &conn) {
this->user_tcall->async_call([this, addr=msg.claimed_addr, conn](ThreadCall::Handle &) {
if (unknown_peer_cb) unknown_peer_cb(addr, conn->get_peer_cert());
});
- this->disp_terminate(conn);
- return;
+ throw PeerNetworkError(SALTI_ERROR_PEER_NOT_MATCH);
}
auto &p = pit->second;
if (p->state != Peer::State::DISCONNECTED ||
@@ -1018,8 +1030,7 @@ void PeerNetwork<O, _, __>::pong_handler(MsgPong &&msg, const conn_t &conn) {
SALTICIDAE_LOG_WARN(
"%s%s%s: unexpected pong from an unknown peer",
tty_secondary_color, id_hex.c_str(), tty_reset_color);
- this->disp_terminate(conn);
- return;
+ throw PeerNetworkError(SALTI_ERROR_PEER_NOT_MATCH);
}
auto &p = pit->second;
assert(!p->addr.is_null() && p->addr == conn->get_addr());
@@ -1290,8 +1301,8 @@ inline int32_t PeerNetwork<O, _, __>::_multicast_msg(Msg &&msg, const std::vecto
/* end: functions invoked by the user loop */
template<typename OpcodeType>
-void ClientNetwork<OpcodeType>::on_setup(const ConnPool::conn_t &_conn) {
- MsgNet::on_setup(_conn);
+void ClientNetwork<OpcodeType>::on_dispatcher_setup(const ConnPool::conn_t &_conn) {
+ MsgNet::on_dispatcher_setup(_conn);
auto conn = static_pointer_cast<Conn>(_conn);
assert(conn->get_mode() == Conn::PASSIVE);
const auto &addr = conn->get_addr();
@@ -1301,8 +1312,8 @@ void ClientNetwork<OpcodeType>::on_setup(const ConnPool::conn_t &_conn) {
}
template<typename OpcodeType>
-void ClientNetwork<OpcodeType>::on_teardown(const ConnPool::conn_t &_conn) {
- MsgNet::on_teardown(_conn);
+void ClientNetwork<OpcodeType>::on_dispatcher_teardown(const ConnPool::conn_t &_conn) {
+ MsgNet::on_dispatcher_teardown(_conn);
auto conn = static_pointer_cast<Conn>(_conn);
conn->get_net()->addr2conn.erase(conn->get_addr());
}
diff --git a/include/salticidae/util.h b/include/salticidae/util.h
index cb09c0e..8c8fcb9 100644
--- a/include/salticidae/util.h
+++ b/include/salticidae/util.h
@@ -91,6 +91,7 @@ enum SalticidaeErrorCode {
SALTI_ERROR_PEER_ALREADY_EXISTS,
SALTI_ERROR_PEER_NOT_EXIST,
SALTI_ERROR_PEER_NOT_READY,
+ SALTI_ERROR_PEER_NOT_MATCH,
SALTI_ERROR_CLIENT_NOT_EXIST,
SALTI_ERROR_NETADDR_INVALID,
SALTI_ERROR_OPTVAL_INVALID,
diff --git a/src/conn.cpp b/src/conn.cpp
index a5d60a7..af15276 100644
--- a/src/conn.cpp
+++ b/src/conn.cpp
@@ -251,9 +251,15 @@ void ConnPool::Conn::_recv_data_tls_handshake(const conn_t &conn, int, int) {
conn->peer_cert = new X509(conn->tls->get_peer_cert());
conn->worker->enable_send_buffer(conn, conn->fd);
auto cpool = conn->cpool;
+ cpool->on_worker_setup(conn);
cpool->disp_tcall->async_call([cpool, conn](ThreadCall::Handle &) {
- cpool->on_setup(conn);
- cpool->update_conn(conn, true);
+ try {
+ cpool->on_dispatcher_setup(conn);
+ cpool->update_conn(conn, true);
+ } catch (...) {
+ cpool->recoverable_error(std::current_exception(), -1);
+ cpool->disp_terminate(conn);
+ }
});
}
else
@@ -266,17 +272,11 @@ void ConnPool::Conn::_recv_data_tls_handshake(const conn_t &conn, int, int) {
void ConnPool::Conn::_recv_data_dummy(const conn_t &, int, int) {}
-void ConnPool::Conn::stop() {
- if (worker) worker->unfeed();
- if (tls) tls->shutdown();
- ev_socket.clear();
- send_buffer.get_queue().unreg_handler();
-}
-
void ConnPool::worker_terminate(const conn_t &conn) {
conn->worker->get_tcall()->async_call([this, conn](ThreadCall::Handle &) {
if (!conn->set_terminated()) return;
- conn->stop();
+ on_worker_teardown(conn);
+ //conn->stop();
disp_tcall->async_call([this, conn](ThreadCall::Handle &) {
del_conn(conn);
});
@@ -292,7 +292,8 @@ void ConnPool::disp_terminate(const conn_t &conn) {
else
disp_tcall->async_call([this, conn](ThreadCall::Handle &) {
if (!conn->set_terminated()) return;
- conn->stop();
+ on_worker_teardown(conn);
+ //conn->stop();
del_conn(conn);
});
}
@@ -440,7 +441,7 @@ void ConnPool::del_conn(const conn_t &conn) {
void ConnPool::release_conn(const conn_t &conn) {
/* inform the upper layer the connection will be destroyed */
conn->ev_connect.clear();
- on_teardown(conn);
+ on_dispatcher_teardown(conn);
::close(conn->fd);
}
diff --git a/src/util.cpp b/src/util.cpp
index 8ca01aa..01f6b06 100644
--- a/src/util.cpp
+++ b/src/util.cpp
@@ -43,6 +43,7 @@ const char *SALTICIDAE_ERROR_STRINGS[] = {
"peer already exists",
"peer does not exist",
"peer is not ready",
+ "peer id does not match the record",
"client does not exist",
"invalid NetAddr format",
"invalid OptVal format",
diff --git a/test/test_p2p_stress.cpp b/test/test_p2p_stress.cpp
index d054a57..7a078eb 100644
--- a/test/test_p2p_stress.cpp
+++ b/test/test_p2p_stress.cpp
@@ -46,6 +46,7 @@ using salticidae::static_pointer_cast;
using salticidae::Config;
using salticidae::ThreadCall;
using salticidae::BoxObj;
+using salticidae::PKey;
using std::placeholders::_1;
using std::placeholders::_2;
@@ -88,6 +89,8 @@ const uint8_t MsgAck::opcode;
using MyNet = salticidae::PeerNetwork<uint8_t>;
+bool use_tls;
+std::unordered_set<uint256_t> valid_certs;
std::vector<NetAddr> addrs;
struct TestContext {
@@ -116,6 +119,11 @@ void install_proto(AppContext &app, const size_t &recv_chunk_size) {
net.send_msg(std::move(msg), conn);
};
net.reg_conn_handler([](const ConnPool::conn_t &conn, bool connected) {
+ if (connected && use_tls)
+ {
+ auto cert_hash = salticidae::get_hash(conn->get_peer_cert()->get_der());
+ return valid_certs.count(cert_hash) > 0;
+ }
return true;
});
net.reg_peer_handler([&, send_rand](const MyNet::conn_t &conn, bool connected) {
@@ -197,6 +205,7 @@ int main(int argc, char **argv) {
auto opt_nworker = Config::OptValInt::create(2);
auto opt_conn_timeout = Config::OptValDouble::create(5);
auto opt_ping_peroid = Config::OptValDouble::create(2);
+ auto opt_tls = Config::OptValFlag::create(false);
auto opt_help = Config::OptValFlag::create(false);
config.add_opt("no-msg", opt_no_msg, Config::SWITCH_ON);
config.add_opt("npeers", opt_npeers, Config::SET_VAL);
@@ -204,6 +213,7 @@ int main(int argc, char **argv) {
config.add_opt("nworker", opt_nworker, Config::SET_VAL);
config.add_opt("conn-timeout", opt_conn_timeout, Config::SET_VAL);
config.add_opt("ping-period", opt_ping_peroid, Config::SET_VAL);
+ config.add_opt("tls", opt_tls, Config::SWITCH_ON, 't');
config.add_opt("help", opt_help, Config::SWITCH_ON, 'h', "show this help info");
config.parse(argc, argv);
if (opt_help->get())
@@ -216,13 +226,24 @@ int main(int argc, char **argv) {
addrs.push_back(NetAddr("127.0.0.1:" + std::to_string(12345 + i)));
std::vector<AppContext> apps;
std::vector<std::thread> threads;
+ use_tls = opt_tls->get();
apps.resize(addrs.size());
for (size_t i = 0; i < apps.size(); i++)
{
auto &a = apps[i];
a.addr = addrs[i];
- a.net = new MyNet(a.ec, MyNet::Config(
- salticidae::ConnPool::Config()
+ salticidae::ConnPool::Config cfg{};
+ if (use_tls)
+ {
+ auto tls_key = new PKey(PKey::create_privkey_rsa(2048));
+ auto tls_cert = new salticidae::X509(salticidae::X509::create_self_signed_from_pubkey(*tls_key));
+ cfg.enable_tls(true)
+ .tls_key(tls_key)
+ .tls_cert(tls_cert);
+ valid_certs.insert(salticidae::get_hash(tls_cert->get_der()));
+ }
+ else cfg.enable_tls(false);
+ a.net = new MyNet(a.ec, MyNet::Config(cfg
.nworker(opt_nworker->get())
.recv_chunk_size(recv_chunk_size))
.conn_timeout(opt_conn_timeout->get())
@@ -238,12 +259,14 @@ int main(int argc, char **argv) {
threads.push_back(std::thread([&]() {
masksigs();
a.net->listen(a.addr);
- for (auto &paddr: addrs)
- if (paddr != a.addr)
+ for (auto &b: apps)
+ if (b.addr != a.addr)
{
- salticidae::PeerId pid{paddr};
+ auto pid = use_tls ?
+ salticidae::PeerId(*b.net->get_cert()) :
+ salticidae::PeerId(b.addr);
a.net->add_peer(pid);
- a.net->set_peer_addr(pid, paddr);
+ a.net->set_peer_addr(pid, b.addr);
a.net->conn_peer(pid);
}
a.ec.dispatch();}));