diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/salticidae/conn.h | 63 | ||||
-rw-r--r-- | include/salticidae/crypto.h | 19 | ||||
-rw-r--r-- | include/salticidae/network.h | 2 | ||||
-rw-r--r-- | include/salticidae/util.h | 2 |
4 files changed, 69 insertions, 17 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index a791057..59d93fc 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -58,7 +58,7 @@ class ConnPool { /** The handle to a bi-directional connection. */ using conn_t = ArcObj<Conn>; /** The type of callback invoked when connection status is changed. */ - using conn_callback_t = std::function<void(const conn_t &, bool)>; + using conn_callback_t = std::function<bool(const conn_t &, bool)>; using error_callback_t = std::function<void(const std::exception_ptr, bool)>; /** Abstraction for a bi-directional connection. */ class Conn { @@ -93,7 +93,7 @@ class ConnPool { socket_io_func *send_data_func; socket_io_func *recv_data_func; BoxObj<TLS> tls; - BoxObj<X509> peer_cert; + BoxObj<const X509> peer_cert; static socket_io_func _recv_data; static socket_io_func _send_data; @@ -102,6 +102,7 @@ class ConnPool { static socket_io_func _send_data_tls; static socket_io_func _recv_data_tls_handshake; static socket_io_func _send_data_tls_handshake; + static socket_io_func _recv_data_dummy; void conn_server(int, int); @@ -189,7 +190,12 @@ class ConnPool { void update_conn(const conn_t &conn, bool connected) { user_tcall->async_call([this, conn, connected](ThreadCall::Handle &) { - if (conn_cb) conn_cb(conn, connected); + if ((!conn_cb || + conn_cb(conn, connected)) && + enable_tls && connected) + conn->worker->get_tcall()->async_call([conn](ThreadCall::Handle &) { + conn->recv_data_func = Conn::_recv_data_tls; + }); }); } @@ -264,7 +270,10 @@ class ConnPool { conn->recv_data_func(conn, fd, what); else conn->send_data_func(conn, fd, what); - } catch (...) { on_fatal_error(std::current_exception()); } + } catch (...) { + conn->cpool->recoverable_error(std::current_exception()); + conn->worker_terminate(); + } }); conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); nconn++; @@ -336,6 +345,8 @@ class ConnPool { std::string _tls_key_file; RcObj<X509> _tls_cert; RcObj<PKey> _tls_key; + bool _tls_skip_ca_check; + SSL_verify_cb _tls_verify_callback; public: Config(): @@ -344,11 +355,13 @@ class ConnPool { _seg_buff_size(4096), _nworker(1), _queue_capacity(0), - _enable_tls(true), - _tls_cert_file("./all.pem"), - _tls_key_file("./all.pem"), + _enable_tls(false), + _tls_cert_file(""), + _tls_key_file(""), _tls_cert(nullptr), - _tls_key(nullptr) {} + _tls_key(nullptr), + _tls_skip_ca_check(true), + _tls_verify_callback(nullptr) {} Config &max_listen_backlog(int x) { _max_listen_backlog = x; @@ -379,6 +392,36 @@ class ConnPool { _enable_tls = x; return *this; } + + Config &tls_cert_file(const std::string &x) { + _tls_cert_file = x; + return *this; + } + + Config &tls_key_file(const std::string &x) { + _tls_key_file = x; + return *this; + } + + Config &tls_cert(X509 *x) { + _tls_cert = x; + return *this; + } + + Config &tls_key(PKey *x) { + _tls_key = x; + return *this; + } + + Config &tls_skip_ca_check(bool *x) { + _tls_skip_ca_check = x; + return *this; + } + + Config &tls_verify_callback(SSL_verify_cb x) { + _tls_verify_callback = x; + return *this; + } }; ConnPool(const EventContext &ec, const Config &config): @@ -403,9 +446,11 @@ class ConnPool { tls_ctx->use_privkey(*config._tls_key); else tls_ctx->use_privkey_file(config._tls_key_file); + tls_ctx->set_verify(config._tls_skip_ca_check, config._tls_verify_callback); if (!tls_ctx->check_privkey()) - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); + throw SalticidaeError(SALTI_ERROR_TLS_KEY_NOT_MATCH); } + signal(SIGPIPE, SIG_IGN); workers = new Worker[nworker]; user_tcall = new ThreadCall(ec); disp_ec = workers[0].get_ec(); diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index 1d20b22..bcfd9dc 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -128,6 +128,10 @@ static inline int _tls_pem_with_passwd(char *buf, int size, int, void *) { return _size - 1; } +static int _skip_CA_check(int, X509_STORE_CTX *) { + return 1; +} + class PKey { EVP_PKEY *key; friend class TLSContext; @@ -271,6 +275,11 @@ class TLSContext { throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY); } + void set_verify(bool skip_ca_check = true, SSL_verify_cb verify_callback = nullptr) { + SSL_CTX_set_verify(ctx, + SSL_VERIFY_PEER, skip_ca_check ? _skip_CA_check : verify_callback); + } + bool check_privkey() { return SSL_CTX_check_private_key(ctx) > 0; } @@ -329,13 +338,9 @@ class TLS { return SSL_get_error(ssl, ret); } - ~TLS() { - if (ssl) - { - SSL_shutdown(ssl); - SSL_free(ssl); - } - } + void shutdown() { SSL_shutdown(ssl); } + + ~TLS() { if (ssl) SSL_free(ssl); } }; } diff --git a/include/salticidae/network.h b/include/salticidae/network.h index e9fdae6..b703c35 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -996,7 +996,7 @@ void msgnetwork_terminate(msgnetwork_t *self, const msgnetwork_conn_t *conn); typedef void (*msgnetwork_msg_callback_t)(const msg_t *, const msgnetwork_conn_t *, void *userdata); void msgnetwork_reg_handler(msgnetwork_t *self, _opcode_t opcode, msgnetwork_msg_callback_t cb, void *userdata); -typedef void (*msgnetwork_conn_callback_t)(const msgnetwork_conn_t *, bool connected, void *userdata); +typedef bool (*msgnetwork_conn_callback_t)(const msgnetwork_conn_t *, bool connected, void *userdata); void msgnetwork_reg_conn_handler(msgnetwork_t *self, msgnetwork_conn_callback_t cb, void *userdata); diff --git a/include/salticidae/util.h b/include/salticidae/util.h index dec498c..9a57ae8 100644 --- a/include/salticidae/util.h +++ b/include/salticidae/util.h @@ -88,6 +88,8 @@ enum SalticidaeErrorCode { SALTI_ERROR_TLS_GENERIC, SALTI_ERROR_TLS_X509, SALTI_ERROR_TLS_KEY, + SALTI_ERROR_TLS_KEY_NOT_MATCH, + SALTI_ERROR_TLS_NO_PEER_CERT, SALTI_ERROR_UNKNOWN }; |