diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OCSP.Chain.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OCSP.Chain.cs new file mode 100644 index 0000000000000..e2839011a437a --- /dev/null +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OCSP.Chain.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Win32.SafeHandles; + +internal static partial class Interop +{ + internal static partial class Crypto + { + [LibraryImport(Libraries.CryptoNative, StringMarshalling = StringMarshalling.Utf8)] + private static partial int CryptoNative_X509ChainGetCachedOcspStatus( + SafeX509StoreCtxHandle ctx, + string cachePath, + int chainDepth); + + internal static X509VerifyStatusCode X509ChainGetCachedOcspStatus(SafeX509StoreCtxHandle ctx, string cachePath, int chainDepth) + { + X509VerifyStatusCode response = (X509VerifyStatusCode)CryptoNative_X509ChainGetCachedOcspStatus(ctx, cachePath, chainDepth); + + if (response.Code < 0) + { + Debug.Fail($"Unexpected response from X509ChainGetCachedOcspSuccess: {response}"); + throw new CryptographicException(); + } + + return response; + } + + [LibraryImport(Libraries.CryptoNative)] + private static partial int CryptoNative_X509ChainHasStapledOcsp(SafeX509StoreCtxHandle storeCtx); + + internal static bool X509ChainHasStapledOcsp(SafeX509StoreCtxHandle storeCtx) + { + int resp = CryptoNative_X509ChainHasStapledOcsp(storeCtx); + + if (resp == 1) + { + return true; + } + + Debug.Assert(resp == 0, $"Unexpected response from X509ChainHasStapledOcsp: {resp}"); + return false; + } + + [LibraryImport(Libraries.CryptoNative, StringMarshalling = StringMarshalling.Utf8)] + private static partial int CryptoNative_X509ChainVerifyOcsp( + SafeX509StoreCtxHandle ctx, + SafeOcspRequestHandle req, + SafeOcspResponseHandle resp, + string cachePath, + int chainDepth); + + internal static X509VerifyStatusCode X509ChainVerifyOcsp( + SafeX509StoreCtxHandle ctx, + SafeOcspRequestHandle req, + SafeOcspResponseHandle resp, + string cachePath, + int chainDepth) + { + X509VerifyStatusCode response = (X509VerifyStatusCode)CryptoNative_X509ChainVerifyOcsp(ctx, req, resp, cachePath, chainDepth); + + if (response.Code < 0) + { + Debug.Fail($"Unexpected response from X509ChainGetCachedOcspSuccess: {response}"); + throw new CryptographicException(); + } + + return response; + } + + [LibraryImport(Libraries.CryptoNative)] + private static partial SafeOcspRequestHandle CryptoNative_X509ChainBuildOcspRequest( + SafeX509StoreCtxHandle storeCtx, + int chainDepth); + + internal static SafeOcspRequestHandle X509ChainBuildOcspRequest(SafeX509StoreCtxHandle storeCtx, int chainDepth) + { + SafeOcspRequestHandle req = CryptoNative_X509ChainBuildOcspRequest(storeCtx, chainDepth); + + if (req.IsInvalid) + { + req.Dispose(); + throw CreateOpenSslCryptographicException(); + } + + return req; + } + } +} diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OCSP.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OCSP.cs index 1a0833b1011ab..59736b39f47e8 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OCSP.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OCSP.cs @@ -4,7 +4,6 @@ using System; using System.Diagnostics; using System.Runtime.InteropServices; -using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using Microsoft.Win32.SafeHandles; @@ -21,97 +20,73 @@ internal static partial class Crypto [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_EncodeOcspRequest")] internal static partial int EncodeOcspRequest(SafeOcspRequestHandle req, byte[] buf); - [LibraryImport(Libraries.CryptoNative)] - private static partial SafeOcspResponseHandle CryptoNative_DecodeOcspResponse(ref byte buf, int len); - - internal static SafeOcspResponseHandle DecodeOcspResponse(ReadOnlySpan buf) - { - return CryptoNative_DecodeOcspResponse( - ref MemoryMarshal.GetReference(buf), - buf.Length); - } - - [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_OcspResponseDestroy")] - internal static partial void OcspResponseDestroy(IntPtr ocspReq); - - [LibraryImport(Libraries.CryptoNative, StringMarshalling = StringMarshalling.Utf8)] - private static partial int CryptoNative_X509ChainGetCachedOcspStatus( - SafeX509StoreCtxHandle ctx, - string cachePath, - int chainDepth); + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_X509BuildOcspRequest")] + internal static partial SafeOcspRequestHandle X509BuildOcspRequest(IntPtr subject, IntPtr issuer); - internal static X509VerifyStatusCode X509ChainGetCachedOcspStatus(SafeX509StoreCtxHandle ctx, string cachePath, int chainDepth) + [LibraryImport(Libraries.CryptoNative)] + private static unsafe partial int CryptoNative_X509DecodeOcspToExpiration( + byte* buf, + int len, + SafeOcspRequestHandle req, + IntPtr subject, + IntPtr issuer, + ref long expiration); + + internal static unsafe bool X509DecodeOcspToExpiration( + ReadOnlySpan buf, + SafeOcspRequestHandle request, + IntPtr x509Subject, + IntPtr x509Issuer, + out DateTimeOffset expiration) { - X509VerifyStatusCode response = (X509VerifyStatusCode)CryptoNative_X509ChainGetCachedOcspStatus(ctx, cachePath, chainDepth); + long timeT = 0; + int ret; - if (response.Code < 0) + fixed (byte* pBuf = buf) { - Debug.Fail($"Unexpected response from X509ChainGetCachedOcspSuccess: {response}"); - throw new CryptographicException(); + ret = CryptoNative_X509DecodeOcspToExpiration( + pBuf, + buf.Length, + request, + x509Subject, + x509Issuer, + ref timeT); } - return response; - } - - [LibraryImport(Libraries.CryptoNative)] - private static partial int CryptoNative_X509ChainHasStapledOcsp(SafeX509StoreCtxHandle storeCtx); - - internal static bool X509ChainHasStapledOcsp(SafeX509StoreCtxHandle storeCtx) - { - int resp = CryptoNative_X509ChainHasStapledOcsp(storeCtx); - - if (resp == 1) + if (ret == 1) { + if (timeT != 0) + { + expiration = DateTimeOffset.FromUnixTimeSeconds(timeT); + } + else + { + // Something went wrong during the determination of when the response + // should not be used any longer. + // Half an hour sounds fair? + expiration = DateTimeOffset.UtcNow.AddMinutes(30); + } + return true; } - Debug.Assert(resp == 0, $"Unexpected response from X509ChainHasStapledOcsp: {resp}"); + Debug.Assert(ret == 0, $"Unexpected response from X509DecodeOcspToExpiration: {ret}"); + expiration = DateTimeOffset.MinValue; return false; } - [LibraryImport(Libraries.CryptoNative, StringMarshalling = StringMarshalling.Utf8)] - private static partial int CryptoNative_X509ChainVerifyOcsp( - SafeX509StoreCtxHandle ctx, - SafeOcspRequestHandle req, - SafeOcspResponseHandle resp, - string cachePath, - int chainDepth); - - internal static X509VerifyStatusCode X509ChainVerifyOcsp( - SafeX509StoreCtxHandle ctx, - SafeOcspRequestHandle req, - SafeOcspResponseHandle resp, - string cachePath, - int chainDepth) - { - X509VerifyStatusCode response = (X509VerifyStatusCode)CryptoNative_X509ChainVerifyOcsp(ctx, req, resp, cachePath, chainDepth); - - if (response.Code < 0) - { - Debug.Fail($"Unexpected response from X509ChainGetCachedOcspSuccess: {response}"); - throw new CryptographicException(); - } - - return response; - } - [LibraryImport(Libraries.CryptoNative)] - private static partial SafeOcspRequestHandle CryptoNative_X509ChainBuildOcspRequest( - SafeX509StoreCtxHandle storeCtx, - int chainDepth); + private static partial SafeOcspResponseHandle CryptoNative_DecodeOcspResponse(ref byte buf, int len); - internal static SafeOcspRequestHandle X509ChainBuildOcspRequest(SafeX509StoreCtxHandle storeCtx, int chainDepth) + internal static SafeOcspResponseHandle DecodeOcspResponse(ReadOnlySpan buf) { - SafeOcspRequestHandle req = CryptoNative_X509ChainBuildOcspRequest(storeCtx, chainDepth); - - if (req.IsInvalid) - { - req.Dispose(); - throw CreateOpenSslCryptographicException(); - } - - return req; + return CryptoNative_DecodeOcspResponse( + ref MemoryMarshal.GetReference(buf), + buf.Length); } + + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_OcspResponseDestroy")] + internal static partial void OcspResponseDestroy(IntPtr ocspReq); } } diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs index a3e872fe04962..a72275a75900f 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs @@ -234,11 +234,19 @@ internal static unsafe SafeSslContextHandle AllocateSslContext(SafeFreeSslCreden SetSslCertificate(sslCtx, certHandle!, certKeyHandle!); } - if (sslAuthenticationOptions.CertificateContext != null && sslAuthenticationOptions.CertificateContext.IntermediateCertificates.Length > 0) + if (sslAuthenticationOptions.CertificateContext != null) { - if (!Ssl.AddExtraChainCertificates(sslCtx, sslAuthenticationOptions.CertificateContext.IntermediateCertificates)) + if (sslAuthenticationOptions.CertificateContext.IntermediateCertificates.Length > 0) { - throw CreateSslException(SR.net_ssl_use_cert_failed); + if (!Ssl.AddExtraChainCertificates(sslCtx, sslAuthenticationOptions.CertificateContext.IntermediateCertificates)) + { + throw CreateSslException(SR.net_ssl_use_cert_failed); + } + } + + if (sslAuthenticationOptions.CertificateContext.OcspStaplingAvailable) + { + Ssl.SslCtxSetDefaultOcspCallback(sslCtx); } } } @@ -422,27 +430,37 @@ internal static SafeSslHandle AllocateSslHandle(SafeFreeSslCredentials credentia Ssl.SslSetVerifyPeer(sslHandle); } - if (sslAuthenticationOptions.CertificateContext?.Trust?._sendTrustInHandshake == true) + if (sslAuthenticationOptions.CertificateContext != null) { - SslCertificateTrust trust = sslAuthenticationOptions.CertificateContext!.Trust!; - X509Certificate2Collection certList = (trust._trustList ?? trust._store!.Certificates); - - Debug.Assert(certList != null, "certList != null"); - Span handles = certList.Count <= 256 - ? stackalloc IntPtr[256] - : new IntPtr[certList.Count]; - - for (int i = 0; i < certList.Count; i++) + if (sslAuthenticationOptions.CertificateContext.Trust?._sendTrustInHandshake == true) { - handles[i] = certList[i].Handle; + SslCertificateTrust trust = sslAuthenticationOptions.CertificateContext!.Trust!; + X509Certificate2Collection certList = (trust._trustList ?? trust._store!.Certificates); + + Debug.Assert(certList != null, "certList != null"); + Span handles = certList.Count <= 256 ? + stackalloc IntPtr[256] : + new IntPtr[certList.Count]; + + for (int i = 0; i < certList.Count; i++) + { + handles[i] = certList[i].Handle; + } + + if (!Ssl.SslAddClientCAs(sslHandle, handles.Slice(0, certList.Count))) + { + // The method can fail only when the number of cert names exceeds the maximum capacity + // supported by STACK_OF(X509_NAME) structure, which should not happen under normal + // operation. + Debug.Fail("Failed to add issuer to trusted CA list."); + } } - if (!Ssl.SslAddClientCAs(sslHandle, handles.Slice(0, certList.Count))) + byte[]? ocspResponse = sslAuthenticationOptions.CertificateContext.GetOcspResponseNoWaiting(); + + if (ocspResponse != null) { - // The method can fail only when the number of cert names exceeds the maximum capacity - // supported by STACK_OF(X509_NAME) structure, which should not happen under normal - // operation. - Debug.Fail("Failed to add issuer to trusted CA list."); + Ssl.SslStapleOcsp(sslHandle, ocspResponse); } } } diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs index 3a4cec91d5746..1779a5087e442 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs @@ -248,6 +248,19 @@ internal static unsafe bool SslAddClientCAs(SafeSslHandle ssl, Span x509 } } + [LibraryImport(Libraries.CryptoNative)] + private static unsafe partial void CryptoNative_SslStapleOcsp(SafeSslHandle ssl, byte* buf, int len); + + internal static unsafe void SslStapleOcsp(SafeSslHandle ssl, ReadOnlySpan stapledResponse) + { + Debug.Assert(stapledResponse.Length > 0); + + fixed (byte* ptr = stapledResponse) + { + CryptoNative_SslStapleOcsp(ssl, ptr, stapledResponse.Length); + } + } + internal static bool AddExtraChainCertificates(SafeSslHandle ssl, X509Certificate2[] chain) { // send pre-computed list of intermediates. diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtxOptions.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtxOptions.cs index f3412d69cc431..c2fe0c28811fe 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtxOptions.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtxOptions.cs @@ -34,5 +34,8 @@ internal static partial class Ssl [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetEncryptionPolicy")] [return: MarshalAs(UnmanagedType.Bool)] internal static partial bool SetEncryptionPolicy(SafeSslContextHandle ctx, EncryptionPolicy policy); + + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetDefaultOcspCallback")] + internal static partial void SslCtxSetDefaultOcspCallback(SafeSslContextHandle ctx); } } diff --git a/src/libraries/Common/src/System/Net/Http/X509ResourceClient.cs b/src/libraries/Common/src/System/Net/Http/X509ResourceClient.cs new file mode 100644 index 0000000000000..c2e36bae468eb --- /dev/null +++ b/src/libraries/Common/src/System/Net/Http/X509ResourceClient.cs @@ -0,0 +1,296 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.IO; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Http +{ + internal static partial class X509ResourceClient + { + private static readonly Func>? s_downloadBytes = CreateDownloadBytesFunc(); + + static partial void ReportNoClient(); + static partial void ReportNegativeTimeout(); + static partial void ReportDownloadStart(long totalMillis, string uri); + static partial void ReportDownloadStop(int bytesDownloaded); + static partial void ReportRedirectsExceeded(); + static partial void ReportRedirected(Uri newUri); + static partial void ReportRedirectNotFollowed(Uri redirectUri); + + internal static byte[]? DownloadAsset(string uri, TimeSpan downloadTimeout) + { + ValueTask task = DownloadAssetCore(uri, downloadTimeout, async: false); + Debug.Assert(task.IsCompletedSuccessfully); + return task.Result; + } + + internal static Task DownloadAssetAsync(string uri, TimeSpan downloadTimeout) + { + ValueTask task = DownloadAssetCore(uri, downloadTimeout, async: true); + return task.AsTask(); + } + + private static async ValueTask DownloadAssetCore(string uri, TimeSpan downloadTimeout, bool async) + { + if (s_downloadBytes is null) + { + ReportNoClient(); + + return null; + } + + if (downloadTimeout <= TimeSpan.Zero) + { + ReportNegativeTimeout(); + + return null; + } + + long totalMillis = (long)downloadTimeout.TotalMilliseconds; + + ReportDownloadStart(totalMillis, uri); + + CancellationTokenSource? cts = totalMillis > int.MaxValue ? null : new CancellationTokenSource((int)totalMillis); + byte[]? ret = null; + + try + { + ret = await s_downloadBytes(uri, cts?.Token ?? default, async).ConfigureAwait(false); + return ret; + } + catch { } + finally + { + cts?.Dispose(); + + ReportDownloadStop(ret?.Length ?? 0); + } + + return null; + } + + private static Func>? CreateDownloadBytesFunc() + { + try + { + // Use reflection to access System.Net.Http: + // Since System.Net.Http.dll explicitly depends on System.Security.Cryptography.X509Certificates.dll, + // the latter can't in turn have an explicit dependency on the former. + + // Get the relevant types needed. + Type? socketsHttpHandlerType = Type.GetType("System.Net.Http.SocketsHttpHandler, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); + Type? httpMessageHandlerType = Type.GetType("System.Net.Http.HttpMessageHandler, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); + Type? httpClientType = Type.GetType("System.Net.Http.HttpClient, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); + Type? httpRequestMessageType = Type.GetType("System.Net.Http.HttpRequestMessage, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); + Type? httpResponseMessageType = Type.GetType("System.Net.Http.HttpResponseMessage, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); + Type? httpResponseHeadersType = Type.GetType("System.Net.Http.Headers.HttpResponseHeaders, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); + Type? httpContentType = Type.GetType("System.Net.Http.HttpContent, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); + + if (socketsHttpHandlerType == null || httpMessageHandlerType == null || httpClientType == null || httpRequestMessageType == null || + httpResponseMessageType == null || httpResponseHeadersType == null || httpContentType == null) + { + Debug.Fail("Unable to load required type."); + return null; + } + + Type taskOfHttpResponseMessageType = typeof(Task<>).MakeGenericType(httpResponseMessageType); + + // Get the methods on those types. + ConstructorInfo? socketsHttpHandlerCtor = socketsHttpHandlerType.GetConstructor(Type.EmptyTypes); + PropertyInfo? pooledConnectionIdleTimeoutProp = socketsHttpHandlerType.GetProperty("PooledConnectionIdleTimeout"); + PropertyInfo? allowAutoRedirectProp = socketsHttpHandlerType.GetProperty("AllowAutoRedirect"); + ConstructorInfo? httpClientCtor = httpClientType.GetConstructor(new Type[] { httpMessageHandlerType }); + PropertyInfo? requestUriProp = httpRequestMessageType.GetProperty("RequestUri"); + ConstructorInfo? httpRequestMessageCtor = httpRequestMessageType.GetConstructor(Type.EmptyTypes); + MethodInfo? sendMethod = httpClientType.GetMethod("Send", new Type[] { httpRequestMessageType, typeof(CancellationToken) }); + MethodInfo? sendAsyncMethod = httpClientType.GetMethod("SendAsync", new Type[] { httpRequestMessageType, typeof(CancellationToken) }); + PropertyInfo? responseContentProp = httpResponseMessageType.GetProperty("Content"); + PropertyInfo? responseStatusCodeProp = httpResponseMessageType.GetProperty("StatusCode"); + PropertyInfo? responseHeadersProp = httpResponseMessageType.GetProperty("Headers"); + PropertyInfo? responseHeadersLocationProp = httpResponseHeadersType.GetProperty("Location"); + MethodInfo? readAsStreamMethod = httpContentType.GetMethod("ReadAsStream", Type.EmptyTypes); + PropertyInfo? taskOfHttpResponseMessageResultProp = taskOfHttpResponseMessageType.GetProperty("Result"); + + if (socketsHttpHandlerCtor == null || pooledConnectionIdleTimeoutProp == null || + allowAutoRedirectProp == null || httpClientCtor == null || + requestUriProp == null || httpRequestMessageCtor == null || + sendMethod == null || sendAsyncMethod == null || + responseContentProp == null || responseStatusCodeProp == null || + responseHeadersProp == null || responseHeadersLocationProp == null || + readAsStreamMethod == null || taskOfHttpResponseMessageResultProp == null) + { + Debug.Fail("Unable to load required members."); + return null; + } + + // Only keep idle connections around briefly, as a compromise between resource leakage and port exhaustion. + const int PooledConnectionIdleTimeoutSeconds = 15; + const int MaxRedirections = 10; + + // Equivalent of: + // var socketsHttpHandler = new SocketsHttpHandler() { + // PooledConnectionIdleTimeout = TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds), + // AllowAutoRedirect = false + // }; + // var httpClient = new HttpClient(socketsHttpHandler); + // Note: using a ConstructorInfo instead of Activator.CreateInstance, so the ILLinker can see the usage through the lambda method. + object? socketsHttpHandler = socketsHttpHandlerCtor.Invoke(null); + pooledConnectionIdleTimeoutProp.SetValue(socketsHttpHandler, TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds)); + allowAutoRedirectProp.SetValue(socketsHttpHandler, false); + object? httpClient = httpClientCtor.Invoke(new object?[] { socketsHttpHandler }); + + return async (string uriString, CancellationToken cancellationToken, bool async) => + { + Uri uri = new Uri(uriString); + + if (!IsAllowedScheme(uri.Scheme)) + { + return null; + } + + // Equivalent of: + // HttpRequestMessage requestMessage = new HttpRequestMessage() { RequestUri = new Uri(uri) }; + // HttpResponseMessage responseMessage = httpClient.Send(requestMessage, cancellationToken); + // Note: using a ConstructorInfo instead of Activator.CreateInstance, so the ILLinker can see the usage through the lambda method. + object requestMessage = httpRequestMessageCtor.Invoke(null); + requestUriProp.SetValue(requestMessage, uri); + object responseMessage; + + if (async) + { + Task sendTask = (Task)sendAsyncMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!; + await sendTask.ConfigureAwait(false); + responseMessage = taskOfHttpResponseMessageResultProp.GetValue(sendTask)!; + } + else + { + responseMessage = sendMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!; + } + + int redirections = 0; + Uri? redirectUri; + bool hasRedirect; + while (true) + { + int statusCode = (int)responseStatusCodeProp.GetValue(responseMessage)!; + object responseHeaders = responseHeadersProp.GetValue(responseMessage)!; + Uri? location = (Uri?)responseHeadersLocationProp.GetValue(responseHeaders); + redirectUri = GetUriForRedirect((Uri)requestUriProp.GetValue(requestMessage)!, statusCode, location, out hasRedirect); + if (redirectUri == null) + { + break; + } + + ((IDisposable)responseMessage).Dispose(); + + redirections++; + if (redirections > MaxRedirections) + { + ReportRedirectsExceeded(); + + return null; + } + + ReportRedirected(redirectUri); + + // Equivalent of: + // requestMessage = new HttpRequestMessage() { RequestUri = redirectUri }; + // requestMessage.RequestUri = redirectUri; + // responseMessage = httpClient.Send(requestMessage, cancellationToken); + requestMessage = httpRequestMessageCtor.Invoke(null); + requestUriProp.SetValue(requestMessage, redirectUri); + + if (async) + { + Task sendTask = (Task)sendAsyncMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!; + await sendTask.ConfigureAwait(false); + responseMessage = taskOfHttpResponseMessageResultProp.GetValue(sendTask)!; + } + else + { + responseMessage = sendMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!; + } + } + + if (hasRedirect && redirectUri == null) + { + return null; + } + + // Equivalent of: + // using Stream responseStream = resp.Content.ReadAsStream(); + object content = responseContentProp.GetValue(responseMessage)!; + using Stream responseStream = (Stream)readAsStreamMethod.Invoke(content, null)!; + + var result = new MemoryStream(); + responseStream.CopyTo(result); + ((IDisposable)responseMessage).Dispose(); + return result.ToArray(); + }; + } + catch + { + // We shouldn't have any exceptions, but if we do, ignore them all. + return null; + } + } + + private static Uri? GetUriForRedirect(Uri requestUri, int statusCode, Uri? location, out bool hasRedirect) + { + if (!IsRedirectStatusCode(statusCode)) + { + hasRedirect = false; + return null; + } + + hasRedirect = true; + + if (location == null) + { + return null; + } + + // Ensure the redirect location is an absolute URI. + if (!location.IsAbsoluteUri) + { + location = new Uri(requestUri, location); + } + + // Per https://tools.ietf.org/html/rfc7231#section-7.1.2, a redirect location without a + // fragment should inherit the fragment from the original URI. + string requestFragment = requestUri.Fragment; + if (!string.IsNullOrEmpty(requestFragment)) + { + string redirectFragment = location.Fragment; + if (string.IsNullOrEmpty(redirectFragment)) + { + location = new UriBuilder(location) { Fragment = requestFragment }.Uri; + } + } + + if (!IsAllowedScheme(location.Scheme)) + { + ReportRedirectNotFollowed(location); + + return null; + } + + return location; + } + + private static bool IsRedirectStatusCode(int statusCode) + { + // MultipleChoices (300), Moved (301), Found (302), SeeOther (303), TemporaryRedirect (307), PermanentRedirect (308) + return (statusCode >= 300 && statusCode <= 303) || statusCode == 307 || statusCode == 308; + } + + private static bool IsAllowedScheme(string scheme) + { + return string.Equals(scheme, "http", StringComparison.OrdinalIgnoreCase); + } + } +} diff --git a/src/libraries/Common/src/System/Text/UrlBase64Encoding.cs b/src/libraries/Common/src/System/Text/UrlBase64Encoding.cs new file mode 100644 index 0000000000000..18c235c8d6e98 --- /dev/null +++ b/src/libraries/Common/src/System/Text/UrlBase64Encoding.cs @@ -0,0 +1,80 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; + +namespace System.Text +{ + /// + /// This class provides URL-encoded-Base64, which is distinct from the base64url encoding. + /// + internal static class UrlBase64Encoding + { + internal static ArraySegment RentEncode(ReadOnlySpan input) + { + // Every 3 bytes turns into 4 chars for the Base64 operation + int base64Len = ((input.Length + 2) / 3) * 4; + char[] base64 = ArrayPool.Shared.Rent(base64Len); + + if (!Convert.TryToBase64Chars(input, base64, out int charsWritten)) + { + Debug.Fail($"Convert.TryToBase64 failed with {input.Length} bytes to a {base64.Length} buffer"); + throw new UnreachableException(); + } + + Debug.Assert(charsWritten == base64Len); + + // In the degenerate case every char will turn into 3 chars. + int urlEncodedLen = charsWritten * 3; + char[] urlEncoded = ArrayPool.Shared.Rent(urlEncodedLen); + + ReadOnlySpan source = base64.AsSpan(0, base64Len); + Span dest = urlEncoded; + int written = 0; + + while (!source.IsEmpty) + { + int pos = source.IndexOfAny('+', '/', '='); + + if (pos < 0) + { + source.CopyTo(dest); + written += source.Length; + break; + } + + source.Slice(0, pos).CopyTo(dest); + source = source.Slice(pos); + dest = dest.Slice(pos); + written += pos; + + dest[0] = '%'; + + switch (source[0]) + { + case '+': + dest[1] = '2'; + dest[2] = 'B'; + break; + case '/': + dest[1] = '2'; + dest[2] = 'F'; + break; + default: + Debug.Assert(source[0] == '='); + dest[1] = '3'; + dest[2] = 'D'; + break; + } + + source = source.Slice(1); + dest = dest.Slice(3); + written += 3; + } + + ArrayPool.Shared.Return(base64); + return new ArraySegment(urlEncoded, 0, written); + } + } +} diff --git a/src/libraries/System.Net.Security/src/System.Net.Security.csproj b/src/libraries/System.Net.Security/src/System.Net.Security.csproj index d3fd152787f15..50a94c18ae064 100644 --- a/src/libraries/System.Net.Security/src/System.Net.Security.csproj +++ b/src/libraries/System.Net.Security/src/System.Net.Security.csproj @@ -297,8 +297,12 @@ + + + SslContexts; + private bool _staplingForbidden; + private byte[]? _ocspResponse; + private DateTimeOffset _ocspExpiration; + private DateTimeOffset _nextDownload; + private Task? _pendingDownload; + private List? _ocspUrls; + private X509Certificate2? _ca; + private SslStreamCertificateContext(X509Certificate2 target, X509Certificate2[] intermediates, SslCertificateTrust? trust) { Certificate = target; @@ -21,6 +34,220 @@ private SslStreamCertificateContext(X509Certificate2 target, X509Certificate2[] SslContexts = new ConcurrentDictionary(); } - internal static SslStreamCertificateContext Create(X509Certificate2 target) => Create(target, null); + internal static SslStreamCertificateContext Create(X509Certificate2 target) => + Create(target, null, offline: false, trust: null, noOcspFetch: true); + + internal bool OcspStaplingAvailable => _ocspUrls is not null; + + partial void SetNoOcspFetch(bool noOcspFetch) + { + _staplingForbidden = noOcspFetch; + } + + partial void AddRootCertificate(X509Certificate2? rootCertificate) + { + if (IntermediateCertificates.Length == 0) + { + _ca = rootCertificate; + } + else + { + _ca = IntermediateCertificates[0]; + } + + if (!_staplingForbidden) + { + // Create the task, let the download finish in the background. + GetOcspResponseAsync().AsTask(); + } + } + + internal byte[]? GetOcspResponseNoWaiting() + { + try + { + ValueTask task = GetOcspResponseAsync(); + + if (task.IsCompletedSuccessfully) + { + return task.Result; + } + } + catch + { + } + + return null; + } + + internal ValueTask GetOcspResponseAsync() + { + if (_staplingForbidden) + { + return ValueTask.FromResult((byte[]?)null); + } + + DateTimeOffset now = DateTimeOffset.UtcNow; + + if (now > _ocspExpiration) + { + return DownloadOcspAsync(); + } + + if (now > _nextDownload) + { + // Calling DownloadOcsp will activate a Task to initiate + // in the background. Further calls will attach to the + // same Task if it's still running. + // + // We don't want the result here, just the task to background. +#pragma warning disable CA2012 // Use ValueTasks correctly + DownloadOcspAsync(); +#pragma warning restore CA2012 // Use ValueTasks correctly + } + + return ValueTask.FromResult(_ocspResponse); + } + + private ValueTask DownloadOcspAsync() + { + Task? pending = _pendingDownload; + + if (pending is not null && !pending.IsFaulted) + { + return new ValueTask(pending); + } + + if (_ocspUrls is null && _ca is not null) + { + foreach (X509Extension ext in Certificate.Extensions) + { + if (ext is X509AuthorityInformationAccessExtension aia) + { + foreach (string entry in aia.EnumerateOcspUris()) + { + if (Uri.TryCreate(entry, UriKind.Absolute, out Uri? uri)) + { + if (uri.Scheme == UriScheme.Http) + { + (_ocspUrls ??= new List()).Add(entry); + } + } + } + + break; + } + } + } + + if (_ocspUrls is null) + { + _ocspExpiration = _nextDownload = DateTimeOffset.MaxValue; + return new ValueTask((byte[]?)null); + } + + lock (SslContexts) + { + pending = _pendingDownload; + + if (pending is null || pending.IsFaulted) + { + _pendingDownload = pending = FetchOcspAsync(); + } + } + + return new ValueTask(pending); + } + + private async Task FetchOcspAsync() + { + X509Certificate2? caCert = _ca; + Debug.Assert(_ocspUrls is not null); + Debug.Assert(_ocspUrls.Count > 0); + Debug.Assert(caCert is not null); + + IntPtr subject = Certificate.Handle; + IntPtr issuer = caCert.Handle; + + using (SafeOcspRequestHandle ocspRequest = Interop.Crypto.X509BuildOcspRequest(subject, issuer)) + { + byte[] rentedBytes = ArrayPool.Shared.Rent(Interop.Crypto.GetOcspRequestDerSize(ocspRequest)); + int encodingSize = Interop.Crypto.EncodeOcspRequest(ocspRequest, rentedBytes); + ArraySegment encoded = new ArraySegment(rentedBytes, 0, encodingSize); + + ArraySegment rentedChars = UrlBase64Encoding.RentEncode(encoded); + byte[]? ret = null; + + for (int i = 0; i < _ocspUrls.Count; i++) + { + string url = MakeUrl(_ocspUrls[i], rentedChars); + ret = await System.Net.Http.X509ResourceClient.DownloadAssetAsync(url, TimeSpan.MaxValue).ConfigureAwait(false); + + if (ret is not null) + { + if (!Interop.Crypto.X509DecodeOcspToExpiration(ret, ocspRequest, subject, issuer, out DateTimeOffset expiration)) + { + continue; + } + + // Swap the working URL in as the first one we'll try next time. + if (i != 0) + { + string tmp = _ocspUrls[0]; + _ocspUrls[0] = _ocspUrls[i]; + _ocspUrls[i] = tmp; + } + + DateTimeOffset nextCheckA = DateTimeOffset.UtcNow.AddDays(1); + DateTimeOffset nextCheckB = expiration.AddMinutes(-5); + + _ocspResponse = ret; + _ocspExpiration = expiration; + _nextDownload = nextCheckA < nextCheckB ? nextCheckA : nextCheckB; + _pendingDownload = null; + break; + } + } + + ArrayPool.Shared.Return(rentedBytes); + ArrayPool.Shared.Return(rentedChars.Array!); + GC.KeepAlive(Certificate); + GC.KeepAlive(caCert); + return ret; + } + } + + private static string MakeUrl(string baseUri, ArraySegment encodedRequest) + { + Debug.Assert(baseUri.Length > 0); + Debug.Assert(encodedRequest.Count > 0); + + // From https://datatracker.ietf.org/doc/html/rfc6960: + // + // An OCSP request using the GET method is constructed as follows: + // + // GET {url}/{url-encoding of base-64 encoding of the DER encoding of + // the OCSPRequest} + // + // where {url} may be derived from the value of the authority + // information access extension in the certificate being checked for + // revocation + + // Since the certificate isn't expected to have a slash at the end, but might, + // use a custom concat over Uri's built-in combining constructor. + + string uriString; + + if (baseUri.EndsWith('/')) + { + uriString = string.Concat(baseUri, encodedRequest.AsSpan()); + } + else + { + uriString = string.Concat(baseUri, "/", encodedRequest.AsSpan()); + } + + return uriString; + } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamCertificateContext.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamCertificateContext.cs index d654e4f570a01..9ea3971f0444b 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamCertificateContext.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamCertificateContext.cs @@ -19,6 +19,16 @@ public static SslStreamCertificateContext Create(X509Certificate2 target, X509Ce } public static SslStreamCertificateContext Create(X509Certificate2 target, X509Certificate2Collection? additionalCertificates, bool offline = false, SslCertificateTrust? trust = null) + { + return Create(target, additionalCertificates, offline, trust, noOcspFetch: false); + } + + internal static SslStreamCertificateContext Create( + X509Certificate2 target, + X509Certificate2Collection? additionalCertificates, + bool offline, + SslCertificateTrust? trust, + bool noOcspFetch) { if (!target.HasPrivateKey) { @@ -26,6 +36,8 @@ public static SslStreamCertificateContext Create(X509Certificate2 target, X509Ce } X509Certificate2[] intermediates = Array.Empty(); + X509Certificate2? root = null; + using (X509Chain chain = new X509Chain()) { if (additionalCertificates != null) @@ -56,12 +68,15 @@ public static SslStreamCertificateContext Create(X509Certificate2 target, X509Ce if (TrimRootCertificate) { count--; + root = chain.ChainElements[chain.ChainElements.Count - 1].Certificate; + foreach (X509ChainStatus status in chain.ChainStatus) { if (status.Status.HasFlag(X509ChainStatusFlags.PartialChain)) { // The last cert isn't a root cert count++; + root = null; break; } } @@ -89,9 +104,19 @@ public static SslStreamCertificateContext Create(X509Certificate2 target, X509Ce } } - return new SslStreamCertificateContext(target, intermediates, trust); + SslStreamCertificateContext ctx = new SslStreamCertificateContext(target, intermediates, trust); + + // On Linux, AddRootCertificate will start a background download of an OCSP response, + // unless this context was built "offline", or this came from the internal Create(X509Certificate2) + ctx.SetNoOcspFetch(offline || noOcspFetch); + ctx.AddRootCertificate(root); + + return ctx; } + partial void AddRootCertificate(X509Certificate2? rootCertificate); + partial void SetNoOcspFetch(bool noOcspFetch); + internal SslStreamCertificateContext Duplicate() { return new SslStreamCertificateContext(new X509Certificate2(Certificate), IntermediateCertificates, Trust); diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/CertificateValidationRemoteServer.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/CertificateValidationRemoteServer.cs index 05dfce4137f05..e76cdfe77c4b1 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/CertificateValidationRemoteServer.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/CertificateValidationRemoteServer.cs @@ -5,7 +5,9 @@ using System.Net; using System.Net.Sockets; using System.Net.Test.Common; +using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; +using System.Security.Cryptography.X509Certificates.Tests.Common; using System.Threading.Tasks; using Microsoft.DotNet.XUnitExtensions; using Xunit; @@ -87,6 +89,178 @@ public async Task DefaultConnect_EndToEnd_Ok(string host) await EndToEndHelper(host); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public Task ConnectWithRevocation_WithCallback(bool checkRevocation) + { + X509RevocationMode mode = checkRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck; + return ConnectWithRevocation_WithCallback_Core(mode); + } + + [PlatformSpecific(TestPlatforms.Linux)] + [Theory] + [OuterLoop("Subject to system load race conditions")] + [InlineData(false)] + [InlineData(true)] + public Task ConnectWithRevocation_StapledOcsp(bool offlineContext) + { + // Offline will only work if + // a) the revocation has been checked recently enough that it is cached, or + // b) the server stapled the response + // + // At high load, the server's background fetch might not have completed before + // this test runs. + return ConnectWithRevocation_WithCallback_Core(X509RevocationMode.Offline, offlineContext); + } + + [Fact] + [PlatformSpecific(TestPlatforms.Linux)] + public Task ConnectWithRevocation_ServerCertWithoutContext_NoStapledOcsp() + { + // Offline will only work if + // a) the revocation has been checked recently enough that it is cached, or + // b) the server stapled the response + // + // At high load, the server's background fetch might not have completed before + // this test runs. + return ConnectWithRevocation_WithCallback_Core(X509RevocationMode.Offline, offlineContext: null); + } + + private async Task ConnectWithRevocation_WithCallback_Core( + X509RevocationMode revocationMode, + bool? offlineContext = false) + { + string offlinePart = offlineContext.HasValue ? offlineContext.GetValueOrDefault().ToString().ToLower() : "null"; + string serverName = $"{revocationMode.ToString().ToLower()}.{offlinePart}.server.example"; + + (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams(); + + CertificateAuthority.BuildPrivatePki( + PkiOptions.EndEntityRevocationViaOcsp | PkiOptions.CrlEverywhere, + out RevocationResponder responder, + out CertificateAuthority rootAuthority, + out CertificateAuthority intermediateAuthority, + out X509Certificate2 serverCert, + subjectName: serverName, + keySize: 2048, + extensions: TestHelper.BuildTlsServerCertExtensions(serverName)); + + SslClientAuthenticationOptions clientOpts = new SslClientAuthenticationOptions + { + TargetHost = serverName, + RemoteCertificateValidationCallback = CertificateValidationCallback, + CertificateRevocationCheckMode = revocationMode, + }; + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + X509Certificate2 temp = new X509Certificate2(serverCert.Export(X509ContentType.Pkcs12)); + serverCert.Dispose(); + serverCert = temp; + } + + await using (clientStream) + await using (serverStream) + using (responder) + using (rootAuthority) + using (intermediateAuthority) + using (serverCert) + using (X509Certificate2 issuerCert = intermediateAuthority.CloneIssuerCert()) + await using (SslStream tlsClient = new SslStream(clientStream)) + await using (SslStream tlsServer = new SslStream(serverStream)) + { + intermediateAuthority.Revoke(serverCert, serverCert.NotBefore); + + SslServerAuthenticationOptions serverOpts = new SslServerAuthenticationOptions(); + + if (offlineContext.HasValue) + { + serverOpts.ServerCertificateContext = SslStreamCertificateContext.Create( + serverCert, + new X509Certificate2Collection(issuerCert), + offlineContext.GetValueOrDefault()); + + if (revocationMode == X509RevocationMode.Offline) + { + // Give the OCSP response a better chance to finish. + await Task.Delay(200); + } + } + else + { + serverOpts.ServerCertificate = serverCert; + } + + Task serverTask = tlsServer.AuthenticateAsServerAsync(serverOpts); + Task clientTask = tlsClient.AuthenticateAsClientAsync(clientOpts); + + await TestConfiguration.WhenAllOrAnyFailedWithTimeout(clientTask, serverTask); + } + + static bool CertificateValidationCallback( + object sender, + X509Certificate? certificate, + X509Chain? chain, + SslPolicyErrors sslPolicyErrors) + { + Assert.NotNull(certificate); + Assert.NotNull(chain); + + sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateChainErrors; + + chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust; + chain.ChainPolicy.CustomTrustStore.Add(chain.ChainElements[^1].Certificate); + + // The offline test will not know about revocation for the intermediate, + // so change the policy to only check the end certificate. + chain.ChainPolicy.RevocationFlag = X509RevocationFlag.EndCertificateOnly; + + if (!chain.Build((X509Certificate2)certificate)) + { + sslPolicyErrors |= SslPolicyErrors.RemoteCertificateChainErrors; + } + + if (chain.ChainPolicy.RevocationMode == X509RevocationMode.NoCheck) + { + X509ChainStatusFlags chainFlags = 0; + + foreach (X509ChainStatus status in chain.ChainStatus) + { + chainFlags |= status.Status; + } + + Assert.Equal(X509ChainStatusFlags.NoError, chainFlags); + + // The call didn't request revocation, so the chain should have been trusted. + Assert.Equal(SslPolicyErrors.None, sslPolicyErrors); + } + else if ((certificate.Subject.Contains(".true.server.") || certificate.Subject.Contains(".null.server.")) && + chain.ChainPolicy.RevocationMode == X509RevocationMode.Offline) + { + // In an Offline chain with an offline context the revocation still shouldn't + // process, because there's no OCSP data. + Assert.Equal(SslPolicyErrors.RemoteCertificateChainErrors, sslPolicyErrors); + + Assert.Contains( + chain.ChainElements[0].ChainElementStatus, + cs => cs.Status == X509ChainStatusFlags.RevocationStatusUnknown); + } + else + { + // Revocation was requested, and the cert is revoked, so the callback should + // say the chain isn't happy. + Assert.Equal(SslPolicyErrors.RemoteCertificateChainErrors, sslPolicyErrors); + + Assert.Contains( + chain.ChainElements[0].ChainElementStatus, + cs => cs.Status == X509ChainStatusFlags.Revoked); + } + + return true; + } + } + private async Task EndToEndHelper(string host) { using (var client = new TcpClient()) diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs index 0a5c14df84112..f4cddb999324d 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs @@ -141,15 +141,13 @@ internal static void CleanupCertificates([CallerMemberName] string? testName = n catch { }; } - internal static (X509Certificate2 certificate, X509Certificate2Collection) GenerateCertificates(string targetName, [CallerMemberName] string? testName = null, bool longChain = false, bool serverCertificate = true) + internal static X509ExtensionCollection BuildTlsServerCertExtensions(string serverName) { - const int keySize = 2048; - if (PlatformDetection.IsWindows && testName != null) - { - CleanupCertificates(testName); - } + return BuildTlsCertExtensions(serverName, true); + } - X509Certificate2Collection chain = new X509Certificate2Collection(); + private static X509ExtensionCollection BuildTlsCertExtensions(string targetName, bool serverCertificate) + { X509ExtensionCollection extensions = new X509ExtensionCollection(); SubjectAlternativeNameBuilder builder = new SubjectAlternativeNameBuilder(); @@ -159,6 +157,20 @@ internal static (X509Certificate2 certificate, X509Certificate2Collection) Gener extensions.Add(s_eeKeyUsage); extensions.Add(serverCertificate ? s_tlsServerEku : s_tlsClientEku); + return extensions; + } + + internal static (X509Certificate2 certificate, X509Certificate2Collection) GenerateCertificates(string targetName, [CallerMemberName] string? testName = null, bool longChain = false, bool serverCertificate = true) + { + const int keySize = 2048; + if (PlatformDetection.IsWindows && testName != null) + { + CleanupCertificates(testName); + } + + X509Certificate2Collection chain = new X509Certificate2Collection(); + X509ExtensionCollection extensions = BuildTlsCertExtensions(targetName, serverCertificate); + CertificateAuthority.BuildPrivatePki( PkiOptions.IssuerRevocationViaCrl, out RevocationResponder responder, diff --git a/src/libraries/System.Security.Cryptography/src/System.Security.Cryptography.csproj b/src/libraries/System.Security.Cryptography/src/System.Security.Cryptography.csproj index 19a6e4ed3f8ad..03476ce8ad204 100644 --- a/src/libraries/System.Security.Cryptography/src/System.Security.Cryptography.csproj +++ b/src/libraries/System.Security.Cryptography/src/System.Security.Cryptography.csproj @@ -33,6 +33,8 @@ Link="Common\System\HexConverter.cs" /> + @@ -641,6 +643,8 @@ Link="Common\Interop\Unix\System.Security.Cryptography.Native\Interop.LookupFriendlyNameByOid.cs" /> + - Common\System\Security\Cryptography\Asn1\DigestInfoAsn.xml @@ -763,6 +765,10 @@ Common\System\Security\Cryptography\Asn1\Pkcs7\EncryptedDataAsn.xml.cs Common\System\Security\Cryptography\Asn1\Pkcs7\EncryptedDataAsn.xml + + diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCertificateAssetDownloader.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCertificateAssetDownloader.cs index 3045c37765cfe..5d87d20b40f5b 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCertificateAssetDownloader.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCertificateAssetDownloader.cs @@ -12,8 +12,6 @@ namespace System.Security.Cryptography.X509Certificates { internal static class OpenSslCertificateAssetDownloader { - private static readonly Func? s_downloadBytes = CreateDownloadBytesFunc(); - internal static X509Certificate2? DownloadCertificate(string uri, TimeSpan downloadTimeout) { byte[]? data = DownloadAsset(uri, downloadTimeout); @@ -113,252 +111,71 @@ internal static class OpenSslCertificateAssetDownloader private static byte[]? DownloadAsset(string uri, TimeSpan downloadTimeout) { - if (s_downloadBytes is null) - { - if (OpenSslX509ChainEventSource.Log.IsEnabled()) - { - OpenSslX509ChainEventSource.Log.HttpClientNotAvailable(); - } - - return null; - } - - if (downloadTimeout <= TimeSpan.Zero) - { - if (OpenSslX509ChainEventSource.Log.IsEnabled()) - { - OpenSslX509ChainEventSource.Log.DownloadTimeExceeded(); - } - - return null; - } + return System.Net.Http.X509ResourceClient.DownloadAsset(uri, downloadTimeout); + } + } +} - long totalMillis = (long)downloadTimeout.TotalMilliseconds; +namespace System.Net.Http +{ + using OpenSslX509ChainEventSource = System.Security.Cryptography.X509Certificates.OpenSslX509ChainEventSource; + internal partial class X509ResourceClient + { + static partial void ReportNoClient() + { if (OpenSslX509ChainEventSource.Log.IsEnabled()) { - OpenSslX509ChainEventSource.Log.AssetDownloadStart(totalMillis, uri); + OpenSslX509ChainEventSource.Log.HttpClientNotAvailable(); } - - CancellationTokenSource? cts = totalMillis > int.MaxValue ? null : new CancellationTokenSource((int)totalMillis); - byte[]? ret = null; - - try - { - ret = s_downloadBytes(uri, cts?.Token ?? default); - return ret; - } - catch { } - finally - { - cts?.Dispose(); - - if (OpenSslX509ChainEventSource.Log.IsEnabled()) - { - OpenSslX509ChainEventSource.Log.AssetDownloadStop(ret?.Length ?? 0); - } - } - - return null; } - private static Func? CreateDownloadBytesFunc() + static partial void ReportNegativeTimeout() { - try - { - // Use reflection to access System.Net.Http: - // Since System.Net.Http.dll explicitly depends on System.Security.Cryptography.X509Certificates.dll, - // the latter can't in turn have an explicit dependency on the former. - - // Get the relevant types needed. - Type? socketsHttpHandlerType = Type.GetType("System.Net.Http.SocketsHttpHandler, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); - Type? httpMessageHandlerType = Type.GetType("System.Net.Http.HttpMessageHandler, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); - Type? httpClientType = Type.GetType("System.Net.Http.HttpClient, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); - Type? httpRequestMessageType = Type.GetType("System.Net.Http.HttpRequestMessage, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); - Type? httpResponseMessageType = Type.GetType("System.Net.Http.HttpResponseMessage, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); - Type? httpResponseHeadersType = Type.GetType("System.Net.Http.Headers.HttpResponseHeaders, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); - Type? httpContentType = Type.GetType("System.Net.Http.HttpContent, System.Net.Http, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a", throwOnError: false); - if (socketsHttpHandlerType == null || httpMessageHandlerType == null || httpClientType == null || httpRequestMessageType == null || - httpResponseMessageType == null || httpResponseHeadersType == null || httpContentType == null) - { - Debug.Fail("Unable to load required type."); - return null; - } - - // Get the methods on those types. - ConstructorInfo? socketsHttpHandlerCtor = socketsHttpHandlerType.GetConstructor(Type.EmptyTypes); - PropertyInfo? pooledConnectionIdleTimeoutProp = socketsHttpHandlerType.GetProperty("PooledConnectionIdleTimeout"); - PropertyInfo? allowAutoRedirectProp = socketsHttpHandlerType.GetProperty("AllowAutoRedirect"); - ConstructorInfo? httpClientCtor = httpClientType.GetConstructor(new Type[] { httpMessageHandlerType }); - PropertyInfo? requestUriProp = httpRequestMessageType.GetProperty("RequestUri"); - ConstructorInfo? httpRequestMessageCtor = httpRequestMessageType.GetConstructor(Type.EmptyTypes); - MethodInfo? sendMethod = httpClientType.GetMethod("Send", new Type[] { httpRequestMessageType, typeof(CancellationToken) }); - PropertyInfo? responseContentProp = httpResponseMessageType.GetProperty("Content"); - PropertyInfo? responseStatusCodeProp = httpResponseMessageType.GetProperty("StatusCode"); - PropertyInfo? responseHeadersProp = httpResponseMessageType.GetProperty("Headers"); - PropertyInfo? responseHeadersLocationProp = httpResponseHeadersType.GetProperty("Location"); - MethodInfo? readAsStreamMethod = httpContentType.GetMethod("ReadAsStream", Type.EmptyTypes); - - if (socketsHttpHandlerCtor == null || pooledConnectionIdleTimeoutProp == null || allowAutoRedirectProp == null || httpClientCtor == null || - requestUriProp == null || httpRequestMessageCtor == null || sendMethod == null || responseContentProp == null || responseStatusCodeProp == null || - responseHeadersProp == null || responseHeadersLocationProp == null || readAsStreamMethod == null) - { - Debug.Fail("Unable to load required member."); - return null; - } - - // Only keep idle connections around briefly, as a compromise between resource leakage and port exhaustion. - const int PooledConnectionIdleTimeoutSeconds = 15; - const int MaxRedirections = 10; - - // Equivalent of: - // var socketsHttpHandler = new SocketsHttpHandler() { - // PooledConnectionIdleTimeout = TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds), - // AllowAutoRedirect = false - // }; - // var httpClient = new HttpClient(socketsHttpHandler); - // Note: using a ConstructorInfo instead of Activator.CreateInstance, so the ILLinker can see the usage through the lambda method. - object? socketsHttpHandler = socketsHttpHandlerCtor.Invoke(null); - pooledConnectionIdleTimeoutProp.SetValue(socketsHttpHandler, TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds)); - allowAutoRedirectProp.SetValue(socketsHttpHandler, false); - object? httpClient = httpClientCtor.Invoke(new object?[] { socketsHttpHandler }); - - return (string uriString, CancellationToken cancellationToken) => - { - Uri uri = new Uri(uriString); - - if (!IsAllowedScheme(uri.Scheme)) - { - return null; - } - - // Equivalent of: - // HttpRequestMessage requestMessage = new HttpRequestMessage() { RequestUri = new Uri(uri) }; - // HttpResponseMessage responseMessage = httpClient.Send(requestMessage, cancellationToken); - // Note: using a ConstructorInfo instead of Activator.CreateInstance, so the ILLinker can see the usage through the lambda method. - object requestMessage = httpRequestMessageCtor.Invoke(null); - requestUriProp.SetValue(requestMessage, uri); - object responseMessage = sendMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!; - - int redirections = 0; - Uri? redirectUri; - bool hasRedirect; - while (true) - { - int statusCode = (int)responseStatusCodeProp.GetValue(responseMessage)!; - object responseHeaders = responseHeadersProp.GetValue(responseMessage)!; - Uri? location = (Uri?)responseHeadersLocationProp.GetValue(responseHeaders); - redirectUri = GetUriForRedirect((Uri)requestUriProp.GetValue(requestMessage)!, statusCode, location, out hasRedirect); - if (redirectUri == null) - { - break; - } - - ((IDisposable)responseMessage).Dispose(); - - redirections++; - if (redirections > MaxRedirections) - { - if (OpenSslX509ChainEventSource.Log.IsEnabled()) - { - OpenSslX509ChainEventSource.Log.DownloadRedirectsExceeded(); - } - - return null; - } - - if (OpenSslX509ChainEventSource.Log.IsEnabled()) - { - OpenSslX509ChainEventSource.Log.DownloadRedirected(redirectUri); - } - - // Equivalent of: - // requestMessage = new HttpRequestMessage() { RequestUri = redirectUri }; - // requestMessage.RequestUri = redirectUri; - // responseMessage = httpClient.Send(requestMessage, cancellationToken); - requestMessage = httpRequestMessageCtor.Invoke(null); - requestUriProp.SetValue(requestMessage, redirectUri); - responseMessage = sendMethod.Invoke(httpClient, new object[] { requestMessage, cancellationToken })!; - } - - if (hasRedirect && redirectUri == null) - { - return null; - } - - // Equivalent of: - // using Stream responseStream = resp.Content.ReadAsStream(); - object content = responseContentProp.GetValue(responseMessage)!; - using Stream responseStream = (Stream)readAsStreamMethod.Invoke(content, null)!; - - var result = new MemoryStream(); - responseStream.CopyTo(result); - ((IDisposable)responseMessage).Dispose(); - return result.ToArray(); - }; - } - catch + if (OpenSslX509ChainEventSource.Log.IsEnabled()) { - // We shouldn't have any exceptions, but if we do, ignore them all. - return null; + OpenSslX509ChainEventSource.Log.DownloadTimeExceeded(); } } - private static Uri? GetUriForRedirect(Uri requestUri, int statusCode, Uri? location, out bool hasRedirect) + static partial void ReportDownloadStart(long totalMillis, string uri) { - if (!IsRedirectStatusCode(statusCode)) - { - hasRedirect = false; - return null; - } - - hasRedirect = true; - - if (location == null) - { - return null; - } - - // Ensure the redirect location is an absolute URI. - if (!location.IsAbsoluteUri) + if (OpenSslX509ChainEventSource.Log.IsEnabled()) { - location = new Uri(requestUri, location); + OpenSslX509ChainEventSource.Log.AssetDownloadStart(totalMillis, uri); } + } - // Per https://tools.ietf.org/html/rfc7231#section-7.1.2, a redirect location without a - // fragment should inherit the fragment from the original URI. - string requestFragment = requestUri.Fragment; - if (!string.IsNullOrEmpty(requestFragment)) + static partial void ReportDownloadStop(int bytesDownloaded) + { + if (OpenSslX509ChainEventSource.Log.IsEnabled()) { - string redirectFragment = location.Fragment; - if (string.IsNullOrEmpty(redirectFragment)) - { - location = new UriBuilder(location) { Fragment = requestFragment }.Uri; - } + OpenSslX509ChainEventSource.Log.AssetDownloadStop(bytesDownloaded); } + } - if (!IsAllowedScheme(location.Scheme)) + static partial void ReportRedirectsExceeded() + { + if (OpenSslX509ChainEventSource.Log.IsEnabled()) { - if (OpenSslX509ChainEventSource.Log.IsEnabled()) - { - OpenSslX509ChainEventSource.Log.DownloadRedirectNotFollowed(location); - } - - return null; + OpenSslX509ChainEventSource.Log.DownloadRedirectsExceeded(); } - - return location; } - private static bool IsRedirectStatusCode(int statusCode) + static partial void ReportRedirected(Uri newUri) { - // MultipleChoices (300), Moved (301), Found (302), SeeOther (303), TemporaryRedirect (307), PermanentRedirect (308) - return (statusCode >= 300 && statusCode <= 303) || statusCode == 307 || statusCode == 308; + if (OpenSslX509ChainEventSource.Log.IsEnabled()) + { + OpenSslX509ChainEventSource.Log.DownloadRedirected(newUri); + } } - private static bool IsAllowedScheme(string scheme) + static partial void ReportRedirectNotFollowed(Uri redirectUri) { - return string.Equals(scheme, "http", StringComparison.OrdinalIgnoreCase); + if (OpenSslX509ChainEventSource.Log.IsEnabled()) + { + OpenSslX509ChainEventSource.Log.DownloadRedirectNotFollowed(redirectUri); + } } } } diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslX509ChainProcessor.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslX509ChainProcessor.cs index ecb780864377c..2922c513bd41e 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslX509ChainProcessor.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslX509ChainProcessor.cs @@ -764,7 +764,7 @@ private Interop.Crypto.X509VerifyStatusCode CheckOcsp( (handle, buf) => Interop.Crypto.EncodeOcspRequest(handle, buf), req); - ArraySegment urlEncoded = Base64UrlEncode(encoded); + ArraySegment urlEncoded = UrlBase64Encoding.RentEncode(encoded); string requestUrl = UrlPathAppend(baseUri, urlEncoded); // Nothing sensitive is in the encoded request (it was sent via HTTP-non-S) @@ -815,85 +815,12 @@ private static string UrlPathAppend(string baseUri, ReadOnlyMemory resourc Debug.Assert(baseUri.Length > 0); Debug.Assert(resource.Length > 0); - int count = baseUri.Length + resource.Length; - if (baseUri.EndsWith('/')) { - return string.Create( - count, - (baseUri, resource), - (buf, st) => - { - st.baseUri.CopyTo(buf); - st.resource.Span.CopyTo(buf.Slice(st.baseUri.Length)); - }); - } - - return string.Create( - count + 1, - (baseUri, resource), - (buf, st) => - { - st.baseUri.CopyTo(buf); - buf[st.baseUri.Length] = '/'; - st.resource.Span.CopyTo(buf.Slice(st.baseUri.Length + 1)); - }); - } - - private static ArraySegment Base64UrlEncode(ReadOnlySpan input) - { - // Every 3 bytes turns into 4 chars for the Base64 operation - int base64Len = ((input.Length + 2) / 3) * 4; - char[] base64 = ArrayPool.Shared.Rent(base64Len); - - if (!Convert.TryToBase64Chars(input, base64, out int charsWritten)) - { - Debug.Fail($"Convert.TryToBase64 failed with {input.Length} bytes to a {base64.Length} buffer"); - throw new CryptographicException(); - } - - Debug.Assert(charsWritten == base64Len); - - // In the degenerate case every char will turn into 3 chars. - int urlEncodedLen = charsWritten * 3; - char[] urlEncoded = ArrayPool.Shared.Rent(urlEncodedLen); - int writeIdx = 0; - - for (int readIdx = 0; readIdx < charsWritten; readIdx++) - { - char cur = base64[readIdx]; - - if (char.IsAsciiLetterOrDigit(cur)) - { - urlEncoded[writeIdx++] = cur; - } - else if (cur == '+') - { - urlEncoded[writeIdx++] = '%'; - urlEncoded[writeIdx++] = '2'; - urlEncoded[writeIdx++] = 'B'; - } - else if (cur == '/') - { - urlEncoded[writeIdx++] = '%'; - urlEncoded[writeIdx++] = '2'; - urlEncoded[writeIdx++] = 'F'; - } - else if (cur == '=') - { - urlEncoded[writeIdx++] = '%'; - urlEncoded[writeIdx++] = '3'; - urlEncoded[writeIdx++] = 'D'; - } - else - { - Debug.Fail($"'{cur}' is not a valid Base64 character"); - throw new CryptographicException(); - } + return string.Concat(baseUri, resource.Span); } - ArrayPool.Shared.Return(base64); - return new ArraySegment(urlEncoded, 0, writeIdx); + return string.Concat(baseUri, "/", resource.Span); } private X509ChainElement[] BuildChainElements( diff --git a/src/native/libs/System.Security.Cryptography.Native/apibridge.c b/src/native/libs/System.Security.Cryptography.Native/apibridge.c index 73ba3e6cef5d6..c8e1a134b053d 100644 --- a/src/native/libs/System.Security.Cryptography.Native/apibridge.c +++ b/src/native/libs/System.Security.Cryptography.Native/apibridge.c @@ -889,4 +889,13 @@ int local_EVP_PKEY_public_check(EVP_PKEY_CTX* ctx) return -1; } } + + +int local_ASN1_TIME_to_tm(const ASN1_TIME* s, struct tm* tm) +{ + (void)s; + (void)tm; + + return 0; +} #endif diff --git a/src/native/libs/System.Security.Cryptography.Native/apibridge.h b/src/native/libs/System.Security.Cryptography.Native/apibridge.h index 2079d81fd9681..92f4c592ad893 100644 --- a/src/native/libs/System.Security.Cryptography.Native/apibridge.h +++ b/src/native/libs/System.Security.Cryptography.Native/apibridge.h @@ -6,6 +6,7 @@ #pragma once #include "pal_types.h" +int local_ASN1_TIME_to_tm(const ASN1_TIME* s, struct tm* tm); int local_BIO_up_ref(BIO *a); const BIGNUM* local_DSA_get0_key(const DSA* dsa, const BIGNUM** pubKey, const BIGNUM** privKey); void local_DSA_get0_pqg(const DSA* dsa, const BIGNUM** p, const BIGNUM** q, const BIGNUM** g); diff --git a/src/native/libs/System.Security.Cryptography.Native/entrypoints.c b/src/native/libs/System.Security.Cryptography.Native/entrypoints.c index 239d4e9e36afc..d5614264d25da 100644 --- a/src/native/libs/System.Security.Cryptography.Native/entrypoints.c +++ b/src/native/libs/System.Security.Cryptography.Native/entrypoints.c @@ -234,6 +234,7 @@ static const Entry s_cryptoNative[] = DllImportEntry(CryptoNative_RsaSignHash) DllImportEntry(CryptoNative_RsaVerifyHash) DllImportEntry(CryptoNative_UpRefEvpPkey) + DllImportEntry(CryptoNative_X509BuildOcspRequest) DllImportEntry(CryptoNative_X509ChainBuildOcspRequest) DllImportEntry(CryptoNative_X509ChainGetCachedOcspStatus) DllImportEntry(CryptoNative_X509ChainHasStapledOcsp) @@ -288,6 +289,7 @@ static const Entry s_cryptoNative[] = DllImportEntry(CryptoNative_SslCtxAddExtraChainCert) DllImportEntry(CryptoNative_SslCtxSetCaching) DllImportEntry(CryptoNative_SslCtxSetCiphers) + DllImportEntry(CryptoNative_SslCtxSetDefaultOcspCallback) DllImportEntry(CryptoNative_SslCtxSetEncryptionPolicy) DllImportEntry(CryptoNative_SetCiphers) DllImportEntry(CryptoNative_SslCreate) @@ -332,11 +334,13 @@ static const Entry s_cryptoNative[] = DllImportEntry(CryptoNative_SslSetTlsExtHostName) DllImportEntry(CryptoNative_SslSetVerifyPeer) DllImportEntry(CryptoNative_SslShutdown) + DllImportEntry(CryptoNative_SslStapleOcsp) DllImportEntry(CryptoNative_SslUseCertificate) DllImportEntry(CryptoNative_SslUsePrivateKey) DllImportEntry(CryptoNative_SslV2_3Method) DllImportEntry(CryptoNative_SslWrite) DllImportEntry(CryptoNative_Tls13Supported) + DllImportEntry(CryptoNative_X509DecodeOcspToExpiration) DllImportEntry(CryptoNative_X509Duplicate) DllImportEntry(CryptoNative_SslGet0AlpnSelected) }; diff --git a/src/native/libs/System.Security.Cryptography.Native/opensslshim.h b/src/native/libs/System.Security.Cryptography.Native/opensslshim.h index 2ddc0ee09d4eb..554f4144852b5 100644 --- a/src/native/libs/System.Security.Cryptography.Native/opensslshim.h +++ b/src/native/libs/System.Security.Cryptography.Native/opensslshim.h @@ -174,6 +174,7 @@ const EVP_CIPHER* EVP_chacha20_poly1305(void); REQUIRED_FUNCTION(ASN1_STRING_print_ex) \ REQUIRED_FUNCTION(ASN1_TIME_new) \ REQUIRED_FUNCTION(ASN1_TIME_set) \ + FALLBACK_FUNCTION(ASN1_TIME_to_tm) \ REQUIRED_FUNCTION(ASN1_TIME_free) \ REQUIRED_FUNCTION(BASIC_CONSTRAINTS_free) \ REQUIRED_FUNCTION(BIO_ctrl) \ @@ -195,7 +196,9 @@ const EVP_CIPHER* EVP_chacha20_poly1305(void); REQUIRED_FUNCTION(BN_num_bits) \ REQUIRED_FUNCTION(BN_set_word) \ LEGACY_FUNCTION(CRYPTO_add_lock) \ + REQUIRED_FUNCTION(CRYPTO_free) \ REQUIRED_FUNCTION(CRYPTO_get_ex_new_index) \ + REQUIRED_FUNCTION(CRYPTO_malloc) \ LEGACY_FUNCTION(CRYPTO_num_locks) \ LEGACY_FUNCTION(CRYPTO_set_locking_callback) \ REQUIRED_FUNCTION(d2i_ASN1_BIT_STRING) \ @@ -464,6 +467,7 @@ const EVP_CIPHER* EVP_chacha20_poly1305(void); REQUIRED_FUNCTION(SSL_add_client_CA) \ REQUIRED_FUNCTION(SSL_set_alpn_protos) \ REQUIRED_FUNCTION(SSL_set_quiet_shutdown) \ + REQUIRED_FUNCTION(SSL_CTX_callback_ctrl) \ REQUIRED_FUNCTION(SSL_CTX_check_private_key) \ FALLBACK_FUNCTION(SSL_CTX_config) \ REQUIRED_FUNCTION(SSL_CTX_ctrl) \ @@ -649,6 +653,7 @@ FOR_ALL_OPENSSL_FUNCTIONS #define ASN1_TIME_free ASN1_TIME_free_ptr #define ASN1_TIME_new ASN1_TIME_new_ptr #define ASN1_TIME_set ASN1_TIME_set_ptr +#define ASN1_TIME_to_tm ASN1_TIME_to_tm_ptr #define BASIC_CONSTRAINTS_free BASIC_CONSTRAINTS_free_ptr #define BIO_ctrl BIO_ctrl_ptr #define BIO_ctrl_pending BIO_ctrl_pending_ptr @@ -669,7 +674,9 @@ FOR_ALL_OPENSSL_FUNCTIONS #define BN_num_bits BN_num_bits_ptr #define BN_set_word BN_set_word_ptr #define CRYPTO_add_lock CRYPTO_add_lock_ptr +#define CRYPTO_free CRYPTO_free_ptr #define CRYPTO_get_ex_new_index CRYPTO_get_ex_new_index_ptr +#define CRYPTO_malloc CRYPTO_malloc_ptr #define CRYPTO_num_locks CRYPTO_num_locks_ptr #define CRYPTO_set_locking_callback CRYPTO_set_locking_callback_ptr #define d2i_ASN1_BIT_STRING d2i_ASN1_BIT_STRING_ptr @@ -940,6 +947,7 @@ FOR_ALL_OPENSSL_FUNCTIONS #define SSL_add_client_CA SSL_add_client_CA_ptr #define SSL_set_alpn_protos SSL_set_alpn_protos_ptr #define SSL_set_quiet_shutdown SSL_set_quiet_shutdown_ptr +#define SSL_CTX_callback_ctrl SSL_CTX_callback_ctrl_ptr #define SSL_CTX_check_private_key SSL_CTX_check_private_key_ptr #define SSL_CTX_config SSL_CTX_config_ptr #define SSL_CTX_ctrl SSL_CTX_ctrl_ptr @@ -1170,6 +1178,7 @@ FOR_ALL_OPENSSL_FUNCTIONS #elif OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_1_1_0_RTM // Alias "future" API to the local_ version. +#define ASN1_TIME_to_tm local_ASN1_TIME_to_tm #define BIO_up_ref local_BIO_up_ref #define DSA_get0_key local_DSA_get0_key #define DSA_get0_pqg local_DSA_get0_pqg diff --git a/src/native/libs/System.Security.Cryptography.Native/osslcompat_111.h b/src/native/libs/System.Security.Cryptography.Native/osslcompat_111.h index 831a0a98caf36..4630276689e3f 100644 --- a/src/native/libs/System.Security.Cryptography.Native/osslcompat_111.h +++ b/src/native/libs/System.Security.Cryptography.Native/osslcompat_111.h @@ -19,6 +19,7 @@ typedef struct stack_st OPENSSL_STACK; #define OPENSSL_INIT_LOAD_CONFIG 0x00000040L #define OPENSSL_INIT_LOAD_SSL_STRINGS 0x00200000L +int ASN1_TIME_to_tm(const ASN1_TIME* s, struct tm* tm); int BIO_up_ref(BIO* a); const BIGNUM* DSA_get0_key(const DSA* dsa, const BIGNUM** pubKey, const BIGNUM** privKey); void DSA_get0_pqg(const DSA* dsa, const BIGNUM** p, const BIGNUM** q, const BIGNUM** g); diff --git a/src/native/libs/System.Security.Cryptography.Native/pal_ssl.c b/src/native/libs/System.Security.Cryptography.Native/pal_ssl.c index 25a8d5433b334..acfd66db97301 100644 --- a/src/native/libs/System.Security.Cryptography.Native/pal_ssl.c +++ b/src/native/libs/System.Security.Cryptography.Native/pal_ssl.c @@ -5,6 +5,7 @@ #include "openssl.h" #include "pal_evp_pkey.h" #include "pal_evp_pkey_rsa.h" +#include "pal_utilities.h" #include "pal_x509.h" #include @@ -751,6 +752,33 @@ int32_t CryptoNative_SslCtxSetEncryptionPolicy(SSL_CTX* ctx, EncryptionPolicy po return false; } +static int DefaultOcspCallback(SSL* ssl, void* args) +{ + (void)args; + int ret = SSL_TLSEXT_ERR_NOACK; + + if (ssl != NULL) + { + uint8_t* resp; + long len = SSL_get_tlsext_status_ocsp_resp(ssl, &resp); + + // If we've already provided the stapled data, say so. + if (len > 0 && resp != NULL) + { + ret = SSL_TLSEXT_ERR_OK; + } + } + + return ret; +} + +void CryptoNative_SslCtxSetDefaultOcspCallback(SSL_CTX* ctx) +{ + assert(ctx != NULL); + + SSL_CTX_set_tlsext_status_cb(ctx, DefaultOcspCallback); +} + int32_t CryptoNative_SslCtxSetCiphers(SSL_CTX* ctx, const char* cipherList, const char* cipherSuites) { ERR_clear_error(); @@ -1218,3 +1246,22 @@ int32_t CryptoNative_OpenSslGetProtocolSupport(SslProtocols protocol) return ret == 1; } + +void CryptoNative_SslStapleOcsp(SSL* ssl, uint8_t* buf, int32_t len) +{ + assert(ssl != NULL); + assert(buf != NULL); + assert(len > 0); + + // OpenSSL's cleanup of the SSL structure will always call OPENSSL_free on + // the pointer we provide for the OCSP response, so we need to freshly + // duplicate it here, using an OpenSSL allocator. + size_t size = Int32ToSizeT(len); + void* copy = OPENSSL_malloc(size); + memcpy(copy, buf, size); + + if (SSL_set_tlsext_status_ocsp_resp(ssl, copy, len) != 1) + { + OPENSSL_free(copy); + } +} diff --git a/src/native/libs/System.Security.Cryptography.Native/pal_ssl.h b/src/native/libs/System.Security.Cryptography.Native/pal_ssl.h index 2a6cb9881ee54..b968fd59753c4 100644 --- a/src/native/libs/System.Security.Cryptography.Native/pal_ssl.h +++ b/src/native/libs/System.Security.Cryptography.Native/pal_ssl.h @@ -405,6 +405,11 @@ Sets the specified encryption policy on the SSL_CTX. */ PALEXPORT int32_t CryptoNative_SslCtxSetEncryptionPolicy(SSL_CTX* ctx, EncryptionPolicy policy); +/* +Activates the default OCSP stapling callback. +*/ +PALEXPORT void CryptoNative_SslCtxSetDefaultOcspCallback(SSL_CTX* ctx); + /* Sets ciphers (< TLS 1.3) and cipher suites (TLS 1.3) on the SSL_CTX */ @@ -493,3 +498,8 @@ PALEXPORT const char* CryptoNative_GetOpenSslCipherSuiteName(SSL* ssl, int32_t c Checks if given protocol version is supported. */ PALEXPORT int32_t CryptoNative_OpenSslGetProtocolSupport(SslProtocols protocol); + +/* +Staples an encoded OCSP response onto the TLS session +*/ +PALEXPORT void CryptoNative_SslStapleOcsp(SSL* ssl, uint8_t* buf, int32_t len); diff --git a/src/native/libs/System.Security.Cryptography.Native/pal_x509.c b/src/native/libs/System.Security.Cryptography.Native/pal_x509.c index 40f959e931568..c238fe012f8cc 100644 --- a/src/native/libs/System.Security.Cryptography.Native/pal_x509.c +++ b/src/native/libs/System.Security.Cryptography.Native/pal_x509.c @@ -881,12 +881,13 @@ static time_t GetIssuanceWindowStart() return t; } -static X509VerifyStatusCode CheckOcsp(OCSP_REQUEST* req, - OCSP_RESPONSE* resp, - X509* subject, - X509* issuer, - X509_STORE_CTX* storeCtx, - int* canCache) +static X509VerifyStatusCode CheckOcspGetExpiry(OCSP_REQUEST* req, + OCSP_RESPONSE* resp, + X509* subject, + X509* issuer, + X509_STORE_CTX* storeCtx, + int* canCache, + time_t* expiry) { assert(resp != NULL); assert(subject != NULL); @@ -979,6 +980,23 @@ static X509VerifyStatusCode CheckOcsp(OCSP_REQUEST* req, if (X509_cmp_time(thisupd, &oldest) > 0) { *canCache = 1; + + if (expiry != NULL) + { + struct tm updTm = { 0 }; + + if (nextupd != NULL && ASN1_TIME_to_tm(nextupd, &updTm) == 1) + { + *expiry = timegm(&updTm); + } + else if (ASN1_TIME_to_tm(thisupd, &updTm) == 1) + { + // If we're doing server side OCSP stapling and the response + // has no nextUpd, treat it as a 24-hour expiration for refresh + // purposes. + *expiry = timegm(&updTm) + (24 * 60 * 60); + } + } } } } @@ -995,6 +1013,16 @@ static X509VerifyStatusCode CheckOcsp(OCSP_REQUEST* req, return ret; } +static X509VerifyStatusCode CheckOcsp(OCSP_REQUEST* req, + OCSP_RESPONSE* resp, + X509* subject, + X509* issuer, + X509_STORE_CTX* storeCtx, + int* canCache) +{ + return CheckOcspGetExpiry(req, resp, subject, issuer, storeCtx, canCache, NULL); +} + static int Get0CertAndIssuer(X509_STORE_CTX* storeCtx, int chainDepth, X509** subject, X509** issuer) { assert(storeCtx != NULL); @@ -1104,23 +1132,8 @@ int32_t CryptoNative_X509ChainGetCachedOcspStatus(X509_STORE_CTX* storeCtx, char return (int32_t)ret; } -OCSP_REQUEST* CryptoNative_X509ChainBuildOcspRequest(X509_STORE_CTX* storeCtx, int chainDepth) +static OCSP_REQUEST* BuildOcspRequest(X509* subject, X509* issuer) { - if (storeCtx == NULL) - { - return NULL; - } - - ERR_clear_error(); - - X509* subject; - X509* issuer; - - if (!Get0CertAndIssuer(storeCtx, chainDepth, &subject, &issuer)) - { - return NULL; - } - OCSP_CERTID* certId = MakeCertId(subject, issuer); if (certId == NULL) @@ -1151,6 +1164,35 @@ OCSP_REQUEST* CryptoNative_X509ChainBuildOcspRequest(X509_STORE_CTX* storeCtx, i return req; } +OCSP_REQUEST* CryptoNative_X509BuildOcspRequest(X509* subject, X509* issuer) +{ + assert(subject != NULL); + assert(issuer != NULL); + + ERR_clear_error(); + return BuildOcspRequest(subject, issuer); +} + +OCSP_REQUEST* CryptoNative_X509ChainBuildOcspRequest(X509_STORE_CTX* storeCtx, int chainDepth) +{ + if (storeCtx == NULL) + { + return NULL; + } + + ERR_clear_error(); + + X509* subject; + X509* issuer; + + if (!Get0CertAndIssuer(storeCtx, chainDepth, &subject, &issuer)) + { + return NULL; + } + + return BuildOcspRequest(subject, issuer); +} + static int32_t X509ChainVerifyOcsp(X509_STORE_CTX* storeCtx, X509* subject, X509* issuer, OCSP_REQUEST* req, OCSP_RESPONSE* resp, char* cachePath) { X509VerifyStatusCode ret = PAL_X509_V_ERR_UNABLE_TO_GET_CRL; @@ -1237,3 +1279,80 @@ CryptoNative_X509ChainVerifyOcsp(X509_STORE_CTX* storeCtx, OCSP_REQUEST* req, OC return X509ChainVerifyOcsp(storeCtx, subject, issuer, req, resp, cachePath); } + +int32_t CryptoNative_X509DecodeOcspToExpiration(const uint8_t* buf, int32_t len, OCSP_REQUEST* req, X509* subject, X509* issuer, int64_t* expiration) +{ + ERR_clear_error(); + + if (buf == NULL || len == 0) + { + return 0; + } + + OCSP_RESPONSE* resp = d2i_OCSP_RESPONSE(NULL, &buf, len); + + if (resp == NULL) + { + return 0; + } + + X509_STORE* store = X509_STORE_new(); + X509_STORE_CTX* ctx = NULL; + X509Stack* bag = NULL; + + if (store != NULL) + { + bag = sk_X509_new_null(); + } + + if (bag != NULL) + { + if (X509_STORE_add_cert(store, issuer) && sk_X509_push(bag, issuer)) + { + ctx = X509_STORE_CTX_new(); + } + } + + int ret = 0; + + if (ctx != NULL) + { + if (X509_STORE_CTX_init(ctx, store, subject, bag) != 0) + { + int canCache = 0; + time_t expiration_t = 0; + X509VerifyStatusCode code = CheckOcspGetExpiry(req, resp, subject, issuer, ctx, &canCache, &expiration_t); + + if (sizeof(time_t) == sizeof(int64_t)) + { + *expiration = (int64_t)expiration_t; + } + else if (sizeof(time_t) == sizeof(int32_t)) + { + *expiration = (int32_t)expiration_t; + } + + if (code == PAL_X509_V_OK || code == PAL_X509_V_ERR_CERT_REVOKED) + { + ret = 1; + } + } + + X509_STORE_CTX_free(ctx); + } + + if (bag != NULL) + { + // Just free, not pop_free. + // We don't want to downref the issuer cert. + sk_X509_free(bag); + } + + if (store != NULL) + { + X509_STORE_free(store); + } + + OCSP_RESPONSE_free(resp); + return ret; +} diff --git a/src/native/libs/System.Security.Cryptography.Native/pal_x509.h b/src/native/libs/System.Security.Cryptography.Native/pal_x509.h index e6983349de3e0..0168f750114f4 100644 --- a/src/native/libs/System.Security.Cryptography.Native/pal_x509.h +++ b/src/native/libs/System.Security.Cryptography.Native/pal_x509.h @@ -377,6 +377,11 @@ determined by the chain in storeCtx. */ PALEXPORT int32_t CryptoNative_X509ChainGetCachedOcspStatus(X509_STORE_CTX* storeCtx, char* cachePath, int chainDepth); +/* +Build an OCSP request appropriate for the subject certificate (as issued by the issuer certificate) +*/ +PALEXPORT OCSP_REQUEST* CryptoNative_X509BuildOcspRequest(X509* subject, X509* issuer); + /* Build an OCSP request appropriate for the end-entity certificate using the issuer (and trust) as determined by the chain in storeCtx. @@ -397,3 +402,9 @@ PALEXPORT int32_t CryptoNative_X509ChainVerifyOcsp(X509_STORE_CTX* storeCtx, OCSP_RESPONSE* resp, char* cachePath, int chainDepth); + +/* +Decode len bytes of buf into an OCSP response, process it against the OCSP request, and return if the bytes were valid. +If the bytes were valid, and the OCSP response had a nextUpdate value, assign it to expiration. +*/ +PALEXPORT int32_t CryptoNative_X509DecodeOcspToExpiration(const uint8_t* buf, int32_t len, OCSP_REQUEST* req, X509* subject, X509* issuer, int64_t* expiration);