aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt6
-rw-r--r--include/salticidae/conn.h66
-rw-r--r--include/salticidae/crypto.h78
-rw-r--r--include/salticidae/util.h3
-rw-r--r--src/conn.cpp144
-rw-r--r--src/util.cpp5
6 files changed, 268 insertions, 34 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4154d66..c3ad2ef 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -46,11 +46,11 @@ if(BUILD_SHARED)
set_property(TARGET salticidae PROPERTY POSITION_INDEPENDENT_CODE 1)
add_library(salticidae_shared SHARED $<TARGET_OBJECTS:salticidae>)
set_target_properties(salticidae_shared PROPERTIES OUTPUT_NAME "salticidae")
- target_link_libraries(salticidae_shared uv crypto pthread)
+ target_link_libraries(salticidae_shared uv crypto ssl pthread)
endif()
add_library(salticidae_static STATIC $<TARGET_OBJECTS:salticidae>)
set_target_properties(salticidae_static PROPERTIES OUTPUT_NAME "salticidae")
-target_link_libraries(salticidae_static uv crypto pthread)
+target_link_libraries(salticidae_static uv crypto ssl pthread)
option(BUILD_TEST "build test binaries." OFF)
if(BUILD_TEST)
@@ -66,7 +66,7 @@ option(SALTICIDAE_NORMAL_LOG "enable regular log" ON)
option(SALTICIDAE_MSG_STAT "enable message statistics" ON)
option(SALTICIDAE_NOCHECK "disable the sanity check" OFF)
option(SALTICIDAE_NOCHECKSUM " disable checksum in messages" OFF)
-option(SALTICIDAE_CBINDINGS "enable C bindings" OFF)
+option(SALTICIDAE_CBINDINGS "enable C bindings" ON)
configure_file(src/config.h.in include/salticidae/config.h @ONLY)
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h
index 48902d4..076d64a 100644
--- a/include/salticidae/conn.h
+++ b/include/salticidae/conn.h
@@ -88,9 +88,20 @@ class ConnPool {
TimerEvent ev_send_wait;
/** does not need to wait if true */
bool ready_send;
+
+ typedef void (socket_io_func)(const conn_t &, int, int);
+ socket_io_func *send_data_func;
+ socket_io_func *recv_data_func;
+ BoxObj<TLS> tls;
- void recv_data(int, int);
- void send_data(int, int);
+ static socket_io_func _recv_data;
+ static socket_io_func _send_data;
+
+ static socket_io_func _recv_data_tls;
+ static socket_io_func _send_data_tls;
+ static socket_io_func _recv_data_tls_handshake;
+ static socket_io_func _send_data_tls_handshake;
+
void conn_server(int, int);
/** Terminate the connection (from the worker thread). */
@@ -99,7 +110,7 @@ class ConnPool {
void disp_terminate();
public:
- Conn(): ready_send(false) {}
+ Conn(): ready_send(false), send_data_func(nullptr), recv_data_func(nullptr) {}
Conn(const Conn &) = delete;
Conn(Conn &&other) = delete;
@@ -158,6 +169,8 @@ class ConnPool {
const double conn_server_timeout;
const size_t seg_buff_size;
const size_t queue_capacity;
+ const bool enable_tls;
+ tls_context_t tls_ctx;
/* owned by user loop */
protected:
@@ -212,6 +225,21 @@ class ConnPool {
std::this_thread::get_id());
return;
}
+ auto cpool = conn->cpool;
+ if (cpool->enable_tls)
+ {
+ conn->tls = new TLS(
+ cpool->tls_ctx, client_fd,
+ conn->mode == Conn::ConnMode::PASSIVE);
+ conn->send_data_func = Conn::_send_data_tls_handshake;
+ conn->recv_data_func = Conn::_recv_data_tls_handshake;
+ }
+ else
+ {
+ conn->send_data_func = Conn::_send_data;
+ conn->recv_data_func = Conn::_recv_data;
+ cpool->update_conn(conn, true);
+ }
assert(conn->fd != -1);
SALTICIDAE_LOG_INFO("worker %x got %s",
std::this_thread::get_id(),
@@ -224,16 +252,16 @@ class ConnPool {
{
conn->ev_socket.del();
conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE);
- conn->send_data(client_fd, FdEvent::WRITE);
+ conn->send_data_func(conn, client_fd, FdEvent::WRITE);
}
return false;
});
conn->ev_socket = FdEvent(ec, client_fd, [this, conn=conn](int fd, int what) {
try {
if (what & FdEvent::READ)
- conn->recv_data(fd, what);
+ conn->recv_data_func(conn, fd, what);
else
- conn->send_data(fd, what);
+ conn->send_data_func(conn, fd, what);
} catch (...) { on_fatal_error(std::current_exception()); }
});
conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE);
@@ -301,6 +329,9 @@ class ConnPool {
size_t _seg_buff_size;
size_t _nworker;
size_t _queue_capacity;
+ bool _enable_tls;
+ std::string _tls_cert_file;
+ std::string _tls_key_file;
public:
Config():
@@ -308,7 +339,10 @@ class ConnPool {
_conn_server_timeout(2),
_seg_buff_size(4096),
_nworker(1),
- _queue_capacity(0) {}
+ _queue_capacity(0),
+ _enable_tls(true),
+ _tls_cert_file("./server.pem"),
+ _tls_key_file("./server.pem") {}
Config &max_listen_backlog(int x) {
_max_listen_backlog = x;
@@ -334,6 +368,11 @@ class ConnPool {
_queue_capacity = x;
return *this;
}
+
+ Config &enable_tls(bool x) {
+ _enable_tls = x;
+ return *this;
+ }
};
ConnPool(const EventContext &ec, const Config &config):
@@ -342,9 +381,19 @@ class ConnPool {
conn_server_timeout(config._conn_server_timeout),
seg_buff_size(config._seg_buff_size),
queue_capacity(config._queue_capacity),
+ enable_tls(config._enable_tls),
+ tls_ctx(nullptr),
listen_fd(-1),
nworker(config._nworker),
system_state(0) {
+ if (enable_tls)
+ {
+ tls_ctx = new TLSContext();
+ tls_ctx->use_cert_file(config._tls_cert_file);
+ tls_ctx->use_priv_key_file(config._tls_key_file);
+ if (!tls_ctx->check_priv_key())
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ }
workers = new Worker[nworker];
user_tcall = new ThreadCall(ec);
disp_ec = workers[0].get_ec();
@@ -353,7 +402,8 @@ class ConnPool {
disp_error_cb = [this](const std::exception_ptr err) {
user_tcall->async_call([this, err](ThreadCall::Handle &) {
stop_workers();
- if (error_cb) error_cb(err, true);
+ std::rethrow_exception(err);
+ //if (error_cb) error_cb(err, true);
});
disp_ec.stop();
workers[0].stop_tcall();
diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h
index 772cce1..1e6daa1 100644
--- a/include/salticidae/crypto.h
+++ b/include/salticidae/crypto.h
@@ -26,7 +26,9 @@
#define _SALTICIDAE_CRYPTO_H
#include "salticidae/type.h"
+#include "salticidae/util.h"
#include <openssl/sha.h>
+#include <openssl/ssl.h>
namespace salticidae {
@@ -114,6 +116,82 @@ class SHA1 {
}
};
+class TLSContext {
+ SSL_CTX *ctx;
+ friend class TLS;
+ public:
+ static void init_tls() { SSL_library_init(); }
+ TLSContext(): ctx(SSL_CTX_new(TLS_method())) {
+ if (ctx == nullptr)
+ throw std::runtime_error("TLSContext init error");
+ }
+
+ void use_cert_file(const std::string &fname) {
+ auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM);
+ if (ret <= 0)
+ throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR);
+ }
+
+ void use_priv_key_file(const std::string &fname) {
+ auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM);
+ if (ret <= 0)
+ throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR);
+ }
+
+ bool check_priv_key() {
+ return SSL_CTX_check_private_key(ctx) > 0;
+ }
+
+ ~TLSContext() { SSL_CTX_free(ctx); }
+};
+
+using tls_context_t = ArcObj<TLSContext>;
+
+class TLS {
+ SSL *ssl;
+ public:
+ TLS(const tls_context_t &ctx, int fd, bool accept): ssl(SSL_new(ctx->ctx)) {
+ if (ssl == nullptr)
+ throw std::runtime_error("TLS init error");
+ if (!SSL_set_fd(ssl, fd))
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ if (accept)
+ SSL_set_accept_state(ssl);
+ else
+ SSL_set_connect_state(ssl);
+ }
+
+ bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */
+ auto ret = SSL_do_handshake(ssl);
+ if (ret == 1) return true;
+ auto err = SSL_get_error(ssl, ret);
+ if (err == SSL_ERROR_WANT_WRITE)
+ want_io_type = 1;
+ else if (err == SSL_ERROR_WANT_READ)
+ want_io_type = 0;
+ else
+ throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR);
+ return false;
+ }
+
+ inline int send(const void *buff, size_t size) {
+ return SSL_write(ssl, buff, size);
+ }
+
+ inline int recv(void *buff, size_t size) {
+ return SSL_read(ssl, buff, size);
+ }
+
+ int get_error(int ret) {
+ return SSL_get_error(ssl, ret);
+ }
+
+ ~TLS() {
+ SSL_shutdown(ssl);
+ SSL_free(ssl);
+ }
+};
+
}
#endif
diff --git a/include/salticidae/util.h b/include/salticidae/util.h
index 007fcc4..320c78f 100644
--- a/include/salticidae/util.h
+++ b/include/salticidae/util.h
@@ -83,6 +83,9 @@ enum SalticidaeErrorCode {
SALTI_ERROR_OPT_UNKNOWN_ACTION,
SALTI_ERROR_CONFIG_LINE_TOO_LONG,
SALTI_ERROR_OPT_INVALID,
+ SALTI_ERROR_TLS_CERT_ERROR,
+ SALTI_ERROR_TLS_KEY_ERROR,
+ SALTI_ERROR_TLS_GENERIC_ERROR,
SALTI_ERROR_UNKNOWN
};
diff --git a/src/conn.cpp b/src/conn.cpp
index 7f485fd..3ec4284 100644
--- a/src/conn.cpp
+++ b/src/conn.cpp
@@ -51,19 +51,18 @@ ConnPool::Conn::operator std::string() const {
return std::move(s);
}
-/* the following two functions are executed by exactly one worker per Conn object */
+/* the following functions are executed by exactly one worker per Conn object */
-void ConnPool::Conn::send_data(int fd, int events) {
+void ConnPool::Conn::_send_data(const ConnPool::conn_t &conn, int fd, int events) {
if (events & FdEvent::ERROR)
{
- worker_terminate();
+ conn->worker_terminate();
return;
}
- auto conn = self(); /* pin the connection */
- ssize_t ret = seg_buff_size;
+ ssize_t ret = conn->seg_buff_size;
for (;;)
{
- bytearray_t buff_seg = send_buffer.move_pop();
+ bytearray_t buff_seg = conn->send_buffer.move_pop();
ssize_t size = buff_seg.size();
if (!size) break;
ret = send(fd, buff_seg.data(), size, MSG_NOSIGNAL);
@@ -74,37 +73,37 @@ void ConnPool::Conn::send_data(int fd, int events) {
if (ret < 1) /* nothing is sent */
{
/* rewind the whole buff_seg */
- send_buffer.rewind(std::move(buff_seg));
+ conn->send_buffer.rewind(std::move(buff_seg));
if (ret < 0 && errno != EWOULDBLOCK)
{
SALTICIDAE_LOG_INFO("send(%d) failure: %s", fd, strerror(errno));
- worker_terminate();
+ conn->worker_terminate();
return;
}
}
else
/* rewind the leftover */
- send_buffer.rewind(
+ conn->send_buffer.rewind(
bytearray_t(buff_seg.begin() + ret, buff_seg.end()));
/* wait for the next write callback */
- ready_send = false;
+ conn->ready_send = false;
//ev_write.add();
return;
}
}
- ev_socket.del();
- ev_socket.add(FdEvent::READ);
+ conn->ev_socket.del();
+ conn->ev_socket.add(FdEvent::READ);
/* consumed the buffer but endpoint still seems to be writable */
- ready_send = true;
+ conn->ready_send = true;
}
-void ConnPool::Conn::recv_data(int fd, int events) {
+void ConnPool::Conn::_recv_data(const ConnPool::conn_t &conn, int fd, int events) {
if (events & FdEvent::ERROR)
{
- worker_terminate();
+ conn->worker_terminate();
return;
}
- auto conn = self(); /* pin the connection */
+ const size_t seg_buff_size = conn->seg_buff_size;
ssize_t ret = seg_buff_size;
while (ret == (ssize_t)seg_buff_size)
{
@@ -117,21 +116,122 @@ void ConnPool::Conn::recv_data(int fd, int events) {
if (errno == EWOULDBLOCK) break;
SALTICIDAE_LOG_INFO("recv(%d) failure: %s", fd, strerror(errno));
/* connection err or half-opened connection */
- worker_terminate();
+ conn->worker_terminate();
return;
}
if (ret == 0)
{
//SALTICIDAE_LOG_INFO("recv(%d) terminates", fd, strerror(errno));
- worker_terminate();
+ conn->worker_terminate();
return;
}
buff_seg.resize(ret);
- recv_buffer.push(std::move(buff_seg));
+ conn->recv_buffer.push(std::move(buff_seg));
}
//ev_read.add();
- on_read();
+ conn->on_read();
+}
+
+
+void ConnPool::Conn::_send_data_tls(const ConnPool::conn_t &conn, int fd, int events) {
+ if (events & FdEvent::ERROR)
+ {
+ conn->worker_terminate();
+ return;
+ }
+ ssize_t ret = conn->seg_buff_size;
+ auto &tls = conn->tls;
+ for (;;)
+ {
+ bytearray_t buff_seg = conn->send_buffer.move_pop();
+ ssize_t size = buff_seg.size();
+ if (!size) break;
+ ret = tls->send(buff_seg.data(), size);
+ SALTICIDAE_LOG_DEBUG("ssl sent %zd bytes", ret);
+ size -= ret;
+ if (size > 0)
+ {
+ if (ret < 1) /* nothing is sent */
+ {
+ /* rewind the whole buff_seg */
+ conn->send_buffer.rewind(std::move(buff_seg));
+ if (ret < 0 && tls->get_error(ret) != SSL_ERROR_WANT_WRITE)
+ {
+ SALTICIDAE_LOG_INFO("send(%d) failure: %s", fd, strerror(errno));
+ conn->worker_terminate();
+ return;
+ }
+ }
+ else
+ /* rewind the leftover */
+ conn->send_buffer.rewind(
+ bytearray_t(buff_seg.begin() + ret, buff_seg.end()));
+ /* wait for the next write callback */
+ conn->ready_send = false;
+ return;
+ }
+ }
+ conn->ev_socket.del();
+ conn->ev_socket.add(FdEvent::READ);
+ /* consumed the buffer but endpoint still seems to be writable */
+ conn->ready_send = true;
+}
+
+void ConnPool::Conn::_recv_data_tls(const ConnPool::conn_t &conn, int fd, int events) {
+ if (events & FdEvent::ERROR)
+ {
+ conn->worker_terminate();
+ return;
+ }
+ const size_t seg_buff_size = conn->seg_buff_size;
+ ssize_t ret = seg_buff_size;
+ auto &tls = conn->tls;
+ while (ret == (ssize_t)seg_buff_size)
+ {
+ bytearray_t buff_seg;
+ buff_seg.resize(seg_buff_size);
+ ret = tls->recv(buff_seg.data(), seg_buff_size);
+ SALTICIDAE_LOG_DEBUG("ssl read %zd bytes", ret);
+ if (ret < 0)
+ {
+ if (tls->get_error(ret) == SSL_ERROR_WANT_READ) break;
+ SALTICIDAE_LOG_INFO("recv(%d) failure: %s", fd, strerror(errno));
+ /* connection err or half-opened connection */
+ conn->worker_terminate();
+ return;
+ }
+ if (ret == 0)
+ {
+ conn->worker_terminate();
+ return;
+ }
+ buff_seg.resize(ret);
+ conn->recv_buffer.push(std::move(buff_seg));
+ }
+ conn->on_read();
+}
+
+void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int fd, int events) {
+ int ret;
+ if (conn->tls->do_handshake(ret))
+ {
+ conn->send_data_func = _send_data_tls;
+ conn->recv_data_func = _recv_data_tls;
+ conn->cpool->update_conn(conn, true);
+ }
+ else
+ {
+ conn->ev_socket.del();
+ conn->ev_socket.add(ret == 0 ? FdEvent::READ : FdEvent::WRITE);
+ SALTICIDAE_LOG_INFO("tls handshake %d", ret);
+ }
+}
+
+void ConnPool::Conn::_recv_data_tls_handshake(const ConnPool::conn_t &conn, int fd, int events) {
+ conn->ready_send = true;
+ _send_data_tls_handshake(conn, fd, events);
}
+/****/
void ConnPool::Conn::stop() {
if (mode != ConnMode::DEAD)
@@ -188,6 +288,7 @@ void ConnPool::accept_client(int fd, int) {
conn->send_buffer.set_capacity(queue_capacity);
conn->seg_buff_size = seg_buff_size;
conn->fd = client_fd;
+ conn->worker = nullptr;
conn->cpool = this;
conn->mode = Conn::PASSIVE;
conn->addr = addr;
@@ -196,7 +297,6 @@ void ConnPool::accept_client(int fd, int) {
auto &worker = select_worker();
conn->worker = &worker;
conn->on_setup();
- update_conn(conn, true);
worker.feed(conn, client_fd);
}
} catch (ConnPoolError &e) {
@@ -214,7 +314,6 @@ void ConnPool::Conn::conn_server(int fd, int events) {
SALTICIDAE_LOG_INFO("connected to remote %s", std::string(*this).c_str());
worker = &(cpool->select_worker());
on_setup();
- cpool->update_conn(conn, true);
worker->feed(conn, fd);
}
else
@@ -282,6 +381,7 @@ ConnPool::conn_t ConnPool::_connect(const NetAddr &addr) {
conn->send_buffer.set_capacity(queue_capacity);
conn->seg_buff_size = seg_buff_size;
conn->fd = fd;
+ conn->worker = nullptr;
conn->cpool = this;
conn->mode = Conn::ACTIVE;
conn->addr = addr;
diff --git a/src/util.cpp b/src/util.cpp
index 874eb41..fde326a 100644
--- a/src/util.cpp
+++ b/src/util.cpp
@@ -47,7 +47,10 @@ const char *SALTICIDAE_ERROR_STRINGS[] = {
"option name already exists",
"unknown action",
"configuration file line too long",
- "invalid option format"
+ "invalid option format",
+ "unable to load cert",
+ "uable to load key",
+ "tls generic error"
};
const char *TTY_COLOR_RED = "\x1b[31m";