aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2019-06-18 18:19:39 -0400
committerDeterminant <ted.sybil@gmail.com>2019-06-18 18:19:39 -0400
commit8f42d0581a8e0cd77bde459db6b61fd957e19c1b (patch)
tree216d607daa37b0bf2308b43375e8a1eb41dce6f8 /include
parentd91fc3e873d4bddd5cdd69fda7f67bd780a0ac55 (diff)
WIP: TLS support
Diffstat (limited to 'include')
-rw-r--r--include/salticidae/conn.h66
-rw-r--r--include/salticidae/crypto.h78
-rw-r--r--include/salticidae/util.h3
3 files changed, 139 insertions, 8 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h
index 48902d4..076d64a 100644
--- a/include/salticidae/conn.h
+++ b/include/salticidae/conn.h
@@ -88,9 +88,20 @@ class ConnPool {
TimerEvent ev_send_wait;
/** does not need to wait if true */
bool ready_send;
+
+ typedef void (socket_io_func)(const conn_t &, int, int);
+ socket_io_func *send_data_func;
+ socket_io_func *recv_data_func;
+ BoxObj<TLS> tls;
- void recv_data(int, int);
- void send_data(int, int);
+ static socket_io_func _recv_data;
+ static socket_io_func _send_data;
+
+ static socket_io_func _recv_data_tls;
+ static socket_io_func _send_data_tls;
+ static socket_io_func _recv_data_tls_handshake;
+ static socket_io_func _send_data_tls_handshake;
+
void conn_server(int, int);
/** Terminate the connection (from the worker thread). */
@@ -99,7 +110,7 @@ class ConnPool {
void disp_terminate();
public:
- Conn(): ready_send(false) {}
+ Conn(): ready_send(false), send_data_func(nullptr), recv_data_func(nullptr) {}
Conn(const Conn &) = delete;
Conn(Conn &&other) = delete;
@@ -158,6 +169,8 @@ class ConnPool {
const double conn_server_timeout;
const size_t seg_buff_size;
const size_t queue_capacity;
+ const bool enable_tls;
+ tls_context_t tls_ctx;
/* owned by user loop */
protected:
@@ -212,6 +225,21 @@ class ConnPool {
std::this_thread::get_id());
return;
}
+ auto cpool = conn->cpool;
+ if (cpool->enable_tls)
+ {
+ conn->tls = new TLS(
+ cpool->tls_ctx, client_fd,
+ conn->mode == Conn::ConnMode::PASSIVE);
+ conn->send_data_func = Conn::_send_data_tls_handshake;
+ conn->recv_data_func = Conn::_recv_data_tls_handshake;
+ }
+ else
+ {
+ conn->send_data_func = Conn::_send_data;
+ conn->recv_data_func = Conn::_recv_data;
+ cpool->update_conn(conn, true);
+ }
assert(conn->fd != -1);
SALTICIDAE_LOG_INFO("worker %x got %s",
std::this_thread::get_id(),
@@ -224,16 +252,16 @@ class ConnPool {
{
conn->ev_socket.del();
conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE);
- conn->send_data(client_fd, FdEvent::WRITE);
+ conn->send_data_func(conn, client_fd, FdEvent::WRITE);
}
return false;
});
conn->ev_socket = FdEvent(ec, client_fd, [this, conn=conn](int fd, int what) {
try {
if (what & FdEvent::READ)
- conn->recv_data(fd, what);
+ conn->recv_data_func(conn, fd, what);
else
- conn->send_data(fd, what);
+ conn->send_data_func(conn, fd, what);
} catch (...) { on_fatal_error(std::current_exception()); }
});
conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE);
@@ -301,6 +329,9 @@ class ConnPool {
size_t _seg_buff_size;
size_t _nworker;
size_t _queue_capacity;
+ bool _enable_tls;
+ std::string _tls_cert_file;
+ std::string _tls_key_file;
public:
Config():
@@ -308,7 +339,10 @@ class ConnPool {
_conn_server_timeout(2),
_seg_buff_size(4096),
_nworker(1),
- _queue_capacity(0) {}
+ _queue_capacity(0),
+ _enable_tls(true),
+ _tls_cert_file("./server.pem"),
+ _tls_key_file("./server.pem") {}
Config &max_listen_backlog(int x) {
_max_listen_backlog = x;
@@ -334,6 +368,11 @@ class ConnPool {
_queue_capacity = x;
return *this;
}
+
+ Config &enable_tls(bool x) {
+ _enable_tls = x;
+ return *this;
+ }
};
ConnPool(const EventContext &ec, const Config &config):
@@ -342,9 +381,19 @@ class ConnPool {
conn_server_timeout(config._conn_server_timeout),
seg_buff_size(config._seg_buff_size),
queue_capacity(config._queue_capacity),
+ enable_tls(config._enable_tls),
+ tls_ctx(nullptr),
listen_fd(-1),
nworker(config._nworker),
system_state(0) {
+ if (enable_tls)
+ {
+ tls_ctx = new TLSContext();
+ tls_ctx->use_cert_file(config._tls_cert_file);
+ tls_ctx->use_priv_key_file(config._tls_key_file);
+ if (!tls_ctx->check_priv_key())
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ }
workers = new Worker[nworker];
user_tcall = new ThreadCall(ec);
disp_ec = workers[0].get_ec();
@@ -353,7 +402,8 @@ class ConnPool {
disp_error_cb = [this](const std::exception_ptr err) {
user_tcall->async_call([this, err](ThreadCall::Handle &) {
stop_workers();
- if (error_cb) error_cb(err, true);
+ std::rethrow_exception(err);
+ //if (error_cb) error_cb(err, true);
});
disp_ec.stop();
workers[0].stop_tcall();
diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h
index 772cce1..1e6daa1 100644
--- a/include/salticidae/crypto.h
+++ b/include/salticidae/crypto.h
@@ -26,7 +26,9 @@
#define _SALTICIDAE_CRYPTO_H
#include "salticidae/type.h"
+#include "salticidae/util.h"
#include <openssl/sha.h>
+#include <openssl/ssl.h>
namespace salticidae {
@@ -114,6 +116,82 @@ class SHA1 {
}
};
+class TLSContext {
+ SSL_CTX *ctx;
+ friend class TLS;
+ public:
+ static void init_tls() { SSL_library_init(); }
+ TLSContext(): ctx(SSL_CTX_new(TLS_method())) {
+ if (ctx == nullptr)
+ throw std::runtime_error("TLSContext init error");
+ }
+
+ void use_cert_file(const std::string &fname) {
+ auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM);
+ if (ret <= 0)
+ throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR);
+ }
+
+ void use_priv_key_file(const std::string &fname) {
+ auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM);
+ if (ret <= 0)
+ throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR);
+ }
+
+ bool check_priv_key() {
+ return SSL_CTX_check_private_key(ctx) > 0;
+ }
+
+ ~TLSContext() { SSL_CTX_free(ctx); }
+};
+
+using tls_context_t = ArcObj<TLSContext>;
+
+class TLS {
+ SSL *ssl;
+ public:
+ TLS(const tls_context_t &ctx, int fd, bool accept): ssl(SSL_new(ctx->ctx)) {
+ if (ssl == nullptr)
+ throw std::runtime_error("TLS init error");
+ if (!SSL_set_fd(ssl, fd))
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ if (accept)
+ SSL_set_accept_state(ssl);
+ else
+ SSL_set_connect_state(ssl);
+ }
+
+ bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */
+ auto ret = SSL_do_handshake(ssl);
+ if (ret == 1) return true;
+ auto err = SSL_get_error(ssl, ret);
+ if (err == SSL_ERROR_WANT_WRITE)
+ want_io_type = 1;
+ else if (err == SSL_ERROR_WANT_READ)
+ want_io_type = 0;
+ else
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ return false;
+ }
+
+ inline int send(const void *buff, size_t size) {
+ return SSL_write(ssl, buff, size);
+ }
+
+ inline int recv(void *buff, size_t size) {
+ return SSL_read(ssl, buff, size);
+ }
+
+ int get_error(int ret) {
+ return SSL_get_error(ssl, ret);
+ }
+
+ ~TLS() {
+ SSL_shutdown(ssl);
+ SSL_free(ssl);
+ }
+};
+
}
#endif
diff --git a/include/salticidae/util.h b/include/salticidae/util.h
index 007fcc4..320c78f 100644
--- a/include/salticidae/util.h
+++ b/include/salticidae/util.h
@@ -83,6 +83,9 @@ enum SalticidaeErrorCode {
SALTI_ERROR_OPT_UNKNOWN_ACTION,
SALTI_ERROR_CONFIG_LINE_TOO_LONG,
SALTI_ERROR_OPT_INVALID,
+ SALTI_ERROR_TLS_CERT_ERROR,
+ SALTI_ERROR_TLS_KEY_ERROR,
+ SALTI_ERROR_TLS_GENERIC_ERROR,
SALTI_ERROR_UNKNOWN
};