diff options
-rw-r--r-- | include/salticidae/conn.h | 63 | ||||
-rw-r--r-- | include/salticidae/crypto.h | 19 | ||||
-rw-r--r-- | include/salticidae/network.h | 2 | ||||
-rw-r--r-- | include/salticidae/util.h | 2 | ||||
-rw-r--r-- | src/conn.cpp | 16 | ||||
-rw-r--r-- | src/network.cpp | 2 | ||||
-rw-r--r-- | src/util.cpp | 2 | ||||
-rw-r--r-- | test/CMakeLists.txt | 3 | ||||
-rw-r--r-- | test/bench_network.cpp | 1 | ||||
-rw-r--r-- | test/bench_network_tls.cpp | 165 | ||||
-rw-r--r-- | test/test_msgnet.cpp | 3 | ||||
-rw-r--r-- | test/test_msgnet_c.c | 3 | ||||
-rw-r--r-- | test/test_p2p_stress.cpp | 1 |
13 files changed, 256 insertions, 26 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index a791057..59d93fc 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -58,7 +58,7 @@ class ConnPool { /** The handle to a bi-directional connection. */ using conn_t = ArcObj<Conn>; /** The type of callback invoked when connection status is changed. */ - using conn_callback_t = std::function<void(const conn_t &, bool)>; + using conn_callback_t = std::function<bool(const conn_t &, bool)>; using error_callback_t = std::function<void(const std::exception_ptr, bool)>; /** Abstraction for a bi-directional connection. */ class Conn { @@ -93,7 +93,7 @@ class ConnPool { socket_io_func *send_data_func; socket_io_func *recv_data_func; BoxObj<TLS> tls; - BoxObj<X509> peer_cert; + BoxObj<const X509> peer_cert; static socket_io_func _recv_data; static socket_io_func _send_data; @@ -102,6 +102,7 @@ class ConnPool { static socket_io_func _send_data_tls; static socket_io_func _recv_data_tls_handshake; static socket_io_func _send_data_tls_handshake; + static socket_io_func _recv_data_dummy; void conn_server(int, int); @@ -189,7 +190,12 @@ class ConnPool { void update_conn(const conn_t &conn, bool connected) { user_tcall->async_call([this, conn, connected](ThreadCall::Handle &) { - if (conn_cb) conn_cb(conn, connected); + if ((!conn_cb || + conn_cb(conn, connected)) && + enable_tls && connected) + conn->worker->get_tcall()->async_call([conn](ThreadCall::Handle &) { + conn->recv_data_func = Conn::_recv_data_tls; + }); }); } @@ -264,7 +270,10 @@ class ConnPool { conn->recv_data_func(conn, fd, what); else conn->send_data_func(conn, fd, what); - } catch (...) { on_fatal_error(std::current_exception()); } + } catch (...) { + conn->cpool->recoverable_error(std::current_exception()); + conn->worker_terminate(); + } }); conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); nconn++; @@ -336,6 +345,8 @@ class ConnPool { std::string _tls_key_file; RcObj<X509> _tls_cert; RcObj<PKey> _tls_key; + bool _tls_skip_ca_check; + SSL_verify_cb _tls_verify_callback; public: Config(): @@ -344,11 +355,13 @@ class ConnPool { _seg_buff_size(4096), _nworker(1), _queue_capacity(0), - _enable_tls(true), - _tls_cert_file("./all.pem"), - _tls_key_file("./all.pem"), + _enable_tls(false), + _tls_cert_file(""), + _tls_key_file(""), _tls_cert(nullptr), - _tls_key(nullptr) {} + _tls_key(nullptr), + _tls_skip_ca_check(true), + _tls_verify_callback(nullptr) {} Config &max_listen_backlog(int x) { _max_listen_backlog = x; @@ -379,6 +392,36 @@ class ConnPool { _enable_tls = x; return *this; } + + Config &tls_cert_file(const std::string &x) { + _tls_cert_file = x; + return *this; + } + + Config &tls_key_file(const std::string &x) { + _tls_key_file = x; + return *this; + } + + Config &tls_cert(X509 *x) { + _tls_cert = x; + return *this; + } + + Config &tls_key(PKey *x) { + _tls_key = x; + return *this; + } + + Config &tls_skip_ca_check(bool *x) { + _tls_skip_ca_check = x; + return *this; + } + + Config &tls_verify_callback(SSL_verify_cb x) { + _tls_verify_callback = x; + return *this; + } }; ConnPool(const EventContext &ec, const Config &config): @@ -403,9 +446,11 @@ class ConnPool { tls_ctx->use_privkey(*config._tls_key); else tls_ctx->use_privkey_file(config._tls_key_file); + tls_ctx->set_verify(config._tls_skip_ca_check, config._tls_verify_callback); if (!tls_ctx->check_privkey()) - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); + throw SalticidaeError(SALTI_ERROR_TLS_KEY_NOT_MATCH); } + signal(SIGPIPE, SIG_IGN); workers = new Worker[nworker]; user_tcall = new ThreadCall(ec); disp_ec = workers[0].get_ec(); diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index 1d20b22..bcfd9dc 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -128,6 +128,10 @@ static inline int _tls_pem_with_passwd(char *buf, int size, int, void *) { return _size - 1; } +static int _skip_CA_check(int, X509_STORE_CTX *) { + return 1; +} + class PKey { EVP_PKEY *key; friend class TLSContext; @@ -271,6 +275,11 @@ class TLSContext { throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY); } + void set_verify(bool skip_ca_check = true, SSL_verify_cb verify_callback = nullptr) { + SSL_CTX_set_verify(ctx, + SSL_VERIFY_PEER, skip_ca_check ? _skip_CA_check : verify_callback); + } + bool check_privkey() { return SSL_CTX_check_private_key(ctx) > 0; } @@ -329,13 +338,9 @@ class TLS { return SSL_get_error(ssl, ret); } - ~TLS() { - if (ssl) - { - SSL_shutdown(ssl); - SSL_free(ssl); - } - } + void shutdown() { SSL_shutdown(ssl); } + + ~TLS() { if (ssl) SSL_free(ssl); } }; } diff --git a/include/salticidae/network.h b/include/salticidae/network.h index e9fdae6..b703c35 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -996,7 +996,7 @@ void msgnetwork_terminate(msgnetwork_t *self, const msgnetwork_conn_t *conn); typedef void (*msgnetwork_msg_callback_t)(const msg_t *, const msgnetwork_conn_t *, void *userdata); void msgnetwork_reg_handler(msgnetwork_t *self, _opcode_t opcode, msgnetwork_msg_callback_t cb, void *userdata); -typedef void (*msgnetwork_conn_callback_t)(const msgnetwork_conn_t *, bool connected, void *userdata); +typedef bool (*msgnetwork_conn_callback_t)(const msgnetwork_conn_t *, bool connected, void *userdata); void msgnetwork_reg_conn_handler(msgnetwork_t *self, msgnetwork_conn_callback_t cb, void *userdata); diff --git a/include/salticidae/util.h b/include/salticidae/util.h index dec498c..9a57ae8 100644 --- a/include/salticidae/util.h +++ b/include/salticidae/util.h @@ -88,6 +88,8 @@ enum SalticidaeErrorCode { SALTI_ERROR_TLS_GENERIC, SALTI_ERROR_TLS_X509, SALTI_ERROR_TLS_KEY, + SALTI_ERROR_TLS_KEY_NOT_MATCH, + SALTI_ERROR_TLS_NO_PEER_CERT, SALTI_ERROR_UNKNOWN }; diff --git a/src/conn.cpp b/src/conn.cpp index 60d5835..535803b 100644 --- a/src/conn.cpp +++ b/src/conn.cpp @@ -211,13 +211,18 @@ void ConnPool::Conn::_recv_data_tls(const ConnPool::conn_t &conn, int fd, int ev conn->on_read(); } -void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int, int) { +void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int fd, int events) { + conn->ready_send = true; + _recv_data_tls_handshake(conn, fd, events); +} + +void ConnPool::Conn::_recv_data_tls_handshake(const ConnPool::conn_t &conn, int, int) { int ret; if (conn->tls->do_handshake(ret)) { /* finishing TLS handshake */ conn->send_data_func = _send_data_tls; - conn->recv_data_func = _recv_data_tls; + conn->recv_data_func = _recv_data_dummy; conn->peer_cert = new X509(conn->tls->get_peer_cert()); conn->cpool->update_conn(conn, true); } @@ -229,9 +234,8 @@ void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int, } } -void ConnPool::Conn::_recv_data_tls_handshake(const ConnPool::conn_t &conn, int fd, int events) { - conn->ready_send = true; - _send_data_tls_handshake(conn, fd, events); + +void ConnPool::Conn::_recv_data_dummy(const ConnPool::conn_t &, int, int) { } /****/ @@ -239,6 +243,7 @@ void ConnPool::Conn::stop() { if (mode != ConnMode::DEAD) { if (worker) worker->unfeed(); + if (tls) tls->shutdown(); ev_connect.clear(); ev_socket.clear(); send_buffer.get_queue().unreg_handler(); @@ -290,7 +295,6 @@ void ConnPool::accept_client(int fd, int) { conn->send_buffer.set_capacity(queue_capacity); conn->seg_buff_size = seg_buff_size; conn->fd = client_fd; - conn->worker = nullptr; conn->cpool = this; conn->mode = Conn::PASSIVE; conn->addr = addr; diff --git a/src/network.cpp b/src/network.cpp index b8d058a..4067531 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -96,7 +96,7 @@ void msgnetwork_reg_conn_handler(msgnetwork_t *self, void *userdata) { self->reg_conn_handler([=](const ConnPool::conn_t &_conn, bool connected) { auto conn = salticidae::static_pointer_cast<msgnetwork_t::Conn>(_conn); - cb(&conn, connected, userdata); + return cb(&conn, connected, userdata); }); } diff --git a/src/util.cpp b/src/util.cpp index 66bcd12..1493b20 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -53,6 +53,8 @@ const char *SALTICIDAE_ERROR_STRINGS[] = { "tls generic error", "x509 cert error", "EVP_PKEY error", + "tls key does not match the cert", + "tls fail to get peer cert", "unknown error" }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2a1a8f0..e7c5813 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -41,5 +41,8 @@ target_link_libraries(test_queue salticidae_static pthread) add_executable(bench_network bench_network.cpp) target_link_libraries(bench_network salticidae_static pthread) +add_executable(bench_network_tls bench_network_tls.cpp) +target_link_libraries(bench_network_tls salticidae_static pthread) + add_executable(test_msgnet_c test_msgnet_c.c) target_link_libraries(test_msgnet_c salticidae_static pthread) diff --git a/test/bench_network.cpp b/test/bench_network.cpp index ca22db4..f8d3070 100644 --- a/test/bench_network.cpp +++ b/test/bench_network.cpp @@ -120,6 +120,7 @@ struct MyNet: public MsgNetworkByteOp { /* try to reconnect to the same address */ connect(conn->get_addr(), false); } + return true; }); } diff --git a/test/bench_network_tls.cpp b/test/bench_network_tls.cpp new file mode 100644 index 0000000..bb5d0c1 --- /dev/null +++ b/test/bench_network_tls.cpp @@ -0,0 +1,165 @@ +/** + * Copyright (c) 2018 Cornell University. + * + * Author: Ted Yin <tederminant@gmail.com> + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + * of the Software, and to permit persons to whom the Software is furnished to do + * so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include <cstdio> +#include <string> +#include <functional> +#include <thread> +#include <signal.h> + +/* disable SHA256 checksum */ +#define SALTICIDAE_NOCHECKSUM + +#include "salticidae/msg.h" +#include "salticidae/event.h" +#include "salticidae/network.h" +#include "salticidae/stream.h" + +using salticidae::NetAddr; +using salticidae::DataStream; +using salticidae::MsgNetwork; +using salticidae::htole; +using salticidae::letoh; +using salticidae::bytearray_t; +using salticidae::TimerEvent; +using salticidae::ThreadCall; +using std::placeholders::_1; +using std::placeholders::_2; +using opcode_t = uint8_t; + +struct MsgBytes { + static const opcode_t opcode = 0xa; + DataStream serialized; + bytearray_t bytes; + MsgBytes(size_t size) { + bytes.resize(size); + serialized << htole((uint32_t)size) << bytes; + } + MsgBytes(DataStream &&s) { + uint32_t len; + s >> len; + len = letoh(len); + auto base = s.get_data_inplace(len); + bytes = bytearray_t(base, base + len); + } +}; + +const opcode_t MsgBytes::opcode; + +using MsgNetworkByteOp = MsgNetwork<opcode_t>; + +struct MyNet: public MsgNetworkByteOp { + const std::string name; + const NetAddr peer; + TimerEvent ev_period_stat; + ThreadCall tcall; + size_t nrecv; + std::function<void(ThreadCall::Handle &)> trigger; + + MyNet(const salticidae::EventContext &ec, + const std::string name, + const NetAddr &peer, + double stat_timeout = -1): + MsgNetworkByteOp(ec, MsgNetworkByteOp::Config( + ConnPool::Config().queue_capacity(65536).enable_tls(true).tls_cert_file("all.pem").tls_key_file("all.pem")).burst_size(1000)), + name(name), + peer(peer), + ev_period_stat(ec, [this, stat_timeout](TimerEvent &) { + SALTICIDAE_LOG_INFO("%.2f mps", nrecv / (double)stat_timeout); + fflush(stderr); + nrecv = 0; + ev_period_stat.add(stat_timeout); + }), + tcall(ec), + nrecv(0) { + /* message handler could be a bound method */ + reg_handler(salticidae::generic_bind(&MyNet::on_receive_bytes, this, _1, _2)); + if (stat_timeout > 0) + ev_period_stat.add(0); + reg_conn_handler([this, ec](const ConnPool::conn_t &conn, bool connected) { + if (connected) + { + if (conn->get_mode() == MyNet::Conn::ACTIVE) + { + printf("[%s] Connected, sending hello.\n", this->name.c_str()); + /* send the first message through this connection */ + trigger = [this, conn](ThreadCall::Handle &) { + send_msg(MsgBytes(256), salticidae::static_pointer_cast<Conn>(conn)); + if (conn->get_mode() != MyNet::Conn::DEAD) + tcall.async_call(trigger); + }; + tcall.async_call(trigger); + } + else + printf("[%s] Passively connected, waiting for greetings.\n", this->name.c_str()); + } + else + { + printf("[%s] Disconnected, retrying.\n", this->name.c_str()); + /* try to reconnect to the same address */ + connect(conn->get_addr(), false); + } + return true; + }); + } + + void on_receive_bytes(MsgBytes &&msg, const conn_t &conn) { + nrecv++; + } +}; + +salticidae::EventContext ec; +NetAddr alice_addr("127.0.0.1:1234"); +NetAddr bob_addr("127.0.0.1:1235"); + +int main() { + salticidae::BoxObj<MyNet> alice = new MyNet(ec, "Alice", bob_addr, 10); + alice->start(); + alice->listen(alice_addr); + salticidae::EventContext tec; + salticidae::BoxObj<ThreadCall> tcall = new ThreadCall(tec); + std::thread bob_thread([&tec]() { + MyNet bob(tec, "Bob", alice_addr); + bob.start(); + bob.connect(alice_addr); + try { + tec.dispatch(); + } catch (std::exception &) {} + SALTICIDAE_LOG_INFO("thread exiting"); + }); + auto shutdown = [&](int) { + tcall->async_call([&](salticidae::ThreadCall::Handle &) { + tec.stop(); + }); + alice = nullptr; + ec.stop(); + bob_thread.join(); + }; + salticidae::SigEvent ev_sigint(ec, shutdown); + salticidae::SigEvent ev_sigterm(ec, shutdown); + ev_sigint.add(SIGINT); + ev_sigterm.add(SIGTERM); + ec.dispatch(); + return 0; +} diff --git a/test/test_msgnet.cpp b/test/test_msgnet.cpp index 088e0ff..7635af8 100644 --- a/test/test_msgnet.cpp +++ b/test/test_msgnet.cpp @@ -108,8 +108,9 @@ struct MyNet: public MsgNetworkByteOp { { printf("[%s] Disconnected, retrying.\n", this->name.c_str()); /* try to reconnect to the same address */ - connect(conn->get_addr()); + connect(conn->get_addr(), false); } + return true; }); } diff --git a/test/test_msgnet_c.c b/test/test_msgnet_c.c index e6ebd14..f99c88b 100644 --- a/test/test_msgnet_c.c +++ b/test/test_msgnet_c.c @@ -117,7 +117,7 @@ void on_receive_ack(const msg_t *msg, const msgnetwork_conn_t *conn, void *userd printf("[%s] the peer knows\n", name); } -void conn_handler(const msgnetwork_conn_t *conn, bool connected, void *userdata) { +bool conn_handler(const msgnetwork_conn_t *conn, bool connected, void *userdata) { msgnetwork_t *net = msgnetwork_conn_get_net(conn); MyNet *n = (MyNet *)userdata; const char *name = n->name; @@ -142,6 +142,7 @@ void conn_handler(const msgnetwork_conn_t *conn, bool connected, void *userdata) msgnetwork_connect(net, addr, &err); check_err(&err); } + return true; } void error_handler(const SalticidaeCError *err, bool fatal, void *userdata) { diff --git a/test/test_p2p_stress.cpp b/test/test_p2p_stress.cpp index 92e5bb4..1cb2ca3 100644 --- a/test/test_p2p_stress.cpp +++ b/test/test_p2p_stress.cpp @@ -113,6 +113,7 @@ void install_proto(AppContext &app, const size_t &seg_buff_size) { send_rand(tc.state, static_pointer_cast<MyNet::Conn>(conn)); } } + return true; }); net.reg_error_handler([ec](const std::exception_ptr _err, bool fatal) { try { |