diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | include/salticidae/conn.h | 24 | ||||
-rw-r--r-- | include/salticidae/crypto.h | 166 | ||||
-rw-r--r-- | include/salticidae/util.h | 8 | ||||
-rw-r--r-- | src/conn.cpp | 6 | ||||
-rw-r--r-- | src/util.cpp | 7 | ||||
-rw-r--r-- | test/bench_network.cpp | 3 |
7 files changed, 191 insertions, 25 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index c3ad2ef..9c5c6e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ set(CMAKE_C_STANDARD 11) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/") find_package(Libuv REQUIRED) -find_package(OpenSSL REQUIRED) +find_package(OpenSSL 1.1.0 REQUIRED) include_directories(include) add_library(salticidae diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 076d64a..a791057 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -93,6 +93,7 @@ class ConnPool { socket_io_func *send_data_func; socket_io_func *recv_data_func; BoxObj<TLS> tls; + BoxObj<X509> peer_cert; static socket_io_func _recv_data; static socket_io_func _send_data; @@ -131,6 +132,7 @@ class ConnPool { operator std::string() const; const NetAddr &get_addr() const { return addr; } + const X509 &get_peer_cert() const { return *peer_cert; } ConnMode get_mode() const { return mode; } ConnPool *get_pool() const { return cpool; } MPSCWriteBuffer &get_send_buffer() { return send_buffer; } @@ -332,6 +334,8 @@ class ConnPool { bool _enable_tls; std::string _tls_cert_file; std::string _tls_key_file; + RcObj<X509> _tls_cert; + RcObj<PKey> _tls_key; public: Config(): @@ -341,8 +345,10 @@ class ConnPool { _nworker(1), _queue_capacity(0), _enable_tls(true), - _tls_cert_file("./server.pem"), - _tls_key_file("./server.pem") {} + _tls_cert_file("./all.pem"), + _tls_key_file("./all.pem"), + _tls_cert(nullptr), + _tls_key(nullptr) {} Config &max_listen_backlog(int x) { _max_listen_backlog = x; @@ -389,10 +395,16 @@ class ConnPool { if (enable_tls) { tls_ctx = new TLSContext(); - tls_ctx->use_cert_file(config._tls_cert_file); - tls_ctx->use_priv_key_file(config._tls_key_file); - if (!tls_ctx->check_priv_key()) - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + if (config._tls_cert) + tls_ctx->use_cert(*config._tls_cert); + else + tls_ctx->use_cert_file(config._tls_cert_file); + if (config._tls_key) + tls_ctx->use_privkey(*config._tls_key); + else + tls_ctx->use_privkey_file(config._tls_key_file); + if (!tls_ctx->check_privkey()) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); } workers = new Worker[nworker]; user_tcall = new ThreadCall(ec); diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index 1e6daa1..1d20b22 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -116,33 +116,166 @@ class SHA1 { } }; +static thread_local const char *_password; +static inline int _tls_pem_no_passswd(char *, int, int, void *) { + return -1; +} +static inline int _tls_pem_with_passwd(char *buf, int size, int, void *) { + size_t _size = strlen(_password) + 1; + if (_size > (size_t)size) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + memmove(buf, _password, _size); + return _size - 1; +} + +class PKey { + EVP_PKEY *key; + friend class TLSContext; + public: + PKey(EVP_PKEY *key): key(key) {} + PKey(const PKey &) = delete; + PKey(PKey &&other): key(other.key) { other.key = nullptr; } + + PKey create_privkey_from_pem_file(std::string pem_fname, std::string *password = nullptr) { + FILE *fp = fopen(pem_fname.c_str(), "r"); + EVP_PKEY *key; + if (fp == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + if (password) + { + _password = password->c_str(); + key = PEM_read_PrivateKey(fp, NULL, _tls_pem_with_passwd, NULL); + } + else + { + key = PEM_read_PrivateKey(fp, NULL, _tls_pem_no_passswd, NULL); + } + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + fclose(fp); + return PKey(key); + } + + PKey create_privkey_from_der(const uint8_t *der, size_t size) { + EVP_PKEY *key; + key = d2i_AutoPrivateKey(NULL, (const unsigned char **)&der, size); + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + return PKey(key); + } + + bytearray_t get_pubkey_der() { + uint8_t *der; + auto ret = i2d_PublicKey(key, &der); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + bytearray_t res(der, der + ret); + OPENSSL_cleanse(der, ret); + OPENSSL_free(der); + return std::move(res); + } + + bytearray_t get_privkey_der() { + uint8_t *der; + auto ret = i2d_PrivateKey(key, &der); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + bytearray_t res(der, der + ret); + OPENSSL_cleanse(der, ret); + OPENSSL_free(der); + return std::move(res); + } + + ~PKey() { if (key) EVP_PKEY_free(key); } +}; + +class X509 { + ::X509 *x509; + friend class TLSContext; + public: + X509(::X509 *x509): x509(x509) {} + X509(const X509 &) = delete; + X509(X509 &&other): x509(other.x509) { other.x509 = nullptr; } + + X509 create_from_pem_file(std::string pem_fname, std::string *password = nullptr) { + FILE *fp = fopen(pem_fname.c_str(), "r"); + ::X509 *x509; + if (fp == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + if (password) + { + _password = password->c_str(); + x509 = PEM_read_X509(fp, NULL, _tls_pem_with_passwd, NULL); + } + else + { + x509 = PEM_read_X509(fp, NULL, _tls_pem_no_passswd, NULL); + } + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + fclose(fp); + return X509(x509); + } + + X509 create_from_der(const uint8_t *der, size_t size) { + ::X509 *x509; + x509 = d2i_X509(NULL, (const unsigned char **)&der, size); + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + return X509(x509); + } + + PKey get_pubkey() { + auto key = X509_get_pubkey(x509); + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + return PKey(key); + } + + ~X509() { if (x509) X509_free(x509); } +}; + class TLSContext { SSL_CTX *ctx; friend class TLS; public: - static void init_tls() { SSL_library_init(); } TLSContext(): ctx(SSL_CTX_new(TLS_method())) { if (ctx == nullptr) throw std::runtime_error("TLSContext init error"); } + TLSContext(const TLSContext &) = delete; + TLSContext(TLSContext &&other): ctx(other.ctx) { other.ctx = nullptr; } + void use_cert_file(const std::string &fname) { auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); if (ret <= 0) - throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_CERT); } - void use_priv_key_file(const std::string &fname) { + void use_cert(const X509 &x509) { + auto ret = SSL_CTX_use_certificate(ctx, x509.x509); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_CERT); + } + + void use_privkey_file(const std::string &fname) { auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); if (ret <= 0) - throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY); } - bool check_priv_key() { + void use_privkey(const PKey &key) { + auto ret = SSL_CTX_use_PrivateKey(ctx, key.key); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY); + } + + bool check_privkey() { return SSL_CTX_check_private_key(ctx) > 0; } - ~TLSContext() { SSL_CTX_free(ctx); } + ~TLSContext() { if (ctx) SSL_CTX_free(ctx); } }; using tls_context_t = ArcObj<TLSContext>; @@ -154,13 +287,16 @@ class TLS { if (ssl == nullptr) throw std::runtime_error("TLS init error"); if (!SSL_set_fd(ssl, fd)) - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); if (accept) SSL_set_accept_state(ssl); else SSL_set_connect_state(ssl); } + TLS(const TLS &) = delete; + TLS(TLS &&other): ssl(other.ssl) { other.ssl = nullptr; } + bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */ auto ret = SSL_do_handshake(ssl); if (ret == 1) return true; @@ -170,10 +306,17 @@ class TLS { else if (err == SSL_ERROR_WANT_READ) want_io_type = 0; else - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); return false; } + X509 get_peer_cert() { + ::X509 *x509 = SSL_get_peer_certificate(ssl); + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); + return X509(x509); + } + inline int send(const void *buff, size_t size) { return SSL_write(ssl, buff, size); } @@ -187,8 +330,11 @@ class TLS { } ~TLS() { - SSL_shutdown(ssl); - SSL_free(ssl); + if (ssl) + { + SSL_shutdown(ssl); + SSL_free(ssl); + } } }; diff --git a/include/salticidae/util.h b/include/salticidae/util.h index 320c78f..dec498c 100644 --- a/include/salticidae/util.h +++ b/include/salticidae/util.h @@ -83,9 +83,11 @@ enum SalticidaeErrorCode { SALTI_ERROR_OPT_UNKNOWN_ACTION, SALTI_ERROR_CONFIG_LINE_TOO_LONG, SALTI_ERROR_OPT_INVALID, - SALTI_ERROR_TLS_CERT_ERROR, - SALTI_ERROR_TLS_KEY_ERROR, - SALTI_ERROR_TLS_GENERIC_ERROR, + SALTI_ERROR_TLS_LOAD_CERT, + SALTI_ERROR_TLS_LOAD_KEY, + SALTI_ERROR_TLS_GENERIC, + SALTI_ERROR_TLS_X509, + SALTI_ERROR_TLS_KEY, SALTI_ERROR_UNKNOWN }; diff --git a/src/conn.cpp b/src/conn.cpp index 3ec4284..60d5835 100644 --- a/src/conn.cpp +++ b/src/conn.cpp @@ -211,19 +211,21 @@ void ConnPool::Conn::_recv_data_tls(const ConnPool::conn_t &conn, int fd, int ev conn->on_read(); } -void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int fd, int events) { +void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int, int) { int ret; if (conn->tls->do_handshake(ret)) { + /* finishing TLS handshake */ conn->send_data_func = _send_data_tls; conn->recv_data_func = _recv_data_tls; + conn->peer_cert = new X509(conn->tls->get_peer_cert()); conn->cpool->update_conn(conn, true); } else { conn->ev_socket.del(); conn->ev_socket.add(ret == 0 ? FdEvent::READ : FdEvent::WRITE); - SALTICIDAE_LOG_INFO("tls handshake %d", ret); + SALTICIDAE_LOG_DEBUG("tls handshake %s", ret == 0 ? "read" : "write"); } } diff --git a/src/util.cpp b/src/util.cpp index fde326a..66bcd12 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -49,8 +49,11 @@ const char *SALTICIDAE_ERROR_STRINGS[] = { "configuration file line too long", "invalid option format", "unable to load cert", - "uable to load key", - "tls generic error" + "unable to load key", + "tls generic error", + "x509 cert error", + "EVP_PKEY error", + "unknown error" }; const char *TTY_COLOR_RED = "\x1b[31m"; diff --git a/test/bench_network.cpp b/test/bench_network.cpp index a498954..ca22db4 100644 --- a/test/bench_network.cpp +++ b/test/bench_network.cpp @@ -81,7 +81,8 @@ struct MyNet: public MsgNetworkByteOp { const std::string name, const NetAddr &peer, double stat_timeout = -1): - MsgNetworkByteOp(ec, MsgNetworkByteOp::Config().burst_size(1000).queue_capacity(65536)), + MsgNetworkByteOp(ec, MsgNetworkByteOp::Config( + ConnPool::Config().queue_capacity(65536)).burst_size(1000)), name(name), peer(peer), ev_period_stat(ec, [this, stat_timeout](TimerEvent &) { |