aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
Diffstat (limited to 'include')
-rw-r--r--include/salticidae/conn.h63
-rw-r--r--include/salticidae/crypto.h19
-rw-r--r--include/salticidae/network.h2
-rw-r--r--include/salticidae/util.h2
4 files changed, 69 insertions, 17 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h
index a791057..59d93fc 100644
--- a/include/salticidae/conn.h
+++ b/include/salticidae/conn.h
@@ -58,7 +58,7 @@ class ConnPool {
/** The handle to a bi-directional connection. */
using conn_t = ArcObj<Conn>;
/** The type of callback invoked when connection status is changed. */
- using conn_callback_t = std::function<void(const conn_t &, bool)>;
+ using conn_callback_t = std::function<bool(const conn_t &, bool)>;
using error_callback_t = std::function<void(const std::exception_ptr, bool)>;
/** Abstraction for a bi-directional connection. */
class Conn {
@@ -93,7 +93,7 @@ class ConnPool {
socket_io_func *send_data_func;
socket_io_func *recv_data_func;
BoxObj<TLS> tls;
- BoxObj<X509> peer_cert;
+ BoxObj<const X509> peer_cert;
static socket_io_func _recv_data;
static socket_io_func _send_data;
@@ -102,6 +102,7 @@ class ConnPool {
static socket_io_func _send_data_tls;
static socket_io_func _recv_data_tls_handshake;
static socket_io_func _send_data_tls_handshake;
+ static socket_io_func _recv_data_dummy;
void conn_server(int, int);
@@ -189,7 +190,12 @@ class ConnPool {
void update_conn(const conn_t &conn, bool connected) {
user_tcall->async_call([this, conn, connected](ThreadCall::Handle &) {
- if (conn_cb) conn_cb(conn, connected);
+ if ((!conn_cb ||
+ conn_cb(conn, connected)) &&
+ enable_tls && connected)
+ conn->worker->get_tcall()->async_call([conn](ThreadCall::Handle &) {
+ conn->recv_data_func = Conn::_recv_data_tls;
+ });
});
}
@@ -264,7 +270,10 @@ class ConnPool {
conn->recv_data_func(conn, fd, what);
else
conn->send_data_func(conn, fd, what);
- } catch (...) { on_fatal_error(std::current_exception()); }
+ } catch (...) {
+ conn->cpool->recoverable_error(std::current_exception());
+ conn->worker_terminate();
+ }
});
conn->ev_socket.add(FdEvent::READ | FdEvent::WRITE);
nconn++;
@@ -336,6 +345,8 @@ class ConnPool {
std::string _tls_key_file;
RcObj<X509> _tls_cert;
RcObj<PKey> _tls_key;
+ bool _tls_skip_ca_check;
+ SSL_verify_cb _tls_verify_callback;
public:
Config():
@@ -344,11 +355,13 @@ class ConnPool {
_seg_buff_size(4096),
_nworker(1),
_queue_capacity(0),
- _enable_tls(true),
- _tls_cert_file("./all.pem"),
- _tls_key_file("./all.pem"),
+ _enable_tls(false),
+ _tls_cert_file(""),
+ _tls_key_file(""),
_tls_cert(nullptr),
- _tls_key(nullptr) {}
+ _tls_key(nullptr),
+ _tls_skip_ca_check(true),
+ _tls_verify_callback(nullptr) {}
Config &max_listen_backlog(int x) {
_max_listen_backlog = x;
@@ -379,6 +392,36 @@ class ConnPool {
_enable_tls = x;
return *this;
}
+
+ Config &tls_cert_file(const std::string &x) {
+ _tls_cert_file = x;
+ return *this;
+ }
+
+ Config &tls_key_file(const std::string &x) {
+ _tls_key_file = x;
+ return *this;
+ }
+
+ Config &tls_cert(X509 *x) {
+ _tls_cert = x;
+ return *this;
+ }
+
+ Config &tls_key(PKey *x) {
+ _tls_key = x;
+ return *this;
+ }
+
+ Config &tls_skip_ca_check(bool *x) {
+ _tls_skip_ca_check = x;
+ return *this;
+ }
+
+ Config &tls_verify_callback(SSL_verify_cb x) {
+ _tls_verify_callback = x;
+ return *this;
+ }
};
ConnPool(const EventContext &ec, const Config &config):
@@ -403,9 +446,11 @@ class ConnPool {
tls_ctx->use_privkey(*config._tls_key);
else
tls_ctx->use_privkey_file(config._tls_key_file);
+ tls_ctx->set_verify(config._tls_skip_ca_check, config._tls_verify_callback);
if (!tls_ctx->check_privkey())
- throw SalticidaeError(SALTI_ERROR_TLS_GENERIC);
+ throw SalticidaeError(SALTI_ERROR_TLS_KEY_NOT_MATCH);
}
+ signal(SIGPIPE, SIG_IGN);
workers = new Worker[nworker];
user_tcall = new ThreadCall(ec);
disp_ec = workers[0].get_ec();
diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h
index 1d20b22..bcfd9dc 100644
--- a/include/salticidae/crypto.h
+++ b/include/salticidae/crypto.h
@@ -128,6 +128,10 @@ static inline int _tls_pem_with_passwd(char *buf, int size, int, void *) {
return _size - 1;
}
+static int _skip_CA_check(int, X509_STORE_CTX *) {
+ return 1;
+}
+
class PKey {
EVP_PKEY *key;
friend class TLSContext;
@@ -271,6 +275,11 @@ class TLSContext {
throw SalticidaeError(SALTI_ERROR_TLS_LOAD_KEY);
}
+ void set_verify(bool skip_ca_check = true, SSL_verify_cb verify_callback = nullptr) {
+ SSL_CTX_set_verify(ctx,
+ SSL_VERIFY_PEER, skip_ca_check ? _skip_CA_check : verify_callback);
+ }
+
bool check_privkey() {
return SSL_CTX_check_private_key(ctx) > 0;
}
@@ -329,13 +338,9 @@ class TLS {
return SSL_get_error(ssl, ret);
}
- ~TLS() {
- if (ssl)
- {
- SSL_shutdown(ssl);
- SSL_free(ssl);
- }
- }
+ void shutdown() { SSL_shutdown(ssl); }
+
+ ~TLS() { if (ssl) SSL_free(ssl); }
};
}
diff --git a/include/salticidae/network.h b/include/salticidae/network.h
index e9fdae6..b703c35 100644
--- a/include/salticidae/network.h
+++ b/include/salticidae/network.h
@@ -996,7 +996,7 @@ void msgnetwork_terminate(msgnetwork_t *self, const msgnetwork_conn_t *conn);
typedef void (*msgnetwork_msg_callback_t)(const msg_t *, const msgnetwork_conn_t *, void *userdata);
void msgnetwork_reg_handler(msgnetwork_t *self, _opcode_t opcode, msgnetwork_msg_callback_t cb, void *userdata);
-typedef void (*msgnetwork_conn_callback_t)(const msgnetwork_conn_t *, bool connected, void *userdata);
+typedef bool (*msgnetwork_conn_callback_t)(const msgnetwork_conn_t *, bool connected, void *userdata);
void msgnetwork_reg_conn_handler(msgnetwork_t *self, msgnetwork_conn_callback_t cb, void *userdata);
diff --git a/include/salticidae/util.h b/include/salticidae/util.h
index dec498c..9a57ae8 100644
--- a/include/salticidae/util.h
+++ b/include/salticidae/util.h
@@ -88,6 +88,8 @@ enum SalticidaeErrorCode {
SALTI_ERROR_TLS_GENERIC,
SALTI_ERROR_TLS_X509,
SALTI_ERROR_TLS_KEY,
+ SALTI_ERROR_TLS_KEY_NOT_MATCH,
+ SALTI_ERROR_TLS_NO_PEER_CERT,
SALTI_ERROR_UNKNOWN
};