Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/9.0.1xx] [Containers] Fix insecure registry handling to use the correct port for the HTTP protocol #44235

Merged
merged 4 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@ namespace Microsoft.NET.Build.Containers;
/// </summary>
internal sealed partial class FallbackToHttpMessageHandler : DelegatingHandler
{
private readonly string _registryName;
private readonly string _host;
private readonly int _port;
private readonly ILogger _logger;
private bool _fallbackToHttp;

public FallbackToHttpMessageHandler(string host, int port, HttpMessageHandler innerHandler, ILogger logger) : base(innerHandler)
public FallbackToHttpMessageHandler(string registryName, string host, int port, HttpMessageHandler innerHandler, ILogger logger)
: base(innerHandler)
{
_registryName = registryName;
_host = host;
_port = port;
_logger = logger;
Expand All @@ -38,7 +41,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
{
if (canFallback && _fallbackToHttp)
{
FallbackToHttp(request);
FallbackToHttp(_registryName, request);
canFallback = false;
}

Expand All @@ -51,7 +54,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
{
// Try falling back.
_logger.LogTrace("Attempt to fall back to http for {uri}.", uri);
FallbackToHttp(request);
FallbackToHttp(_registryName, request);
HttpResponseMessage response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false);

// Fall back was successful. Use http for all new requests.
Expand All @@ -76,10 +79,22 @@ internal static bool ShouldAttemptFallbackToHttp(HttpRequestException exception)
return exception.HttpRequestError == HttpRequestError.SecureConnectionError;
}

private static void FallbackToHttp(HttpRequestMessage request)
private static bool RegistryNameContainsPort(string registryName)
{
// use `container` scheme which does not have a default port.
return new Uri($"container://{registryName}").Port != -1;
}

private static void FallbackToHttp(string registryName, HttpRequestMessage request)
{
var uriBuilder = new UriBuilder(request.RequestUri!);
uriBuilder.Scheme = "http";
if (RegistryNameContainsPort(registryName) == false)
{
// registeryName does not contains port number, so reset the port number to -1, otherwise it will be https default port 443
uriBuilder.Port = -1;
}

request.RequestUri = uriBuilder.Uri;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ internal DefaultRegistryAPI(string registryName, Uri baseUri, bool isInsecureReg

private static HttpClient CreateClient(string registryName, Uri baseUri, ILogger logger, bool isInsecureRegistry, RegistryMode mode)
{
HttpMessageHandler innerHandler = CreateHttpHandler(baseUri, isInsecureRegistry, logger);
HttpMessageHandler innerHandler = CreateHttpHandler(registryName, baseUri, isInsecureRegistry, logger);

HttpMessageHandler clientHandler = new AuthHandshakeMessageHandler(registryName, innerHandler, logger, mode);

Expand All @@ -56,7 +56,7 @@ private static HttpClient CreateClient(string registryName, Uri baseUri, ILogger
return client;
}

private static HttpMessageHandler CreateHttpHandler(Uri baseUri, bool allowInsecure, ILogger logger)
private static HttpMessageHandler CreateHttpHandler(string registryName, Uri baseUri, bool allowInsecure, ILogger logger)
{
var socketsHttpHandler = new SocketsHttpHandler()
{
Expand All @@ -75,7 +75,7 @@ private static HttpMessageHandler CreateHttpHandler(Uri baseUri, bool allowInsec
RemoteCertificateValidationCallback = IgnoreCertificateErrorsForSpecificHost(baseUri.Host)
};

return new FallbackToHttpMessageHandler(baseUri.Host, baseUri.Port, socketsHttpHandler, logger);
return new FallbackToHttpMessageHandler(registryName, baseUri.Host, baseUri.Port, socketsHttpHandler, logger);
}

private static RemoteCertificateValidationCallback IgnoreCertificateErrorsForSpecificHost(string host)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging.Abstractions;

namespace Microsoft.NET.Build.Containers.UnitTests
{
public class FallbackToHttpMessageHandlerTests
{
[Theory]
[InlineData("mcr.microsoft.com", 80)]
[InlineData("mcr.microsoft.com:443", 443)]
[InlineData("mcr.microsoft.com:80", 80)]
[InlineData("mcr.microsoft.com:5555", 5555)]
[InlineData("[2408:8120:245:49a0:f041:d7bb:bb13:5b64]", 80)]
[InlineData("[2408:8120:245:49a0:f041:d7bb:bb13:5b64]:443", 443)]
[InlineData("[2408:8120:245:49a0:f041:d7bb:bb13:5b64]:80", 80)]
[InlineData("[2408:8120:245:49a0:f041:d7bb:bb13:5b64]:5555", 5555)]
public async Task FallBackToHttpPortShouldAsExpected(string registry, int expectedPort)
{
var uri = new Uri($"https://{registry}");
var handler = new FallbackToHttpMessageHandler(
registry,
uri.Host,
uri.Port,
new ServerMessageHandler(request =>
{
// only accept http requests, reject https requests with a secure connection error

if (request.RequestUri!.Scheme == Uri.UriSchemeHttps)
{
throw new HttpRequestException(
httpRequestError: HttpRequestError.SecureConnectionError
);
}
else
{
return new HttpResponseMessage(HttpStatusCode.OK)
{
RequestMessage = request,
};
}
}),
NullLogger.Instance
);
using var httpClient = new HttpClient(handler);
var response = await httpClient.GetAsync(uri);
Assert.Equal(expectedPort, response.RequestMessage?.RequestUri?.Port);
}

private sealed class ServerMessageHandler : HttpMessageHandler
{
private readonly Func<HttpRequestMessage, HttpResponseMessage> _server;

public ServerMessageHandler(Func<HttpRequestMessage, HttpResponseMessage> server)
{
_server = server;
}

protected override Task<HttpResponseMessage> SendAsync(
HttpRequestMessage request,
CancellationToken cancellationToken
)
{
return Task.FromResult(_server(request));
}
}
}
}