diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/salticidae/conn.h | 24 | ||||
-rw-r--r-- | include/salticidae/crypto.h | 166 | ||||
-rw-r--r-- | include/salticidae/util.h | 8 |
3 files changed, 179 insertions, 19 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 076d64a..a791057 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -93,6 +93,7 @@ class ConnPool { socket_io_func *send_data_func; socket_io_func *recv_data_func; BoxObj<TLS> tls; + BoxObj<X509> peer_cert; static socket_io_func _recv_data; static socket_io_func _send_data; @@ -131,6 +132,7 @@ class ConnPool { operator std::string() const; const NetAddr &get_addr() const { return addr; } + const X509 &get_peer_cert() const { return *peer_cert; } ConnMode get_mode() const { return mode; } ConnPool *get_pool() const { return cpool; } MPSCWriteBuffer &get_send_buffer() { return send_buffer; } @@ -332,6 +334,8 @@ class ConnPool { bool _enable_tls; std::string _tls_cert_file; std::string _tls_key_file; + RcObj<X509> _tls_cert; + RcObj<PKey> _tls_key; public: Config(): @@ -341,8 +345,10 @@ class ConnPool { _nworker(1), _queue_capacity(0), _enable_tls(true), - _tls_cert_file("./server.pem"), - _tls_key_file("./server.pem") {} + _tls_cert_file("./all.pem"), + _tls_key_file("./all.pem"), + _tls_cert(nullptr), + _tls_key(nullptr) {} Config &max_listen_backlog(int x) { _max_listen_backlog = x; @@ -389,10 +395,16 @@ class ConnPool { if (enable_tls) { tls_ctx = new TLSContext(); - tls_ctx->use_cert_file(config._tls_cert_file); - tls_ctx->use_priv_key_file(config._tls_key_file); - if (!tls_ctx->check_priv_key()) - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + if (config._tls_cert) + tls_ctx->use_cert(*config._tls_cert); + else + tls_ctx->use_cert_file(config._tls_cert_file); + if (config._tls_key) + tls_ctx->use_privkey(*config._tls_key); + else + tls_ctx->use_privkey_file(config._tls_key_file); + if (!tls_ctx->check_privkey()) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); } workers = new Worker[nworker]; user_tcall = new ThreadCall(ec); diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index 1e6daa1..1d20b22 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -116,33 +116,166 @@ class SHA1 { } }; +static thread_local const char *_password; +static inline int _tls_pem_no_passswd(char *, int, int, void *) { + return -1; +} +static inline int _tls_pem_with_passwd(char *buf, int size, int, void *) { + size_t _size = strlen(_password) + 1; + if (_size > (size_t)size) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + memmove(buf, _password, _size); + return _size - 1; +} + +class PKey { + EVP_PKEY *key; + friend class TLSContext; + public: + PKey(EVP_PKEY *key): key(key) {} + PKey(const PKey &) = delete; + PKey(PKey &&other): key(other.key) { other.key = nullptr; } + + PKey create_privkey_from_pem_file(std::string pem_fname, std::string *password = nullptr) { + FILE *fp = fopen(pem_fname.c_str(), "r"); + EVP_PKEY *key; + if (fp == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + if (password) + { + _password = password->c_str(); + key = PEM_read_PrivateKey(fp, NULL, _tls_pem_with_passwd, NULL); + } + else + { + key = PEM_read_PrivateKey(fp, NULL, _tls_pem_no_passswd, NULL); + } + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + fclose(fp); + return PKey(key); + } + + PKey create_privkey_from_der(const uint8_t *der, size_t size) { + EVP_PKEY *key; + key = d2i_AutoPrivateKey(NULL, (const unsigned char **)&der, size); + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + return PKey(key); + } + + bytearray_t get_pubkey_der() { + uint8_t *der; + auto ret = i2d_PublicKey(key, &der); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + bytearray_t res(der, der + ret); + OPENSSL_cleanse(der, ret); + OPENSSL_free(der); + return std::move(res); + } + + bytearray_t get_privkey_der() { + uint8_t *der; + auto ret = i2d_PrivateKey(key, &der); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + bytearray_t res(der, der + ret); + OPENSSL_cleanse(der, ret); + OPENSSL_free(der); + return std::move(res); + } + + ~PKey() { if (key) EVP_PKEY_free(key); } +}; + +class X509 { + ::X509 *x509; + friend class TLSContext; + public: + X509(::X509 *x509): x509(x509) {} + X509(const X509 &) = delete; + X509(X509 &&other): x509(other.x509) { other.x509 = nullptr; } + + X509 create_from_pem_file(std::string pem_fname, std::string *password = nullptr) { + FILE *fp = fopen(pem_fname.c_str(), "r"); + ::X509 *x509; + if (fp == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + if (password) + { + _password = password->c_str(); + x509 = PEM_read_X509(fp, NULL, _tls_pem_with_passwd, NULL); + } + else + { + x509 = PEM_read_X509(fp, NULL, _tls_pem_no_passswd, NULL); + } + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + fclose(fp); + return X509(x509); + } + + X509 create_from_der(const uint8_t *der, size_t size) { + ::X509 *x509; + x509 = d2i_X509(NULL, (const unsigned char **)&der, size); + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + return X509(x509); + } + + PKey get_pubkey() { + auto key = X509_get_pubkey(x509); + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + return PKey(key); + } + + ~X509() { if (x509) X509_free(x509); } +}; + class TLSContext { SSL_CTX *ctx; friend class TLS; public: - static void init_tls() { SSL_library_init(); } TLSContext(): ctx(SSL_CTX_new(TLS_method())) { if (ctx == nullptr) throw std::runtime_error("TLSContext init error"); } + TLSContext(const TLSContext &) = delete; + TLSContext(TLSContext &&other): ctx(other.ctx) { other.ctx = nullptr; } + void use_cert_file(const std::string &fname) { auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); if (ret <= 0) - throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_CERT); } - void use_priv_key_file(const std::string &fname) { + void use_cert(const X509 &x509) { + auto ret = SSL_CTX_use_certificate(ctx, x509.x509); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_CERT); + } + + void use_privkey_file(const std::string &fname) { auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); if (ret <= 0) - throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY); } - bool check_priv_key() { + void use_privkey(const PKey &key) { + auto ret = SSL_CTX_use_PrivateKey(ctx, key.key); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY); + } + + bool check_privkey() { return SSL_CTX_check_private_key(ctx) > 0; } - ~TLSContext() { SSL_CTX_free(ctx); } + ~TLSContext() { if (ctx) SSL_CTX_free(ctx); } }; using tls_context_t = ArcObj<TLSContext>; @@ -154,13 +287,16 @@ class TLS { if (ssl == nullptr) throw std::runtime_error("TLS init error"); if (!SSL_set_fd(ssl, fd)) - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); if (accept) SSL_set_accept_state(ssl); else SSL_set_connect_state(ssl); } + TLS(const TLS &) = delete; + TLS(TLS &&other): ssl(other.ssl) { other.ssl = nullptr; } + bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */ auto ret = SSL_do_handshake(ssl); if (ret == 1) return true; @@ -170,10 +306,17 @@ class TLS { else if (err == SSL_ERROR_WANT_READ) want_io_type = 0; else - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); return false; } + X509 get_peer_cert() { + ::X509 *x509 = SSL_get_peer_certificate(ssl); + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); + return X509(x509); + } + inline int send(const void *buff, size_t size) { return SSL_write(ssl, buff, size); } @@ -187,8 +330,11 @@ class TLS { } ~TLS() { - SSL_shutdown(ssl); - SSL_free(ssl); + if (ssl) + { + SSL_shutdown(ssl); + SSL_free(ssl); + } } }; diff --git a/include/salticidae/util.h b/include/salticidae/util.h index 320c78f..dec498c 100644 --- a/include/salticidae/util.h +++ b/include/salticidae/util.h @@ -83,9 +83,11 @@ enum SalticidaeErrorCode { SALTI_ERROR_OPT_UNKNOWN_ACTION, SALTI_ERROR_CONFIG_LINE_TOO_LONG, SALTI_ERROR_OPT_INVALID, - SALTI_ERROR_TLS_CERT_ERROR, - SALTI_ERROR_TLS_KEY_ERROR, - SALTI_ERROR_TLS_GENERIC_ERROR, + SALTI_ERROR_TLS_LOAD_CERT, + SALTI_ERROR_TLS_LOAD_KEY, + SALTI_ERROR_TLS_GENERIC, + SALTI_ERROR_TLS_X509, + SALTI_ERROR_TLS_KEY, SALTI_ERROR_UNKNOWN }; |