aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/salticidae/conn.h24
-rw-r--r--include/salticidae/crypto.h166
-rw-r--r--include/salticidae/util.h8
-rw-r--r--src/conn.cpp6
-rw-r--r--src/util.cpp7
-rw-r--r--test/bench_network.cpp3
7 files changed, 191 insertions, 25 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c3ad2ef..9c5c6e8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -27,7 +27,7 @@ set(CMAKE_C_STANDARD 11)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/")
find_package(Libuv REQUIRED)
-find_package(OpenSSL REQUIRED)
+find_package(OpenSSL 1.1.0 REQUIRED)
include_directories(include)
add_library(salticidae
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h
index 076d64a..a791057 100644
--- a/include/salticidae/conn.h
+++ b/include/salticidae/conn.h
@@ -93,6 +93,7 @@ class ConnPool {
socket_io_func *send_data_func;
socket_io_func *recv_data_func;
BoxObj<TLS> tls;
+ BoxObj<X509> peer_cert;
static socket_io_func _recv_data;
static socket_io_func _send_data;
@@ -131,6 +132,7 @@ class ConnPool {
operator std::string() const;
const NetAddr &get_addr() const { return addr; }
+ const X509 &get_peer_cert() const { return *peer_cert; }
ConnMode get_mode() const { return mode; }
ConnPool *get_pool() const { return cpool; }
MPSCWriteBuffer &get_send_buffer() { return send_buffer; }
@@ -332,6 +334,8 @@ class ConnPool {
bool _enable_tls;
std::string _tls_cert_file;
std::string _tls_key_file;
+ RcObj<X509> _tls_cert;
+ RcObj<PKey> _tls_key;
public:
Config():
@@ -341,8 +345,10 @@ class ConnPool {
_nworker(1),
_queue_capacity(0),
_enable_tls(true),
- _tls_cert_file("./server.pem"),
- _tls_key_file("./server.pem") {}
+ _tls_cert_file("./all.pem"),
+ _tls_key_file("./all.pem"),
+ _tls_cert(nullptr),
+ _tls_key(nullptr) {}
Config &max_listen_backlog(int x) {
_max_listen_backlog = x;
@@ -389,10 +395,16 @@ class ConnPool {
if (enable_tls)
{
tls_ctx = new TLSContext();
- tls_ctx->use_cert_file(config._tls_cert_file);
- tls_ctx->use_priv_key_file(config._tls_key_file);
- if (!tls_ctx->check_priv_key())
- throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ if (config._tls_cert)
+ tls_ctx->use_cert(*config._tls_cert);
+ else
+ tls_ctx->use_cert_file(config._tls_cert_file);
+ if (config._tls_key)
+ tls_ctx->use_privkey(*config._tls_key);
+ else
+ tls_ctx->use_privkey_file(config._tls_key_file);
+ if (!tls_ctx->check_privkey())
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC);
}
workers = new Worker[nworker];
user_tcall = new ThreadCall(ec);
diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h
index 1e6daa1..1d20b22 100644
--- a/include/salticidae/crypto.h
+++ b/include/salticidae/crypto.h
@@ -116,33 +116,166 @@ class SHA1 {
}
};
+static thread_local const char *_password;
+static inline int _tls_pem_no_passswd(char *, int, int, void *) {
+ return -1;
+}
+static inline int _tls_pem_with_passwd(char *buf, int size, int, void *) {
+ size_t _size = strlen(_password) + 1;
+ if (_size > (size_t)size)
+ throw SalticidaeError(SALTI_ERROR_TLS_X509);
+ memmove(buf, _password, _size);
+ return _size - 1;
+}
+
+class PKey {
+ EVP_PKEY *key;
+ friend class TLSContext;
+ public:
+ PKey(EVP_PKEY *key): key(key) {}
+ PKey(const PKey &) = delete;
+ PKey(PKey &&other): key(other.key) { other.key = nullptr; }
+
+ PKey create_privkey_from_pem_file(std::string pem_fname, std::string *password = nullptr) {
+ FILE *fp = fopen(pem_fname.c_str(), "r");
+ EVP_PKEY *key;
+ if (fp == nullptr)
+ throw SalticidaeError(SALTI_ERROR_TLS_KEY);
+ if (password)
+ {
+ _password = password->c_str();
+ key = PEM_read_PrivateKey(fp, NULL, _tls_pem_with_passwd, NULL);
+ }
+ else
+ {
+ key = PEM_read_PrivateKey(fp, NULL, _tls_pem_no_passswd, NULL);
+ }
+ if (key == nullptr)
+ throw SalticidaeError(SALTI_ERROR_TLS_KEY);
+ fclose(fp);
+ return PKey(key);
+ }
+
+ PKey create_privkey_from_der(const uint8_t *der, size_t size) {
+ EVP_PKEY *key;
+ key = d2i_AutoPrivateKey(NULL, (const unsigned char **)&der, size);
+ if (key == nullptr)
+ throw SalticidaeError(SALTI_ERROR_TLS_KEY);
+ return PKey(key);
+ }
+
+ bytearray_t get_pubkey_der() {
+ uint8_t *der;
+ auto ret = i2d_PublicKey(key, &der);
+ if (ret <= 0)
+ throw SalticidaeError(SALTI_ERROR_TLS_KEY);
+ bytearray_t res(der, der + ret);
+ OPENSSL_cleanse(der, ret);
+ OPENSSL_free(der);
+ return std::move(res);
+ }
+
+ bytearray_t get_privkey_der() {
+ uint8_t *der;
+ auto ret = i2d_PrivateKey(key, &der);
+ if (ret <= 0)
+ throw SalticidaeError(SALTI_ERROR_TLS_KEY);
+ bytearray_t res(der, der + ret);
+ OPENSSL_cleanse(der, ret);
+ OPENSSL_free(der);
+ return std::move(res);
+ }
+
+ ~PKey() { if (key) EVP_PKEY_free(key); }
+};
+
+class X509 {
+ ::X509 *x509;
+ friend class TLSContext;
+ public:
+ X509(::X509 *x509): x509(x509) {}
+ X509(const X509 &) = delete;
+ X509(X509 &&other): x509(other.x509) { other.x509 = nullptr; }
+
+ X509 create_from_pem_file(std::string pem_fname, std::string *password = nullptr) {
+ FILE *fp = fopen(pem_fname.c_str(), "r");
+ ::X509 *x509;
+ if (fp == nullptr)
+ throw SalticidaeError(SALTI_ERROR_TLS_X509);
+ if (password)
+ {
+ _password = password->c_str();
+ x509 = PEM_read_X509(fp, NULL, _tls_pem_with_passwd, NULL);
+ }
+ else
+ {
+ x509 = PEM_read_X509(fp, NULL, _tls_pem_no_passswd, NULL);
+ }
+ if (x509 == nullptr)
+ throw SalticidaeError(SALTI_ERROR_TLS_X509);
+ fclose(fp);
+ return X509(x509);
+ }
+
+ X509 create_from_der(const uint8_t *der, size_t size) {
+ ::X509 *x509;
+ x509 = d2i_X509(NULL, (const unsigned char **)&der, size);
+ if (x509 == nullptr)
+ throw SalticidaeError(SALTI_ERROR_TLS_X509);
+ return X509(x509);
+ }
+
+ PKey get_pubkey() {
+ auto key = X509_get_pubkey(x509);
+ if (key == nullptr)
+ throw SalticidaeError(SALTI_ERROR_TLS_X509);
+ return PKey(key);
+ }
+
+ ~X509() { if (x509) X509_free(x509); }
+};
+
class TLSContext {
SSL_CTX *ctx;
friend class TLS;
public:
- static void init_tls() { SSL_library_init(); }
TLSContext(): ctx(SSL_CTX_new(TLS_method())) {
if (ctx == nullptr)
throw std::runtime_error("TLSContext init error");
}
+ TLSContext(const TLSContext &) = delete;
+ TLSContext(TLSContext &&other): ctx(other.ctx) { other.ctx = nullptr; }
+
void use_cert_file(const std::string &fname) {
auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM);
if (ret <= 0)
- throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR);
+ throw SalticidaeError(SALTI_ERROR_TLS_LOAD_CERT);
}
- void use_priv_key_file(const std::string &fname) {
+ void use_cert(const X509 &x509) {
+ auto ret = SSL_CTX_use_certificate(ctx, x509.x509);
+ if (ret <= 0)
+ throw SalticidaeError(SALTI_ERROR_TLS_LOAD_CERT);
+ }
+
+ void use_privkey_file(const std::string &fname) {
auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM);
if (ret <= 0)
- throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR);
+ throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY);
}
- bool check_priv_key() {
+ void use_privkey(const PKey &key) {
+ auto ret = SSL_CTX_use_PrivateKey(ctx, key.key);
+ if (ret <= 0)
+ throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY);
+ }
+
+ bool check_privkey() {
return SSL_CTX_check_private_key(ctx) > 0;
}
- ~TLSContext() { SSL_CTX_free(ctx); }
+ ~TLSContext() { if (ctx) SSL_CTX_free(ctx); }
};
using tls_context_t = ArcObj<TLSContext>;
@@ -154,13 +287,16 @@ class TLS {
if (ssl == nullptr)
throw std::runtime_error("TLS init error");
if (!SSL_set_fd(ssl, fd))
- throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC);
if (accept)
SSL_set_accept_state(ssl);
else
SSL_set_connect_state(ssl);
}
+ TLS(const TLS &) = delete;
+ TLS(TLS &&other): ssl(other.ssl) { other.ssl = nullptr; }
+
bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */
auto ret = SSL_do_handshake(ssl);
if (ret == 1) return true;
@@ -170,10 +306,17 @@ class TLS {
else if (err == SSL_ERROR_WANT_READ)
want_io_type = 0;
else
- throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC);
return false;
}
+ X509 get_peer_cert() {
+ ::X509 *x509 = SSL_get_peer_certificate(ssl);
+ if (x509 == nullptr)
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC);
+ return X509(x509);
+ }
+
inline int send(const void *buff, size_t size) {
return SSL_write(ssl, buff, size);
}
@@ -187,8 +330,11 @@ class TLS {
}
~TLS() {
- SSL_shutdown(ssl);
- SSL_free(ssl);
+ if (ssl)
+ {
+ SSL_shutdown(ssl);
+ SSL_free(ssl);
+ }
}
};
diff --git a/include/salticidae/util.h b/include/salticidae/util.h
index 320c78f..dec498c 100644
--- a/include/salticidae/util.h
+++ b/include/salticidae/util.h
@@ -83,9 +83,11 @@ enum SalticidaeErrorCode {
SALTI_ERROR_OPT_UNKNOWN_ACTION,
SALTI_ERROR_CONFIG_LINE_TOO_LONG,
SALTI_ERROR_OPT_INVALID,
- SALTI_ERROR_TLS_CERT_ERROR,
- SALTI_ERROR_TLS_KEY_ERROR,
- SALTI_ERROR_TLS_GENERIC_ERROR,
+ SALTI_ERROR_TLS_LOAD_CERT,
+ SALTI_ERROR_TLS_LOAD_KEY,
+ SALTI_ERROR_TLS_GENERIC,
+ SALTI_ERROR_TLS_X509,
+ SALTI_ERROR_TLS_KEY,
SALTI_ERROR_UNKNOWN
};
diff --git a/src/conn.cpp b/src/conn.cpp
index 3ec4284..60d5835 100644
--- a/src/conn.cpp
+++ b/src/conn.cpp
@@ -211,19 +211,21 @@ void ConnPool::Conn::_recv_data_tls(const ConnPool::conn_t &conn, int fd, int ev
conn->on_read();
}
-void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int fd, int events) {
+void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int, int) {
int ret;
if (conn->tls->do_handshake(ret))
{
+ /* finishing TLS handshake */
conn->send_data_func = _send_data_tls;
conn->recv_data_func = _recv_data_tls;
+ conn->peer_cert = new X509(conn->tls->get_peer_cert());
conn->cpool->update_conn(conn, true);
}
else
{
conn->ev_socket.del();
conn->ev_socket.add(ret == 0 ? FdEvent::READ : FdEvent::WRITE);
- SALTICIDAE_LOG_INFO("tls handshake %d", ret);
+ SALTICIDAE_LOG_DEBUG("tls handshake %s", ret == 0 ? "read" : "write");
}
}
diff --git a/src/util.cpp b/src/util.cpp
index fde326a..66bcd12 100644
--- a/src/util.cpp
+++ b/src/util.cpp
@@ -49,8 +49,11 @@ const char *SALTICIDAE_ERROR_STRINGS[] = {
"configuration file line too long",
"invalid option format",
"unable to load cert",
- "uable to load key",
- "tls generic error"
+ "unable to load key",
+ "tls generic error",
+ "x509 cert error",
+ "EVP_PKEY error",
+ "unknown error"
};
const char *TTY_COLOR_RED = "\x1b[31m";
diff --git a/test/bench_network.cpp b/test/bench_network.cpp
index a498954..ca22db4 100644
--- a/test/bench_network.cpp
+++ b/test/bench_network.cpp
@@ -81,7 +81,8 @@ struct MyNet: public MsgNetworkByteOp {
const std::string name,
const NetAddr &peer,
double stat_timeout = -1):
- MsgNetworkByteOp(ec, MsgNetworkByteOp::Config().burst_size(1000).queue_capacity(65536)),
+ MsgNetworkByteOp(ec, MsgNetworkByteOp::Config(
+ ConnPool::Config().queue_capacity(65536)).burst_size(1000)),
name(name),
peer(peer),
ev_period_stat(ec, [this, stat_timeout](TimerEvent &) {