From 823585c4db2ef6752d12f489c83edab577b86099 Mon Sep 17 00:00:00 2001 From: Determinant Date: Thu, 20 Jun 2019 23:52:56 -0400 Subject: finish test_msgnet_tls example --- include/salticidae/conn.h | 20 +++++++++++++------- include/salticidae/crypto.h | 21 ++++++++++++++++----- 2 files changed, 29 insertions(+), 12 deletions(-) (limited to 'include') diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index 59d93fc..a3da96c 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -112,7 +112,9 @@ class ConnPool { void disp_terminate(); public: - Conn(): ready_send(false), send_data_func(nullptr), recv_data_func(nullptr) {} + Conn(): ready_send(false), + send_data_func(nullptr), recv_data_func(nullptr), + tls(nullptr), peer_cert(nullptr) {} Conn(const Conn &) = delete; Conn(Conn &&other) = delete; @@ -133,7 +135,7 @@ class ConnPool { operator std::string() const; const NetAddr &get_addr() const { return addr; } - const X509 &get_peer_cert() const { return *peer_cert; } + const X509 *get_peer_cert() const { return peer_cert.get(); } ConnMode get_mode() const { return mode; } ConnPool *get_pool() const { return cpool; } MPSCWriteBuffer &get_send_buffer() { return send_buffer; } @@ -190,12 +192,16 @@ class ConnPool { void update_conn(const conn_t &conn, bool connected) { user_tcall->async_call([this, conn, connected](ThreadCall::Handle &) { - if ((!conn_cb || - conn_cb(conn, connected)) && - enable_tls && connected) - conn->worker->get_tcall()->async_call([conn](ThreadCall::Handle &) { - conn->recv_data_func = Conn::_recv_data_tls; + bool ret = !conn_cb || conn_cb(conn, connected); + if (enable_tls && connected) + { + conn->worker->get_tcall()->async_call([conn, ret](ThreadCall::Handle &) { + if (ret) + conn->recv_data_func = Conn::_recv_data_tls; + else + conn->worker_terminate(); }); + } }); } diff --git a/include/salticidae/crypto.h b/include/salticidae/crypto.h index bcfd9dc..7eec030 100644 --- a/include/salticidae/crypto.h +++ b/include/salticidae/crypto.h @@ -168,8 +168,8 @@ class PKey { return PKey(key); } - bytearray_t get_pubkey_der() { - uint8_t *der; + bytearray_t get_pubkey_der() const { + uint8_t *der = nullptr; auto ret = i2d_PublicKey(key, &der); if (ret <= 0) throw SalticidaeError(SALTI_ERROR_TLS_KEY); @@ -179,8 +179,8 @@ class PKey { return std::move(res); } - bytearray_t get_privkey_der() { - uint8_t *der; + bytearray_t get_privkey_der() const { + uint8_t *der = nullptr; auto ret = i2d_PrivateKey(key, &der); if (ret <= 0) throw SalticidaeError(SALTI_ERROR_TLS_KEY); @@ -229,13 +229,24 @@ class X509 { return X509(x509); } - PKey get_pubkey() { + PKey get_pubkey() const { auto key = X509_get_pubkey(x509); if (key == nullptr) throw SalticidaeError(SALTI_ERROR_TLS_X509); return PKey(key); } + bytearray_t get_der() const { + uint8_t *der = nullptr; + auto ret = i2d_X509(x509, &der); + if (ret <= 0) + throw SalticidaeError(SALTI_ERROR_TLS_X509); + bytearray_t res(der, der + ret); + OPENSSL_cleanse(der, ret); + OPENSSL_free(der); + return std::move(res); + } + ~X509() { if (x509) X509_free(x509); } }; -- cgit v1.2.3