diff --git a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs index 57ac10a63d4c..629a5a2f86dd 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs @@ -288,6 +288,8 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, TlsHandsh listenOptions.IsTls = true; listenOptions.Use(next => { + // Set the list of protocols from listen options + callbackOptions.HttpProtocols = listenOptions.Protocols; var middleware = new HttpsConnectionMiddleware(next, callbackOptions, loggerFactory); return middleware.OnConnectionAsync; }); diff --git a/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs index 46ae37421070..a3f5600ffef1 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs @@ -46,6 +46,7 @@ internal class HttpsConnectionMiddleware // The following fields are only set by TlsHandshakeCallbackOptions ctor. private readonly Func>? _tlsCallbackOptions; private readonly object? _tlsCallbackOptionsState; + private readonly HttpProtocols _httpProtocols; // Pool for cancellation tokens that cancel the handshake private readonly CancellationTokenSourcePool _ctsPool = new(); @@ -127,6 +128,7 @@ internal HttpsConnectionMiddleware( _tlsCallbackOptions = tlsCallbackOptions.OnConnection; _tlsCallbackOptionsState = tlsCallbackOptions.OnConnectionState; + _httpProtocols = ValidateAndNormalizeHttpProtocols(tlsCallbackOptions.HttpProtocols, _logger); _sslStreamFactory = s => new SslStream(s); } @@ -434,6 +436,11 @@ private static async ValueTask ServerOptionsCall var sslOptions = await middleware._tlsCallbackOptions!(callbackContext); feature.AllowDelayedClientCertificateNegotation = callbackContext.AllowDelayedClientCertificateNegotation; + // The callback didn't set ALPN so we will. + if (sslOptions.ApplicationProtocols == null) + { + ConfigureAlpn(sslOptions, middleware._httpProtocols); + } KestrelEventSource.Log.TlsHandshakeStart(context, sslOptions); return sslOptions; diff --git a/src/Servers/Kestrel/Core/src/TlsHandshakeCallbackOptions.cs b/src/Servers/Kestrel/Core/src/TlsHandshakeCallbackOptions.cs index c269005657d6..46531a887839 100644 --- a/src/Servers/Kestrel/Core/src/TlsHandshakeCallbackOptions.cs +++ b/src/Servers/Kestrel/Core/src/TlsHandshakeCallbackOptions.cs @@ -42,5 +42,8 @@ public TimeSpan HandshakeTimeout _handshakeTimeout = value != Timeout.InfiniteTimeSpan ? value : TimeSpan.MaxValue; } } + + // Copied from the ListenOptions to enable ALPN + internal HttpProtocols HttpProtocols { get; set; } } } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsConnectionMiddlewareTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsConnectionMiddlewareTests.cs index f76e84d5db69..2a22ccb143e3 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsConnectionMiddlewareTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/HttpsConnectionMiddlewareTests.cs @@ -793,6 +793,98 @@ void ConfigureListenOptions(ListenOptions listenOptions) await AssertConnectionResult(stream, true, expectedBody); } + [ConditionalFact] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "ALPN not supported")] + public async Task ServerOptionsSelectionCallback_SetsALPN() + { + static void ConfigureListenOptions(ListenOptions listenOptions) + { + listenOptions.UseHttps((_, _, _, _) => + ValueTask.FromResult(new SslServerAuthenticationOptions() + { + ServerCertificate = _x509Certificate2, + }), state: null); + } + + await using var server = new TestServer(context => Task.CompletedTask, + new TestServiceContext(LoggerFactory), ConfigureListenOptions); + + using var connection = server.CreateConnection(); + var stream = OpenSslStream(connection.Stream); + await stream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() + { + // Use a random host name to avoid the TLS session resumption cache. + TargetHost = Guid.NewGuid().ToString(), + ApplicationProtocols = new() { SslApplicationProtocol.Http2, SslApplicationProtocol.Http11, }, + }); + Assert.Equal(SslApplicationProtocol.Http2, stream.NegotiatedApplicationProtocol); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "ALPN not supported")] + public async Task TlsHandshakeCallbackOptionsOverload_SetsALPN() + { + static void ConfigureListenOptions(ListenOptions listenOptions) + { + listenOptions.UseHttps(new TlsHandshakeCallbackOptions() + { + OnConnection = context => + { + return ValueTask.FromResult(new SslServerAuthenticationOptions() + { + ServerCertificate = _x509Certificate2, + }); + } + }); + } + + await using var server = new TestServer(context => Task.CompletedTask, + new TestServiceContext(LoggerFactory), ConfigureListenOptions); + + using var connection = server.CreateConnection(); + var stream = OpenSslStream(connection.Stream); + await stream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() + { + // Use a random host name to avoid the TLS session resumption cache. + TargetHost = Guid.NewGuid().ToString(), + ApplicationProtocols = new() { SslApplicationProtocol.Http2, SslApplicationProtocol.Http11, }, + }); + Assert.Equal(SslApplicationProtocol.Http2, stream.NegotiatedApplicationProtocol); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "ALPN not supported")] + public async Task TlsHandshakeCallbackOptionsOverload_EmptyAlpnList_DisablesAlpn() + { + static void ConfigureListenOptions(ListenOptions listenOptions) + { + listenOptions.UseHttps(new TlsHandshakeCallbackOptions() + { + OnConnection = context => + { + return ValueTask.FromResult(new SslServerAuthenticationOptions() + { + ServerCertificate = _x509Certificate2, + ApplicationProtocols = new(), + }); + } + }); + } + + await using var server = new TestServer(context => Task.CompletedTask, + new TestServiceContext(LoggerFactory), ConfigureListenOptions); + + using var connection = server.CreateConnection(); + var stream = OpenSslStream(connection.Stream); + await stream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() + { + // Use a random host name to avoid the TLS session resumption cache. + TargetHost = Guid.NewGuid().ToString(), + ApplicationProtocols = new() { SslApplicationProtocol.Http2, SslApplicationProtocol.Http11, }, + }); + Assert.Equal(default, stream.NegotiatedApplicationProtocol); + } + [ConditionalFact] [OSSkipCondition(OperatingSystems.MacOSX | OperatingSystems.Linux, SkipReason = "Not supported yet.")] public async Task CanRenegotiateForClientCertificateOnPostIfDrained()