diff options
Diffstat (limited to 'include/salticidae/crypto.h')
-rw-r--r-- | include/salticidae/crypto.h | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index 772cce1..1e6daa1 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -26,7 +26,9 @@ #define _SALTICIDAE_CRYPTO_H #include "salticidae/type.h" +#include "salticidae/util.h" #include <openssl/sha.h> +#include <openssl/ssl.h> namespace salticidae { @@ -114,6 +116,82 @@ class SHA1 { } }; +class TLSContext { + SSL_CTX *ctx; + friend class TLS; + public: + static void init_tls() { SSL_library_init(); } + TLSContext(): ctx(SSL_CTX_new(TLS_method())) { + if (ctx == nullptr) + throw std::runtime_error("TLSContext init error"); + } + + void use_cert_file(const std::string &fname) { + auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR); + } + + void use_priv_key_file(const std::string &fname) { + auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR); + } + + bool check_priv_key() { + return SSL_CTX_check_private_key(ctx) > 0; + } + + ~TLSContext() { SSL_CTX_free(ctx); } +}; + +using tls_context_t = ArcObj<TLSContext>; + +class TLS { + SSL *ssl; + public: + TLS(const tls_context_t &ctx, int fd, bool accept): ssl(SSL_new(ctx->ctx)) { + if (ssl == nullptr) + throw std::runtime_error("TLS init error"); + if (!SSL_set_fd(ssl, fd)) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + if (accept) + SSL_set_accept_state(ssl); + else + SSL_set_connect_state(ssl); + } + + bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */ + auto ret = SSL_do_handshake(ssl); + if (ret == 1) return true; + auto err = SSL_get_error(ssl, ret); + if (err == SSL_ERROR_WANT_WRITE) + want_io_type = 1; + else if (err == SSL_ERROR_WANT_READ) + want_io_type = 0; + else + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + return false; + } + + inline int send(const void *buff, size_t size) { + return SSL_write(ssl, buff, size); + } + + inline int recv(void *buff, size_t size) { + return SSL_read(ssl, buff, size); + } + + int get_error(int ret) { + return SSL_get_error(ssl, ret); + } + + ~TLS() { + SSL_shutdown(ssl); + SSL_free(ssl); + } +}; + } #endif |