diff --git a/include/rtc/websocket.hpp b/include/rtc/websocket.hpp index 154a04bc3..6f76f7a09 100644 --- a/include/rtc/websocket.hpp +++ b/include/rtc/websocket.hpp @@ -39,6 +39,7 @@ class RTC_CPP_EXPORT WebSocket final : private CheshireCat, pub optional connectionTimeout; // zero to disable optional pingInterval; // zero to disable optional maxOutstandingPings; + optional caCertificatePemFile; }; WebSocket(); diff --git a/src/impl/verifiedtlstransport.cpp b/src/impl/verifiedtlstransport.cpp index 14973229d..d9606522c 100644 --- a/src/impl/verifiedtlstransport.cpp +++ b/src/impl/verifiedtlstransport.cpp @@ -13,9 +13,11 @@ namespace rtc::impl { +static const string PemBeginCertificateTag = "-----BEGIN CERTIFICATE-----"; + VerifiedTlsTransport::VerifiedTlsTransport( variant, shared_ptr> lower, string host, - certificate_ptr certificate, state_callback callback) + certificate_ptr certificate, state_callback callback, [[maybe_unused]] optional cacert) : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) { PLOG_DEBUG << "Setting up TLS certificate verification"; @@ -24,13 +26,36 @@ VerifiedTlsTransport::VerifiedTlsTransport( gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0); #elif USE_MBEDTLS mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED); + mbedtls_x509_crt_init(&mCaCert); + try { + if (cacert) { + if (cacert->find(PemBeginCertificateTag) == string::npos) { + // *cacert is a file path + mbedtls::check(mbedtls_x509_crt_parse_file(&mCaCert, cacert->c_str())); + } else { + // *cacert is a PEM content + mbedtls::check(mbedtls_x509_crt_parse( + &mCaCert, reinterpret_cast(cacert->c_str()), + cacert->size())); + } + mbedtls_ssl_conf_ca_chain(&mConf, &mCaCert, NULL); + } + } catch (...) { + mbedtls_x509_crt_free(&mCaCert); + throw; + } #else SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL); SSL_set_verify_depth(mSsl, 4); #endif } -VerifiedTlsTransport::~VerifiedTlsTransport() { stop(); } +VerifiedTlsTransport::~VerifiedTlsTransport() { + stop(); +#if USE_MBEDTLS + mbedtls_x509_crt_free(&mCaCert); +#endif +} } // namespace rtc::impl diff --git a/src/impl/verifiedtlstransport.hpp b/src/impl/verifiedtlstransport.hpp index 352f2a049..0d38feba5 100644 --- a/src/impl/verifiedtlstransport.hpp +++ b/src/impl/verifiedtlstransport.hpp @@ -18,8 +18,14 @@ namespace rtc::impl { class VerifiedTlsTransport final : public TlsTransport { public: VerifiedTlsTransport(variant, shared_ptr> lower, - string host, certificate_ptr certificate, state_callback callback); + string host, certificate_ptr certificate, state_callback callback, + optional cacert); ~VerifiedTlsTransport(); + +private: +#if USE_MBEDTLS + mbedtls_x509_crt mCaCert; +#endif }; } // namespace rtc::impl diff --git a/src/impl/websocket.cpp b/src/impl/websocket.cpp index 1794a4c9b..77fe34bbe 100644 --- a/src/impl/websocket.cpp +++ b/src/impl/websocket.cpp @@ -358,7 +358,8 @@ shared_ptr WebSocket::initTlsTransport() { shared_ptr transport; if (verify) transport = std::make_shared(lower, mHostname.value(), - mCertificate, stateChangeCallback); + mCertificate, stateChangeCallback, + config.caCertificatePemFile); else transport = std::make_shared(lower, mHostname, mCertificate, stateChangeCallback);