diff options
-rw-r--r-- | CMakeLists.txt | 6 | ||||
-rw-r--r-- | include/salticidae/conn.h | 66 | ||||
-rw-r--r-- | include/salticidae/crypto.h | 78 | ||||
-rw-r--r-- | include/salticidae/util.h | 3 | ||||
-rw-r--r-- | src/conn.cpp | 144 | ||||
-rw-r--r-- | src/util.cpp | 5 |
6 files changed, 268 insertions, 34 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 4154d66..c3ad2ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,11 +46,11 @@ if(BUILD_SHARED) set_property(TARGET salticidae PROPERTY POSITION_INDEPENDENT_CODE 1) add_library(salticidae_shared SHARED $<TARGET_OBJECTS:salticidae>) set_target_properties(salticidae_shared PROPERTIES OUTPUT_NAME "salticidae") - target_link_libraries(salticidae_shared uv crypto pthread) + target_link_libraries(salticidae_shared uv crypto ssl pthread) endif() add_library(salticidae_static STATIC $<TARGET_OBJECTS:salticidae>) set_target_properties(salticidae_static PROPERTIES OUTPUT_NAME "salticidae") -target_link_libraries(salticidae_static uv crypto pthread) +target_link_libraries(salticidae_static uv crypto ssl pthread) option(BUILD_TEST "build test binaries." OFF) if(BUILD_TEST) @@ -66,7 +66,7 @@ option(SALTICIDAE_NORMAL_LOG "enable regular log" ON) option(SALTICIDAE_MSG_STAT "enable message statistics" ON) option(SALTICIDAE_NOCHECK "disable the sanity check" OFF) option(SALTICIDAE_NOCHECKSUM " disable checksum in messages" OFF) -option(SALTICIDAE_CBINDINGS "enable C bindings" OFF) +option(SALTICIDAE_CBINDINGS "enable C bindings" ON) configure_file(src/config.h.in include/salticidae/config.h @ONLY) diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 48902d4..076d64a 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -88,9 +88,20 @@ class ConnPool { TimerEvent ev_send_wait; /** does not need to wait if true */ bool ready_send; + + typedef void (socket_io_func)(const conn_t &, int, int); + socket_io_func *send_data_func; + socket_io_func *recv_data_func; + BoxObj<TLS> tls; - void recv_data(int, int); - void send_data(int, int); + static socket_io_func _recv_data; + static socket_io_func _send_data; + + static socket_io_func _recv_data_tls; + static socket_io_func _send_data_tls; + static socket_io_func _recv_data_tls_handshake; + static socket_io_func _send_data_tls_handshake; + void conn_server(int, int); /** Terminate the connection (from the worker thread). */ @@ -99,7 +110,7 @@ class ConnPool { void disp_terminate(); public: - Conn(): ready_send(false) {} + Conn(): ready_send(false), send_data_func(nullptr), recv_data_func(nullptr) {} Conn(const Conn &) = delete; Conn(Conn &&other) = delete; @@ -158,6 +169,8 @@ class ConnPool { const double conn_server_timeout; const size_t seg_buff_size; const size_t queue_capacity; + const bool enable_tls; + tls_context_t tls_ctx; /* owned by user loop */ protected: @@ -212,6 +225,21 @@ class ConnPool { std::this_thread::get_id()); return; } + auto cpool = conn->cpool; + if (cpool->enable_tls) + { + conn->tls = new TLS( + cpool->tls_ctx, client_fd, + conn->mode == Conn::ConnMode::PASSIVE); + conn->send_data_func = Conn::_send_data_tls_handshake; + conn->recv_data_func = Conn::_recv_data_tls_handshake; + } + else + { + conn->send_data_func = Conn::_send_data; + conn->recv_data_func = Conn::_recv_data; + cpool->update_conn(conn, true); + } assert(conn->fd != -1); SALTICIDAE_LOG_INFO("worker %x got %s", std::this_thread::get_id(), @@ -224,16 +252,16 @@ class ConnPool { { conn->ev_socket.del(); conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); - conn->send_data(client_fd, FdEvent::WRITE); + conn->send_data_func(conn, client_fd, FdEvent::WRITE); } return false; }); conn->ev_socket = FdEvent(ec, client_fd, [this, conn=conn](int fd, int what) { try { if (what & FdEvent::READ) - conn->recv_data(fd, what); + conn->recv_data_func(conn, fd, what); else - conn->send_data(fd, what); + conn->send_data_func(conn, fd, what); } catch (...) { on_fatal_error(std::current_exception()); } }); conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE); @@ -301,6 +329,9 @@ class ConnPool { size_t _seg_buff_size; size_t _nworker; size_t _queue_capacity; + bool _enable_tls; + std::string _tls_cert_file; + std::string _tls_key_file; public: Config(): @@ -308,7 +339,10 @@ class ConnPool { _conn_server_timeout(2), _seg_buff_size(4096), _nworker(1), - _queue_capacity(0) {} + _queue_capacity(0), + _enable_tls(true), + _tls_cert_file("./server.pem"), + _tls_key_file("./server.pem") {} Config &max_listen_backlog(int x) { _max_listen_backlog = x; @@ -334,6 +368,11 @@ class ConnPool { _queue_capacity = x; return *this; } + + Config &enable_tls(bool x) { + _enable_tls = x; + return *this; + } }; ConnPool(const EventContext &ec, const Config &config): @@ -342,9 +381,19 @@ class ConnPool { conn_server_timeout(config._conn_server_timeout), seg_buff_size(config._seg_buff_size), queue_capacity(config._queue_capacity), + enable_tls(config._enable_tls), + tls_ctx(nullptr), listen_fd(-1), nworker(config._nworker), system_state(0) { + if (enable_tls) + { + tls_ctx = new TLSContext(); + tls_ctx->use_cert_file(config._tls_cert_file); + tls_ctx->use_priv_key_file(config._tls_key_file); + if (!tls_ctx->check_priv_key()) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + } workers = new Worker[nworker]; user_tcall = new ThreadCall(ec); disp_ec = workers[0].get_ec(); @@ -353,7 +402,8 @@ class ConnPool { disp_error_cb = [this](const std::exception_ptr err) { user_tcall->async_call([this, err](ThreadCall::Handle &) { stop_workers(); - if (error_cb) error_cb(err, true); + std::rethrow_exception(err); + //if (error_cb) error_cb(err, true); }); disp_ec.stop(); workers[0].stop_tcall(); diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index 772cce1..1e6daa1 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -26,7 +26,9 @@ #define _SALTICIDAE_CRYPTO_H #include "salticidae/type.h" +#include "salticidae/util.h" #include <openssl/sha.h> +#include <openssl/ssl.h> namespace salticidae { @@ -114,6 +116,82 @@ class SHA1 { } }; +class TLSContext { + SSL_CTX *ctx; + friend class TLS; + public: + static void init_tls() { SSL_library_init(); } + TLSContext(): ctx(SSL_CTX_new(TLS_method())) { + if (ctx == nullptr) + throw std::runtime_error("TLSContext init error"); + } + + void use_cert_file(const std::string &fname) { + auto ret = SSL_CTX_use_certificate_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_CERT_ERROR); + } + + void use_priv_key_file(const std::string &fname) { + auto ret = SSL_CTX_use_PrivateKey_file(ctx, fname.c_str(), SSL_FILETYPE_PEM); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_KEY_ERROR); + } + + bool check_priv_key() { + return SSL_CTX_check_private_key(ctx) > 0; + } + + ~TLSContext() { SSL_CTX_free(ctx); } +}; + +using tls_context_t = ArcObj<TLSContext>; + +class TLS { + SSL *ssl; + public: + TLS(const tls_context_t &ctx, int fd, bool accept): ssl(SSL_new(ctx->ctx)) { + if (ssl == nullptr) + throw std::runtime_error("TLS init error"); + if (!SSL_set_fd(ssl, fd)) + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + if (accept) + SSL_set_accept_state(ssl); + else + SSL_set_connect_state(ssl); + } + + bool do_handshake(int &want_io_type) { /* 0 for read, 1 for write */ + auto ret = SSL_do_handshake(ssl); + if (ret == 1) return true; + auto err = SSL_get_error(ssl, ret); + if (err == SSL_ERROR_WANT_WRITE) + want_io_type = 1; + else if (err == SSL_ERROR_WANT_READ) + want_io_type = 0; + else + throw SalticidaeError(SALTI_ERROR_TLS_GENERIC_ERROR); + return false; + } + + inline int send(const void *buff, size_t size) { + return SSL_write(ssl, buff, size); + } + + inline int recv(void *buff, size_t size) { + return SSL_read(ssl, buff, size); + } + + int get_error(int ret) { + return SSL_get_error(ssl, ret); + } + + ~TLS() { + SSL_shutdown(ssl); + SSL_free(ssl); + } +}; + } #endif diff --git a/include/salticidae/util.h b/include/salticidae/util.h index 007fcc4..320c78f 100644 --- a/include/salticidae/util.h +++ b/include/salticidae/util.h @@ -83,6 +83,9 @@ enum SalticidaeErrorCode { SALTI_ERROR_OPT_UNKNOWN_ACTION, SALTI_ERROR_CONFIG_LINE_TOO_LONG, SALTI_ERROR_OPT_INVALID, + SALTI_ERROR_TLS_CERT_ERROR, + SALTI_ERROR_TLS_KEY_ERROR, + SALTI_ERROR_TLS_GENERIC_ERROR, SALTI_ERROR_UNKNOWN }; diff --git a/src/conn.cpp b/src/conn.cpp index 7f485fd..3ec4284 100644 --- a/src/conn.cpp +++ b/src/conn.cpp @@ -51,19 +51,18 @@ ConnPool::Conn::operator std::string() const { return std::move(s); } -/* the following two functions are executed by exactly one worker per Conn object */ +/* the following functions are executed by exactly one worker per Conn object */ -void ConnPool::Conn::send_data(int fd, int events) { +void ConnPool::Conn::_send_data(const ConnPool::conn_t &conn, int fd, int events) { if (events & FdEvent::ERROR) { - worker_terminate(); + conn->worker_terminate(); return; } - auto conn = self(); /* pin the connection */ - ssize_t ret = seg_buff_size; + ssize_t ret = conn->seg_buff_size; for (;;) { - bytearray_t buff_seg = send_buffer.move_pop(); + bytearray_t buff_seg = conn->send_buffer.move_pop(); ssize_t size = buff_seg.size(); if (!size) break; ret = send(fd, buff_seg.data(), size, MSG_NOSIGNAL); @@ -74,37 +73,37 @@ void ConnPool::Conn::send_data(int fd, int events) { if (ret < 1) /* nothing is sent */ { /* rewind the whole buff_seg */ - send_buffer.rewind(std::move(buff_seg)); + conn->send_buffer.rewind(std::move(buff_seg)); if (ret < 0 && errno != EWOULDBLOCK) { SALTICIDAE_LOG_INFO("send(%d) failure: %s", fd, strerror(errno)); - worker_terminate(); + conn->worker_terminate(); return; } } else /* rewind the leftover */ - send_buffer.rewind( + conn->send_buffer.rewind( bytearray_t(buff_seg.begin() + ret, buff_seg.end())); /* wait for the next write callback */ - ready_send = false; + conn->ready_send = false; //ev_write.add(); return; } } - ev_socket.del(); - ev_socket.add(FdEvent::READ); + conn->ev_socket.del(); + conn->ev_socket.add(FdEvent::READ); /* consumed the buffer but endpoint still seems to be writable */ - ready_send = true; + conn->ready_send = true; } -void ConnPool::Conn::recv_data(int fd, int events) { +void ConnPool::Conn::_recv_data(const ConnPool::conn_t &conn, int fd, int events) { if (events & FdEvent::ERROR) { - worker_terminate(); + conn->worker_terminate(); return; } - auto conn = self(); /* pin the connection */ + const size_t seg_buff_size = conn->seg_buff_size; ssize_t ret = seg_buff_size; while (ret == (ssize_t)seg_buff_size) { @@ -117,21 +116,122 @@ void ConnPool::Conn::recv_data(int fd, int events) { if (errno == EWOULDBLOCK) break; SALTICIDAE_LOG_INFO("recv(%d) failure: %s", fd, strerror(errno)); /* connection err or half-opened connection */ - worker_terminate(); + conn->worker_terminate(); return; } if (ret == 0) { //SALTICIDAE_LOG_INFO("recv(%d) terminates", fd, strerror(errno)); - worker_terminate(); + conn->worker_terminate(); return; } buff_seg.resize(ret); - recv_buffer.push(std::move(buff_seg)); + conn->recv_buffer.push(std::move(buff_seg)); } //ev_read.add(); - on_read(); + conn->on_read(); +} + + +void ConnPool::Conn::_send_data_tls(const ConnPool::conn_t &conn, int fd, int events) { + if (events & FdEvent::ERROR) + { + conn->worker_terminate(); + return; + } + ssize_t ret = conn->seg_buff_size; + auto &tls = conn->tls; + for (;;) + { + bytearray_t buff_seg = conn->send_buffer.move_pop(); + ssize_t size = buff_seg.size(); + if (!size) break; + ret = tls->send(buff_seg.data(), size); + SALTICIDAE_LOG_DEBUG("ssl sent %zd bytes", ret); + size -= ret; + if (size > 0) + { + if (ret < 1) /* nothing is sent */ + { + /* rewind the whole buff_seg */ + conn->send_buffer.rewind(std::move(buff_seg)); + if (ret < 0 && tls->get_error(ret) != SSL_ERROR_WANT_WRITE) + { + SALTICIDAE_LOG_INFO("send(%d) failure: %s", fd, strerror(errno)); + conn->worker_terminate(); + return; + } + } + else + /* rewind the leftover */ + conn->send_buffer.rewind( + bytearray_t(buff_seg.begin() + ret, buff_seg.end())); + /* wait for the next write callback */ + conn->ready_send = false; + return; + } + } + conn->ev_socket.del(); + conn->ev_socket.add(FdEvent::READ); + /* consumed the buffer but endpoint still seems to be writable */ + conn->ready_send = true; +} + +void ConnPool::Conn::_recv_data_tls(const ConnPool::conn_t &conn, int fd, int events) { + if (events & FdEvent::ERROR) + { + conn->worker_terminate(); + return; + } + const size_t seg_buff_size = conn->seg_buff_size; + ssize_t ret = seg_buff_size; + auto &tls = conn->tls; + while (ret == (ssize_t)seg_buff_size) + { + bytearray_t buff_seg; + buff_seg.resize(seg_buff_size); + ret = tls->recv(buff_seg.data(), seg_buff_size); + SALTICIDAE_LOG_DEBUG("ssl read %zd bytes", ret); + if (ret < 0) + { + if (tls->get_error(ret) == SSL_ERROR_WANT_READ) break; + SALTICIDAE_LOG_INFO("recv(%d) failure: %s", fd, strerror(errno)); + /* connection err or half-opened connection */ + conn->worker_terminate(); + return; + } + if (ret == 0) + { + conn->worker_terminate(); + return; + } + buff_seg.resize(ret); + conn->recv_buffer.push(std::move(buff_seg)); + } + conn->on_read(); +} + +void ConnPool::Conn::_send_data_tls_handshake(const ConnPool::conn_t &conn, int fd, int events) { + int ret; + if (conn->tls->do_handshake(ret)) + { + conn->send_data_func = _send_data_tls; + conn->recv_data_func = _recv_data_tls; + conn->cpool->update_conn(conn, true); + } + else + { + conn->ev_socket.del(); + conn->ev_socket.add(ret == 0 ? FdEvent::READ : FdEvent::WRITE); + SALTICIDAE_LOG_INFO("tls handshake %d", ret); + } +} + +void ConnPool::Conn::_recv_data_tls_handshake(const ConnPool::conn_t &conn, int fd, int events) { + conn->ready_send = true; + _send_data_tls_handshake(conn, fd, events); } +/****/ void ConnPool::Conn::stop() { if (mode != ConnMode::DEAD) @@ -188,6 +288,7 @@ void ConnPool::accept_client(int fd, int) { conn->send_buffer.set_capacity(queue_capacity); conn->seg_buff_size = seg_buff_size; conn->fd = client_fd; + conn->worker = nullptr; conn->cpool = this; conn->mode = Conn::PASSIVE; conn->addr = addr; @@ -196,7 +297,6 @@ void ConnPool::accept_client(int fd, int) { auto &worker = select_worker(); conn->worker = &worker; conn->on_setup(); - update_conn(conn, true); worker.feed(conn, client_fd); } } catch (ConnPoolError &e) { @@ -214,7 +314,6 @@ void ConnPool::Conn::conn_server(int fd, int events) { SALTICIDAE_LOG_INFO("connected to remote %s", std::string(*this).c_str()); worker = &(cpool->select_worker()); on_setup(); - cpool->update_conn(conn, true); worker->feed(conn, fd); } else @@ -282,6 +381,7 @@ ConnPool::conn_t ConnPool::_connect(const NetAddr &addr) { conn->send_buffer.set_capacity(queue_capacity); conn->seg_buff_size = seg_buff_size; conn->fd = fd; + conn->worker = nullptr; conn->cpool = this; conn->mode = Conn::ACTIVE; conn->addr = addr; diff --git a/src/util.cpp b/src/util.cpp index 874eb41..fde326a 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -47,7 +47,10 @@ const char *SALTICIDAE_ERROR_STRINGS[] = { "option name already exists", "unknown action", "configuration file line too long", - "invalid option format" + "invalid option format", + "unable to load cert", + "uable to load key", + "tls generic error" }; const char *TTY_COLOR_RED = "\x1b[31m"; |