Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

H/3 Server Cert validation callback exception fix #55526

Merged
merged 2 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public CertificateCallbackMapper(Func<HttpRequestMessage, X509Certificate2?, X50
}
}

public static ValueTask<SslStream> EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken)
private static SslClientAuthenticationOptions SetUpRemoteCertificateValidationCallback(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request)
{
// If there's a cert validation callback, and if it came from HttpClientHandler,
// wrap the original delegate in order to change the sender to be the request message (expected by HttpClientHandler's delegate).
Expand All @@ -52,12 +52,13 @@ public static ValueTask<SslStream> EstablishSslConnectionAsync(SslClientAuthenti
};
}

// Create the SslStream, authenticate, and return it.
return EstablishSslConnectionAsyncCore(async, stream, sslOptions, cancellationToken);
return sslOptions;
}

private static async ValueTask<SslStream> EstablishSslConnectionAsyncCore(bool async, Stream stream, SslClientAuthenticationOptions sslOptions, CancellationToken cancellationToken)
public static async ValueTask<SslStream> EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken)
{
sslOptions = SetUpRemoteCertificateValidationCallback(sslOptions, request);

SslStream sslStream = new SslStream(stream);

try
Expand Down Expand Up @@ -104,8 +105,10 @@ private static async ValueTask<SslStream> EstablishSslConnectionAsyncCore(bool a
[SupportedOSPlatform("windows")]
[SupportedOSPlatform("linux")]
[SupportedOSPlatform("macos")]
public static async ValueTask<QuicConnection> ConnectQuicAsync(QuicImplementationProvider quicImplementationProvider, DnsEndPoint endPoint, SslClientAuthenticationOptions? clientAuthenticationOptions, CancellationToken cancellationToken)
public static async ValueTask<QuicConnection> ConnectQuicAsync(HttpRequestMessage request, QuicImplementationProvider quicImplementationProvider, DnsEndPoint endPoint, SslClientAuthenticationOptions clientAuthenticationOptions, CancellationToken cancellationToken)
{
clientAuthenticationOptions = SetUpRemoteCertificateValidationCallback(clientAuthenticationOptions, request);

QuicConnection con = new QuicConnection(quicImplementationProvider, endPoint, clientAuthenticationOptions);
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ private async ValueTask<Http3Connection> GetHttp3ConnectionAsync(HttpRequestMess
QuicConnection quicConnection;
try
{
quicConnection = await ConnectHelper.ConnectQuicAsync(Settings._quicImplementationProvider ?? QuicImplementationProviders.Default, new DnsEndPoint(authority.IdnHost, authority.Port), _sslOptionsHttp3, cancellationToken).ConfigureAwait(false);
quicConnection = await ConnectHelper.ConnectQuicAsync(request, Settings._quicImplementationProvider ?? QuicImplementationProviders.Default, new DnsEndPoint(authority.IdnHost, authority.Port), _sslOptionsHttp3!, cancellationToken).ConfigureAwait(false);
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,64 @@ public async Task ReservedFrameType_Throws()
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}

[Fact]
public async Task ServerCertificateCustomValidationCallback_Succeeds()
{
// Mock doesn't make use of cart validation callback.
if (UseQuicImplementationProvider == QuicImplementationProviders.Mock)
{
return;
}

HttpRequestMessage? callbackRequest = null;
int invocationCount = 0;

var httpClientHandler = CreateHttpClientHandler();
httpClientHandler.ServerCertificateCustomValidationCallback = (request, _, _, _) =>
{
callbackRequest = request;
++invocationCount;
return true;
};

using Http3LoopbackServer server = CreateHttp3LoopbackServer();
using HttpClient client = CreateHttpClient(httpClientHandler);

Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync();
await stream2.HandleRequestAsync();
});

var request = new HttpRequestMessage(HttpMethod.Get, server.Address);
request.Version = HttpVersion.Version30;
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

var response = await client.SendAsync(request);

response.EnsureSuccessStatusCode();
Assert.Equal(HttpVersion.Version30, response.Version);
Assert.Same(request, callbackRequest);
Assert.Equal(1, invocationCount);

// Second request, the callback shouldn't be hit at all.
callbackRequest = null;

request = new HttpRequestMessage(HttpMethod.Get, server.Address);
request.Version = HttpVersion.Version30;
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

response = await client.SendAsync(request);

response.EnsureSuccessStatusCode();
Assert.Equal(HttpVersion.Version30, response.Version);
Assert.Null(callbackRequest);
Assert.Equal(1, invocationCount);
}

[OuterLoop]
[ConditionalTheory(nameof(IsMsQuicSupported))]
[MemberData(nameof(InteropUris))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,10 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti
if (connection._remoteCertificateValidationCallback != null)
{
bool success = connection._remoteCertificateValidationCallback(connection, certificate, chain, sslPolicyErrors);
// Unset the callback to prevent multiple invocations of the callback per a single connection.
// Return the same value as the custom callback just did.
connection._remoteCertificateValidationCallback = (_, _, _, _) => success;

if (!success && NetEventSource.Log.IsEnabled())
NetEventSource.Error(state, $"[Connection#{state.GetHashCode()}] remote certificate rejected by verification callback");
return success ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
Expand Down