Older/ToolKit/Util/SSLBox.cpp
amass 9de3af15eb
All checks were successful
Deploy / PullDocker (push) Successful in 12s
Deploy / Build (push) Successful in 1m51s
add ZLMediaKit code for learning.
2024-09-28 23:55:00 +08:00

520 lines
14 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
*
* This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
*
* Use of this source code is governed by MIT license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "SSLBox.h"
#include "onceToken.h"
#include "SSLUtil.h"
#if defined(ENABLE_OPENSSL)
#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <openssl/conf.h>
#include <openssl/bio.h>
#include <openssl/ossl_typ.h>
#endif //defined(ENABLE_OPENSSL)
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
//openssl版本是否支持sni
#define SSL_ENABLE_SNI
#endif
using namespace std;
namespace toolkit {
static bool s_ignore_invalid_cer = true;
SSL_Initor &SSL_Initor::Instance() {
static SSL_Initor obj;
return obj;
}
void SSL_Initor::ignoreInvalidCertificate(bool ignore) {
s_ignore_invalid_cer = ignore;
}
SSL_Initor::SSL_Initor() {
#if defined(ENABLE_OPENSSL)
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_digests();
OpenSSL_add_all_ciphers();
OpenSSL_add_all_algorithms();
CRYPTO_set_locking_callback([](int mode, int n, const char *file, int line) {
static mutex *s_mutexes = new mutex[CRYPTO_num_locks()];
static onceToken token(nullptr, []() {
delete[] s_mutexes;
});
if (mode & CRYPTO_LOCK) {
s_mutexes[n].lock();
} else {
s_mutexes[n].unlock();
}
});
CRYPTO_set_id_callback([]() -> unsigned long {
#if !defined(_WIN32)
return (unsigned long) pthread_self();
#else
return (unsigned long) GetCurrentThreadId();
#endif
});
setContext("", SSLUtil::makeSSLContext(vector<shared_ptr<X509> >(), nullptr, false), false);
setContext("", SSLUtil::makeSSLContext(vector<shared_ptr<X509> >(), nullptr, true), true);
#endif //defined(ENABLE_OPENSSL)
}
SSL_Initor::~SSL_Initor() {
#if defined(ENABLE_OPENSSL)
EVP_cleanup();
ERR_free_strings();
ERR_clear_error();
#if OPENSSL_VERSION_NUMBER >= 0x10000000L && OPENSSL_VERSION_NUMBER < 0x10100000L
ERR_remove_thread_state(nullptr);
#elif OPENSSL_VERSION_NUMBER < 0x10000000L
ERR_remove_state(0);
#endif
CRYPTO_set_locking_callback(nullptr);
//sk_SSL_COMP_free(SSL_COMP_get_compression_methods());
CRYPTO_cleanup_all_ex_data();
CONF_modules_unload(1);
CONF_modules_free();
#endif //defined(ENABLE_OPENSSL)
}
bool SSL_Initor::loadCertificate(const string &pem_or_p12, bool server_mode, const string &password, bool is_file,
bool is_default) {
auto cers = SSLUtil::loadPublicKey(pem_or_p12, password, is_file);
auto key = SSLUtil::loadPrivateKey(pem_or_p12, password, is_file);
auto ssl_ctx = SSLUtil::makeSSLContext(cers, key, server_mode, true);
if (!ssl_ctx) {
return false;
}
for (auto &cer : cers) {
auto server_name = SSLUtil::getServerName(cer.get());
setContext(server_name, ssl_ctx, server_mode, is_default);
break;
}
return true;
}
int SSL_Initor::findCertificate(SSL *ssl, int *, void *arg) {
#if !defined(ENABLE_OPENSSL) || !defined(SSL_ENABLE_SNI)
return 0;
#else
if (!ssl) {
return SSL_TLSEXT_ERR_ALERT_FATAL;
}
SSL_CTX *ctx = nullptr;
static auto &ref = SSL_Initor::Instance();
const char *vhost = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (vhost && vhost[0] != '\0') {
//根据域名找到证书
ctx = ref.getSSLCtx(vhost, (bool) (arg)).get();
if (!ctx) {
//未找到对应的证书
WarnL << "Can not find any certificate of host: " << vhost
<< ", select default certificate of: " << ref._default_vhost[(bool) (arg)];
}
}
if (!ctx) {
//客户端未指定域名或者指定的证书不存在,那么选择一个默认的证书
ctx = ref.getSSLCtx("", (bool) (arg)).get();
}
if (!ctx) {
//未有任何有效的证书
WarnL << "Can not find any available certificate of host: " << (vhost ? vhost : "default host")
<< ", tls handshake failed";
return SSL_TLSEXT_ERR_ALERT_FATAL;
}
SSL_set_SSL_CTX(ssl, ctx);
return SSL_TLSEXT_ERR_OK;
#endif
}
bool SSL_Initor::setContext(const string &vhost, const shared_ptr<SSL_CTX> &ctx, bool server_mode, bool is_default) {
if (!ctx) {
return false;
}
setupCtx(ctx.get());
#if defined(ENABLE_OPENSSL)
if (vhost.empty()) {
_ctx_empty[server_mode] = ctx;
#ifdef SSL_ENABLE_SNI
if (server_mode) {
SSL_CTX_set_tlsext_servername_callback(ctx.get(), findCertificate);
SSL_CTX_set_tlsext_servername_arg(ctx.get(), (void *) server_mode);
}
#endif // SSL_ENABLE_SNI
} else {
_ctxs[server_mode][vhost] = ctx;
if (is_default) {
_default_vhost[server_mode] = vhost;
}
if (vhost.find("*.") == 0) {
//通配符证书
_ctxs_wildcards[server_mode][vhost.substr(1)] = ctx;
}
DebugL << "Add certificate of: " << vhost;
}
return true;
#else
WarnL << "ENABLE_OPENSSL disabled, you can not use any features based on openssl";
return false;
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Initor::setupCtx(SSL_CTX *ctx) {
#if defined(ENABLE_OPENSSL)
//加载默认信任证书
SSLUtil::loadDefaultCAs(ctx);
SSL_CTX_set_cipher_list(ctx, "ALL:!ADH:!LOW:!EXP:!MD5:!3DES:!DES:!IDEA:!RC4:@STRENGTH");
SSL_CTX_set_verify_depth(ctx, 9);
SSL_CTX_set_mode(ctx, SSL_MODE_AUTO_RETRY);
SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, [](int ok, X509_STORE_CTX *pStore) {
if (!ok) {
int depth = X509_STORE_CTX_get_error_depth(pStore);
int err = X509_STORE_CTX_get_error(pStore);
WarnL << "SSL_CTX_set_verify callback, depth: " << depth << " ,err: " << X509_verify_cert_error_string(err);
}
return s_ignore_invalid_cer ? 1 : ok;
});
#ifndef SSL_OP_NO_COMPRESSION
#define SSL_OP_NO_COMPRESSION 0
#endif
#ifndef SSL_MODE_RELEASE_BUFFERS /* OpenSSL >= 1.0.0 */
#define SSL_MODE_RELEASE_BUFFERS 0
#endif
unsigned long ssloptions = SSL_OP_ALL
| SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION
| SSL_OP_NO_COMPRESSION;
#ifdef SSL_OP_NO_RENEGOTIATION /* openssl 1.1.0 */
ssloptions |= SSL_OP_NO_RENEGOTIATION;
#endif
#ifdef SSL_OP_NO_SSLv2
ssloptions |= SSL_OP_NO_SSLv2;
#endif
#ifdef SSL_OP_NO_SSLv3
ssloptions |= SSL_OP_NO_SSLv3;
#endif
#ifdef SSL_OP_NO_TLSv1
ssloptions |= SSL_OP_NO_TLSv1;
#endif
#ifdef SSL_OP_NO_TLSv1_1 /* openssl 1.0.1 */
ssloptions |= SSL_OP_NO_TLSv1_1;
#endif
SSL_CTX_set_options(ctx, ssloptions);
#endif //defined(ENABLE_OPENSSL)
}
shared_ptr<SSL> SSL_Initor::makeSSL(bool server_mode) {
#if defined(ENABLE_OPENSSL)
#ifdef SSL_ENABLE_SNI
//openssl 版本支持SNI
return SSLUtil::makeSSL(_ctx_empty[server_mode].get());
#else
//openssl 版本不支持SNI选择默认证书
return SSLUtil::makeSSL(getSSLCtx("",server_mode).get());
#endif//SSL_CTRL_SET_TLSEXT_HOSTNAME
#else
return nullptr;
#endif //defined(ENABLE_OPENSSL)
}
bool SSL_Initor::trustCertificate(X509 *cer, bool server_mode) {
return SSLUtil::trustCertificate(_ctx_empty[server_mode].get(), cer);
}
bool SSL_Initor::trustCertificate(const string &pem_p12_cer, bool server_mode, const string &password, bool is_file) {
auto cers = SSLUtil::loadPublicKey(pem_p12_cer, password, is_file);
for (auto &cer : cers) {
trustCertificate(cer.get(), server_mode);
}
return true;
}
std::shared_ptr<SSL_CTX> SSL_Initor::getSSLCtx(const string &vhost, bool server_mode) {
auto ret = getSSLCtx_l(vhost, server_mode);
if (ret) {
return ret;
}
return getSSLCtxWildcards(vhost, server_mode);
}
std::shared_ptr<SSL_CTX> SSL_Initor::getSSLCtxWildcards(const string &vhost, bool server_mode) {
for (auto &pr : _ctxs_wildcards[server_mode]) {
auto pos = strcasestr(vhost.data(), pr.first.data());
if (pos && pos + pr.first.size() == &vhost.back() + 1) {
return pr.second;
}
}
return nullptr;
}
std::shared_ptr<SSL_CTX> SSL_Initor::getSSLCtx_l(const string &vhost_in, bool server_mode) {
auto vhost = vhost_in;
if (vhost.empty()) {
if (!_default_vhost[server_mode].empty()) {
vhost = _default_vhost[server_mode];
} else {
//没默认主机,选择空主机
if (server_mode) {
WarnL << "Server with ssl must have certification and key";
}
return _ctx_empty[server_mode];
}
}
//根据主机名查找证书
auto it = _ctxs[server_mode].find(vhost);
if (it == _ctxs[server_mode].end()) {
return nullptr;
}
return it->second;
}
string SSL_Initor::defaultVhost(bool server_mode) {
return _default_vhost[server_mode];
}
////////////////////////////////////////////////////SSL_Box////////////////////////////////////////////////////////////
SSL_Box::~SSL_Box() {}
SSL_Box::SSL_Box(bool server_mode, bool enable, int buff_size) {
#if defined(ENABLE_OPENSSL)
_read_bio = BIO_new(BIO_s_mem());
_server_mode = server_mode;
if (enable) {
_ssl = SSL_Initor::Instance().makeSSL(server_mode);
}
if (_ssl) {
_write_bio = BIO_new(BIO_s_mem());
SSL_set_bio(_ssl.get(), _read_bio, _write_bio);
_server_mode ? SSL_set_accept_state(_ssl.get()) : SSL_set_connect_state(_ssl.get());
} else {
WarnL << "makeSSL failed";
}
_send_handshake = false;
_buff_size = buff_size;
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::shutdown() {
#if defined(ENABLE_OPENSSL)
_buffer_send.clear();
int ret = SSL_shutdown(_ssl.get());
if (ret != 1) {
ErrorL << "SSL_shutdown failed: " << SSLUtil::getLastError();
} else {
flush();
}
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::onRecv(const Buffer::Ptr &buffer) {
if (!buffer->size()) {
return;
}
if (!_ssl) {
if (_on_dec) {
_on_dec(buffer);
}
return;
}
#if defined(ENABLE_OPENSSL)
uint32_t offset = 0;
while (offset < buffer->size()) {
auto nwrite = BIO_write(_read_bio, buffer->data() + offset, buffer->size() - offset);
if (nwrite > 0) {
//部分或全部写入bio完毕
offset += nwrite;
flush();
continue;
}
//nwrite <= 0,出现异常
ErrorL << "Ssl error on BIO_write: " << SSLUtil::getLastError();
shutdown();
break;
}
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::onSend(Buffer::Ptr buffer) {
if (!buffer->size()) {
return;
}
if (!_ssl) {
if (_on_enc) {
_on_enc(buffer);
}
return;
}
#if defined(ENABLE_OPENSSL)
if (!_server_mode && !_send_handshake) {
_send_handshake = true;
SSL_do_handshake(_ssl.get());
}
_buffer_send.emplace_back(std::move(buffer));
flush();
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::setOnDecData(const function<void(const Buffer::Ptr &)> &cb) {
_on_dec = cb;
}
void SSL_Box::setOnEncData(const function<void(const Buffer::Ptr &)> &cb) {
_on_enc = cb;
}
void SSL_Box::flushWriteBio() {
#if defined(ENABLE_OPENSSL)
int total = 0;
int nread = 0;
auto buffer_bio = _buffer_pool.obtain2();
buffer_bio->setCapacity(_buff_size);
auto buf_size = buffer_bio->getCapacity() - 1;
do {
nread = BIO_read(_write_bio, buffer_bio->data() + total, buf_size - total);
if (nread > 0) {
total += nread;
}
} while (nread > 0 && buf_size - total > 0);
if (!total) {
//未有数据
return;
}
//触发此次回调
buffer_bio->data()[total] = '\0';
buffer_bio->setSize(total);
if (_on_enc) {
_on_enc(buffer_bio);
}
if (nread > 0) {
//还有剩余数据,读取剩余数据
flushWriteBio();
}
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::flushReadBio() {
#if defined(ENABLE_OPENSSL)
int total = 0;
int nread = 0;
auto buffer_bio = _buffer_pool.obtain2();
buffer_bio->setCapacity(_buff_size);
auto buf_size = buffer_bio->getCapacity() - 1;
do {
nread = SSL_read(_ssl.get(), buffer_bio->data() + total, buf_size - total);
if (nread > 0) {
total += nread;
}
} while (nread > 0 && buf_size - total > 0);
if (!total) {
//未有数据
return;
}
//触发此次回调
buffer_bio->data()[total] = '\0';
buffer_bio->setSize(total);
if (_on_dec) {
_on_dec(buffer_bio);
}
if (nread > 0) {
//还有剩余数据,读取剩余数据
flushReadBio();
}
#endif //defined(ENABLE_OPENSSL)
}
void SSL_Box::flush() {
#if defined(ENABLE_OPENSSL)
if (_is_flush) {
return;
}
onceToken token([&] {
_is_flush = true;
}, [&]() {
_is_flush = false;
});
flushReadBio();
if (!SSL_is_init_finished(_ssl.get()) || _buffer_send.empty()) {
//ssl未握手结束或没有需要发送的数据
flushWriteBio();
return;
}
//加密数据并发送
while (!_buffer_send.empty()) {
auto &front = _buffer_send.front();
uint32_t offset = 0;
while (offset < front->size()) {
auto nwrite = SSL_write(_ssl.get(), front->data() + offset, front->size() - offset);
if (nwrite > 0) {
//部分或全部写入完毕
offset += nwrite;
flushWriteBio();
continue;
}
//nwrite <= 0,出现异常
break;
}
if (offset != front->size()) {
//这个包未消费完毕,出现了异常,清空数据并断开ssl
ErrorL << "Ssl error on SSL_write: " << SSLUtil::getLastError();
shutdown();
break;
}
//这个包消费完毕,开始消费下一个包
_buffer_send.pop_front();
}
#endif //defined(ENABLE_OPENSSL)
}
bool SSL_Box::setHost(const char *host) {
if (!_ssl) {
return false;
}
#ifdef SSL_ENABLE_SNI
return 0 != SSL_set_tlsext_host_name(_ssl.get(), host);
#else
return false;
#endif//SSL_ENABLE_SNI
}
} /* namespace toolkit */