diff --git a/src/Containers/Microsoft.NET.Build.Containers/FallbackToHttpMessageHandler.cs b/src/Containers/Microsoft.NET.Build.Containers/FallbackToHttpMessageHandler.cs index 5f7754d58bb1..7fcb85735605 100644 --- a/src/Containers/Microsoft.NET.Build.Containers/FallbackToHttpMessageHandler.cs +++ b/src/Containers/Microsoft.NET.Build.Containers/FallbackToHttpMessageHandler.cs @@ -12,13 +12,16 @@ namespace Microsoft.NET.Build.Containers; /// internal sealed partial class FallbackToHttpMessageHandler : DelegatingHandler { + private readonly string _registryName; private readonly string _host; private readonly int _port; private readonly ILogger _logger; private bool _fallbackToHttp; - public FallbackToHttpMessageHandler(string host, int port, HttpMessageHandler innerHandler, ILogger logger) : base(innerHandler) + public FallbackToHttpMessageHandler(string registryName, string host, int port, HttpMessageHandler innerHandler, ILogger logger) + : base(innerHandler) { + _registryName = registryName; _host = host; _port = port; _logger = logger; @@ -38,7 +41,7 @@ protected override async Task SendAsync(HttpRequestMessage { if (canFallback && _fallbackToHttp) { - FallbackToHttp(request); + FallbackToHttp(_registryName, request); canFallback = false; } @@ -51,7 +54,7 @@ protected override async Task SendAsync(HttpRequestMessage { // Try falling back. _logger.LogTrace("Attempt to fall back to http for {uri}.", uri); - FallbackToHttp(request); + FallbackToHttp(_registryName, request); HttpResponseMessage response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false); // Fall back was successful. Use http for all new requests. @@ -76,10 +79,22 @@ internal static bool ShouldAttemptFallbackToHttp(HttpRequestException exception) return exception.HttpRequestError == HttpRequestError.SecureConnectionError; } - private static void FallbackToHttp(HttpRequestMessage request) + private static bool RegistryNameContainsPort(string registryName) + { + // use `container` scheme which does not have a default port. + return new Uri($"container://{registryName}").Port != -1; + } + + private static void FallbackToHttp(string registryName, HttpRequestMessage request) { var uriBuilder = new UriBuilder(request.RequestUri!); uriBuilder.Scheme = "http"; + if (RegistryNameContainsPort(registryName) == false) + { + // registeryName does not contains port number, so reset the port number to -1, otherwise it will be https default port 443 + uriBuilder.Port = -1; + } + request.RequestUri = uriBuilder.Uri; } } diff --git a/src/Containers/Microsoft.NET.Build.Containers/Registry/DefaultRegistryAPI.cs b/src/Containers/Microsoft.NET.Build.Containers/Registry/DefaultRegistryAPI.cs index dff5a8921cc2..7a5ae043e319 100644 --- a/src/Containers/Microsoft.NET.Build.Containers/Registry/DefaultRegistryAPI.cs +++ b/src/Containers/Microsoft.NET.Build.Containers/Registry/DefaultRegistryAPI.cs @@ -37,7 +37,7 @@ internal DefaultRegistryAPI(string registryName, Uri baseUri, bool isInsecureReg private static HttpClient CreateClient(string registryName, Uri baseUri, ILogger logger, bool isInsecureRegistry, RegistryMode mode) { - HttpMessageHandler innerHandler = CreateHttpHandler(baseUri, isInsecureRegistry, logger); + HttpMessageHandler innerHandler = CreateHttpHandler(registryName, baseUri, isInsecureRegistry, logger); HttpMessageHandler clientHandler = new AuthHandshakeMessageHandler(registryName, innerHandler, logger, mode); @@ -56,7 +56,7 @@ private static HttpClient CreateClient(string registryName, Uri baseUri, ILogger return client; } - private static HttpMessageHandler CreateHttpHandler(Uri baseUri, bool allowInsecure, ILogger logger) + private static HttpMessageHandler CreateHttpHandler(string registryName, Uri baseUri, bool allowInsecure, ILogger logger) { var socketsHttpHandler = new SocketsHttpHandler() { @@ -75,7 +75,7 @@ private static HttpMessageHandler CreateHttpHandler(Uri baseUri, bool allowInsec RemoteCertificateValidationCallback = IgnoreCertificateErrorsForSpecificHost(baseUri.Host) }; - return new FallbackToHttpMessageHandler(baseUri.Host, baseUri.Port, socketsHttpHandler, logger); + return new FallbackToHttpMessageHandler(registryName, baseUri.Host, baseUri.Port, socketsHttpHandler, logger); } private static RemoteCertificateValidationCallback IgnoreCertificateErrorsForSpecificHost(string host) diff --git a/test/Microsoft.NET.Build.Containers.UnitTests/FallbackToHttpMessageHandlerTests.cs b/test/Microsoft.NET.Build.Containers.UnitTests/FallbackToHttpMessageHandlerTests.cs new file mode 100644 index 000000000000..0c6933f94473 --- /dev/null +++ b/test/Microsoft.NET.Build.Containers.UnitTests/FallbackToHttpMessageHandlerTests.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.NET.Build.Containers.UnitTests +{ + public class FallbackToHttpMessageHandlerTests + { + [Theory] + [InlineData("mcr.microsoft.com", 80)] + [InlineData("mcr.microsoft.com:443", 443)] + [InlineData("mcr.microsoft.com:80", 80)] + [InlineData("mcr.microsoft.com:5555", 5555)] + [InlineData("[2408:8120:245:49a0:f041:d7bb:bb13:5b64]", 80)] + [InlineData("[2408:8120:245:49a0:f041:d7bb:bb13:5b64]:443", 443)] + [InlineData("[2408:8120:245:49a0:f041:d7bb:bb13:5b64]:80", 80)] + [InlineData("[2408:8120:245:49a0:f041:d7bb:bb13:5b64]:5555", 5555)] + public async Task FallBackToHttpPortShouldAsExpected(string registry, int expectedPort) + { + var uri = new Uri($"https://{registry}"); + var handler = new FallbackToHttpMessageHandler( + registry, + uri.Host, + uri.Port, + new ServerMessageHandler(request => + { + // only accept http requests, reject https requests with a secure connection error + + if (request.RequestUri!.Scheme == Uri.UriSchemeHttps) + { + throw new HttpRequestException( + httpRequestError: HttpRequestError.SecureConnectionError + ); + } + else + { + return new HttpResponseMessage(HttpStatusCode.OK) + { + RequestMessage = request, + }; + } + }), + NullLogger.Instance + ); + using var httpClient = new HttpClient(handler); + var response = await httpClient.GetAsync(uri); + Assert.Equal(expectedPort, response.RequestMessage?.RequestUri?.Port); + } + + private sealed class ServerMessageHandler : HttpMessageHandler + { + private readonly Func _server; + + public ServerMessageHandler(Func server) + { + _server = server; + } + + protected override Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken + ) + { + return Task.FromResult(_server(request)); + } + } + } +}