From 8f42d0581a8e0cd77bde459db6b61fd957e19c1b Mon Sep 17 00:00:00 2001 From: Determinant Date: Tue, 18 Jun 2019 18:19:39 -0400 Subject: WIP: TLS support --- include/salticidae/conn.h | 66 +++++++++++++++++++++++++++++++++----- include/salticidae/crypto.h | 78 +++++++++++++++++++++++++++++++++++++++++++++ include/salticidae/util.h | 3 ++ 3 files changed, 139 insertions(+), 8 deletions(-) (limited to 'include') diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 48902d4..076d64a 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -88,9 +88,20 @@ class ConnPool { TimerEvent ev_send_wait; /** does not need to wait if true */ bool ready_send; + + typedef void (socket_io_func)(const conn_t &, int, int); + socket_io_func *send_data_func; + socket_io_func *recv_data_func; + BoxObj tls; - void recv_data(int, int); - void send_data(int, int); + static socket_io_func _recv_data; + static socket_io_func _send_data; + + static socket_io_func _recv_data_tls; + static socket_io_func _send_data_tls; + static socket_io_func _recv_data_tls_handshake; + static socket_io_func _send_data_tls_handshake; + void conn_server(int, int); /** Terminate the connection (from the worker thread). */ @@ -99,7 +110,7 @@ class ConnPool { void disp_terminate(); public: - Conn(): ready_send(false) {} + Conn(): ready_send(false), send_data_func(nullptr), recv_data_func(nullptr) {} Conn(const Conn &) = delete; Conn(Conn &&other) = delete; @@ -158,6 +169,8 @@ class ConnPool { const double conn_server_timeout; const size_t seg_buff_size; const size_t queue_capacity; + const bool enable_tls; + tls_context_t tls_ctx; /* owned by user loop */ protected: @@ -212,6 +225,21 @@ class ConnPool { std::this_thread::get_id()); return; } + auto cpool = conn->cpool; + if (cpool->enable_tls) + { + conn->tls = new TLS( + cpool->tls_ctx, client_fd, + conn->mode == Conn::ConnMode::PASSIVE); + conn->send_data_func = Conn::_send_data_tls_handshake; + conn->recv_data_func = Conn::_recv_data_tls_handshake; + } + else + { + conn->send_data_func = Conn::_send_data; + conn->recv_data_func = Conn::_recv_data; + cpool->update_conn(conn, true); + } assert(conn->fd != -1); SALTICIDAE_LOG_INFO("worker %x got %s", std::this_thread::get_id(), @@ -224,16 +252,16 @@ class ConnPool { { conn->ev_socket.del(); conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); - conn->send_data(client_fd, FdEvent::WRITE); + conn->send_data_func(conn, client_fd, FdEvent::WRITE); } return false; }); conn->ev_socket = FdEvent(ec, client_fd, [this, conn=conn](int fd, int what) { try { if (what & FdEvent::READ) - conn->recv_data(fd, what); + conn->recv_data_func(conn, fd, what); else - conn->send_data(fd, what); + conn->send_data_func(conn, fd, what); } catch (...) { on_fatal_error(std::current_exception()); } }); conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); @@ -301,6 +329,9 @@ class ConnPool { size_t _seg_buff_size; size_t _nworker; size_t _queue_capacity; + bool _enable_tls; + std::string _tls_cert_file; + std::string _tls_key_file; public: Config(): @@ -308,7 +339,10 @@ class ConnPool { _conn_server_timeout(2), _seg_buff_size(4096), _nworker(1), - _queue_capacity(0) {} + _queue_capacity(0), + _enable_tls(true), + _tls_cert_file("./server.pem"), + _tls_key_file("./server.pem") {} Config &max_listen_backlog(int x) { _max_listen_backlog = x; @@ -334,6 +368,11 @@ class ConnPool { _queue_capacity = x; return *this; } + + Config &enable_tls(bool x) { + _enable_tls = x; + return *this; + } }; ConnPool(const EventContext &ec, const Config &config): @@ -342,9 +381,19 @@ class ConnPool { conn_server_timeout(config._conn_server_timeout), seg_buff_size(config._seg_buff_size), queue_capacity(config._queue_capacity), + enable_tls(config._enable_tls), + tls_ctx(nullptr), listen_fd(-1), nworker(config._nworker), system_state(0) { + if (enable_tls) + { + tls_ctx = new TLSContext(); + tls_ctx->use_cert_file(config._tls_cert_file); + tls_ctx->use_priv_key_file(config._tls_key_file); + if (!tls_ctx->check_priv_key()) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + } workers = new Worker[nworker]; user_tcall = new ThreadCall(ec); disp_ec = workers[0].get_ec(); @@ -353,7 +402,8 @@ class ConnPool { disp_error_cb = [this](const std::exception_ptr err) { user_tcall->async_call([this, err](ThreadCall::Handle &) { stop_workers(); - if (error_cb) error_cb(err, true); + std::rethrow_exception(err); + //if (error_cb) error_cb(err, true); }); disp_ec.stop(); workers[0].stop_tcall(); diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index 772cce1..1e6daa1 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -26,7 +26,9 @@ #define _SALTICIDAE_CRYPTO_H #include "salticidae/type.h" +#include "salticidae/util.h" #include +#include namespace salticidae { @@ -114,6 +116,82 @@ class SHA1 { } }; +class TLSContext { + SSL_CTX *ctx; + friend class TLS; + public: + static void init_tls() { SSL_library_init(); } + TLSContext(): ctx(SSL_CTX_new(TLS_method())) { + if (ctx == nullptr) + throw std::runtime_error("TLSContext init error"); + } + + void use_cert_file(const std::string &fname) { + auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR); + } + + void use_priv_key_file(const std::string &fname) { + auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR); + } + + bool check_priv_key() { + return SSL_CTX_check_private_key(ctx) > 0; + } + + ~TLSContext() { SSL_CTX_free(ctx); } +}; + +using tls_context_t = ArcObj; + +class TLS { + SSL *ssl; + public: + TLS(const tls_context_t &ctx, int fd, bool accept): ssl(SSL_new(ctx->ctx)) { + if (ssl == nullptr) + throw std::runtime_error("TLS init error"); + if (!SSL_set_fd(ssl, fd)) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + if (accept) + SSL_set_accept_state(ssl); + else + SSL_set_connect_state(ssl); + } + + bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */ + auto ret = SSL_do_handshake(ssl); + if (ret == 1) return true; + auto err = SSL_get_error(ssl, ret); + if (err == SSL_ERROR_WANT_WRITE) + want_io_type = 1; + else if (err == SSL_ERROR_WANT_READ) + want_io_type = 0; + else + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + return false; + } + + inline int send(const void *buff, size_t size) { + return SSL_write(ssl, buff, size); + } + + inline int recv(void *buff, size_t size) { + return SSL_read(ssl, buff, size); + } + + int get_error(int ret) { + return SSL_get_error(ssl, ret); + } + + ~TLS() { + SSL_shutdown(ssl); + SSL_free(ssl); + } +}; + } #endif diff --git a/include/salticidae/util.h b/include/salticidae/util.h index 007fcc4..320c78f 100644 --- a/include/salticidae/util.h +++ b/include/salticidae/util.h @@ -83,6 +83,9 @@ enum SalticidaeErrorCode { SALTI_ERROR_OPT_UNKNOWN_ACTION, SALTI_ERROR_CONFIG_LINE_TOO_LONG, SALTI_ERROR_OPT_INVALID, + SALTI_ERROR_TLS_CERT_ERROR, + SALTI_ERROR_TLS_KEY_ERROR, + SALTI_ERROR_TLS_GENERIC_ERROR, SALTI_ERROR_UNKNOWN }; -- cgit v1.2.3