From 4f41e23016dc316334e7d6cc8765bdf334b96f3e Mon Sep 17 00:00:00 2001 From: Determinant Date: Wed, 19 Jun 2019 19:11:58 -0400 Subject: more openssl wrappers --- include/salticidae/crypto.h | 166 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 156 insertions(+), 10 deletions(-) (limited to 'include/salticidae/crypto.h') diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index 1e6daa1..1d20b22 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -116,33 +116,166 @@ class SHA1 { } }; +static thread_local const char *_password; +static inline int _tls_pem_no_passswd(char *, int, int, void *) { + return -1; +} +static inline int _tls_pem_with_passwd(char *buf, int size, int, void *) { + size_t _size = strlen(_password) + 1; + if (_size > (size_t)size) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + memmove(buf, _password, _size); + return _size - 1; +} + +class PKey { + EVP_PKEY *key; + friend class TLSContext; + public: + PKey(EVP_PKEY *key): key(key) {} + PKey(const PKey &) = delete; + PKey(PKey &&other): key(other.key) { other.key = nullptr; } + + PKey create_privkey_from_pem_file(std::string pem_fname, std::string *password = nullptr) { + FILE *fp = fopen(pem_fname.c_str(), "r"); + EVP_PKEY *key; + if (fp == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + if (password) + { + _password = password->c_str(); + key = PEM_read_PrivateKey(fp, NULL, _tls_pem_with_passwd, NULL); + } + else + { + key = PEM_read_PrivateKey(fp, NULL, _tls_pem_no_passswd, NULL); + } + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + fclose(fp); + return PKey(key); + } + + PKey create_privkey_from_der(const uint8_t *der, size_t size) { + EVP_PKEY *key; + key = d2i_AutoPrivateKey(NULL, (const unsigned char **)&der, size); + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + return PKey(key); + } + + bytearray_t get_pubkey_der() { + uint8_t *der; + auto ret = i2d_PublicKey(key, &der); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + bytearray_t res(der, der + ret); + OPENSSL_cleanse(der, ret); + OPENSSL_free(der); + return std::move(res); + } + + bytearray_t get_privkey_der() { + uint8_t *der; + auto ret = i2d_PrivateKey(key, &der); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY); + bytearray_t res(der, der + ret); + OPENSSL_cleanse(der, ret); + OPENSSL_free(der); + return std::move(res); + } + + ~PKey() { if (key) EVP_PKEY_free(key); } +}; + +class X509 { + ::X509 *x509; + friend class TLSContext; + public: + X509(::X509 *x509): x509(x509) {} + X509(const X509 &) = delete; + X509(X509 &&other): x509(other.x509) { other.x509 = nullptr; } + + X509 create_from_pem_file(std::string pem_fname, std::string *password = nullptr) { + FILE *fp = fopen(pem_fname.c_str(), "r"); + ::X509 *x509; + if (fp == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + if (password) + { + _password = password->c_str(); + x509 = PEM_read_X509(fp, NULL, _tls_pem_with_passwd, NULL); + } + else + { + x509 = PEM_read_X509(fp, NULL, _tls_pem_no_passswd, NULL); + } + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + fclose(fp); + return X509(x509); + } + + X509 create_from_der(const uint8_t *der, size_t size) { + ::X509 *x509; + x509 = d2i_X509(NULL, (const unsigned char **)&der, size); + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + return X509(x509); + } + + PKey get_pubkey() { + auto key = X509_get_pubkey(x509); + if (key == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + return PKey(key); + } + + ~X509() { if (x509) X509_free(x509); } +}; + class TLSContext { SSL_CTX *ctx; friend class TLS; public: - static void init_tls() { SSL_library_init(); } TLSContext(): ctx(SSL_CTX_new(TLS_method())) { if (ctx == nullptr) throw std::runtime_error("TLSContext init error"); } + TLSContext(const TLSContext &) = delete; + TLSContext(TLSContext &&other): ctx(other.ctx) { other.ctx = nullptr; } + void use_cert_file(const std::string &fname) { auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); if (ret <= 0) - throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_CERT); } - void use_priv_key_file(const std::string &fname) { + void use_cert(const X509 &x509) { + auto ret = SSL_CTX_use_certificate(ctx, x509.x509); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_CERT); + } + + void use_privkey_file(const std::string &fname) { auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); if (ret <= 0) - throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY); } - bool check_priv_key() { + void use_privkey(const PKey &key) { + auto ret = SSL_CTX_use_PrivateKey(ctx, key.key); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY); + } + + bool check_privkey() { return SSL_CTX_check_private_key(ctx) > 0; } - ~TLSContext() { SSL_CTX_free(ctx); } + ~TLSContext() { if (ctx) SSL_CTX_free(ctx); } }; using tls_context_t = ArcObj; @@ -154,13 +287,16 @@ class TLS { if (ssl == nullptr) throw std::runtime_error("TLS init error"); if (!SSL_set_fd(ssl, fd)) - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); if (accept) SSL_set_accept_state(ssl); else SSL_set_connect_state(ssl); } + TLS(const TLS &) = delete; + TLS(TLS &&other): ssl(other.ssl) { other.ssl = nullptr; } + bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */ auto ret = SSL_do_handshake(ssl); if (ret == 1) return true; @@ -170,10 +306,17 @@ class TLS { else if (err == SSL_ERROR_WANT_READ) want_io_type = 0; else - throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); return false; } + X509 get_peer_cert() { + ::X509 *x509 = SSL_get_peer_certificate(ssl); + if (x509 == nullptr) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC); + return X509(x509); + } + inline int send(const void *buff, size_t size) { return SSL_write(ssl, buff, size); } @@ -187,8 +330,11 @@ class TLS { } ~TLS() { - SSL_shutdown(ssl); - SSL_free(ssl); + if (ssl) + { + SSL_shutdown(ssl); + SSL_free(ssl); + } } }; -- cgit v1.2.3