Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harden and re-enable DualMode socket tests #80715

Merged
merged 2 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 127 additions & 2 deletions src/libraries/Common/tests/System/Net/Sockets/SocketTestExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,16 @@ public static (Socket client, Socket server) CreateConnectedSocketPair(bool ipv6
{
IPAddress serverAddress = ipv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback;

using Socket listener = new Socket(serverAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
listener.Bind(new IPEndPoint(serverAddress, 0));
// PortBlocker creates a temporary socket of the opposite AddressFamily in the background, so parallel tests won't attempt
// to create their listener sockets on the same port, regardless of address family.
// This should prevent 'listener' from accepting DualMode connections of unrelated tests.
using PortBlocker portBlocker = new PortBlocker(() =>
{
Socket l = new Socket(serverAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
l.BindToAnonymousPort(serverAddress);
return l;
});
Socket listener = portBlocker.MainSocket; // PortBlocker shall dispose this
listener.Listen(1);

IPEndPoint connectTo = (IPEndPoint)listener.LocalEndPoint;
Expand Down Expand Up @@ -84,4 +92,121 @@ public static bool TryConnect(this Socket socket, EndPoint remoteEndpoint, int m
return false;
}
}

/// <summary>
/// A utility to create and bind a socket while blocking it's port for both IPv4 and IPv6
/// by also creating and binding a "shadow" socket of the opposite address family.
/// </summary>
internal class PortBlocker : IDisposable
{
private const int MaxAttempts = 16;
private Socket _shadowSocket;
public Socket MainSocket { get; }

public PortBlocker(Func<Socket> socketFactory)
{
bool success = false;
for (int i = 0; i < MaxAttempts; i++)
{
MainSocket = socketFactory();
if (MainSocket.LocalEndPoint is not IPEndPoint)
{
MainSocket.Dispose();
throw new Exception($"{nameof(socketFactory)} is expected create and bind the socket.");
}

IPAddress shadowAddress = MainSocket.AddressFamily == AddressFamily.InterNetwork ?
IPAddress.IPv6Loopback :
IPAddress.Loopback;
int port = ((IPEndPoint)MainSocket.LocalEndPoint).Port;
IPEndPoint shadowEndPoint = new IPEndPoint(shadowAddress, port);

try
{
_shadowSocket = new Socket(shadowAddress.AddressFamily, MainSocket.SocketType, MainSocket.ProtocolType);
success = TryBindWithoutReuseAddress(_shadowSocket, shadowEndPoint, out _);

if (success) break;
}
catch (SocketException)
{
MainSocket.Dispose();
_shadowSocket?.Dispose();
}
}

if (!success)
{
throw new Exception($"Failed to create the 'shadow' (port blocker) socket in {MaxAttempts} attempts.");
}
}

public void Dispose()
{
MainSocket.Dispose();
_shadowSocket.Dispose();
}

// Socket.Bind() auto-enables SO_REUSEADDR on Unix to allow Bind() during TIME_WAIT to emulate Windows behavior, see SystemNative_Bind() in 'pal_networking.c'.
// To prevent other sockets from succesfully binding to the same port port, we need to avoid this logic when binding the shadow socket.
// This method is doing a custom P/Invoke to bind() on Unix to achieve that.
private static unsafe bool TryBindWithoutReuseAddress(Socket socket, IPEndPoint endPoint, out int port)
{
if (PlatformDetection.IsWindows)
{
try
{
socket.Bind(endPoint);
}
catch (SocketException)
{
port = default;
return false;
}

port = ((IPEndPoint)socket.LocalEndPoint).Port;
return true;
}

SocketAddress addr = endPoint.Serialize();
byte[] data = new byte[addr.Size];
for (int i = 0; i < data.Length; i++)
{
data[i] = addr[i];
}

fixed (byte* dataPtr = data)
{
int result = bind(socket.SafeHandle, (nint)dataPtr, (uint)data.Length);
if (result != 0)
{
port = default;
return false;
}
uint sockLen = (uint)data.Length;
result = getsockname(socket.SafeHandle, (nint)dataPtr, (IntPtr)(&sockLen));
if (result != 0)
{
port = default;
return false;
}

addr = new SocketAddress(endPoint.AddressFamily, (int)sockLen);
}

for (int i = 0; i < data.Length; i++)
{
addr[i] = data[i];
}

port = ((IPEndPoint)endPoint.Create(addr)).Port;
return true;

[Runtime.InteropServices.DllImport("libc", SetLastError = true)]
static extern int bind(SafeSocketHandle socket, IntPtr socketAddress, uint addrLen);

[Runtime.InteropServices.DllImport("libc", SetLastError = true)]
static extern int getsockname(SafeSocketHandle socket, IntPtr socketAddress, IntPtr addrLenPtr);
}
}
}
45 changes: 22 additions & 23 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,28 @@ await RetryHelper.ExecuteAsync(async () =>
}
}, maxAttempts: 10, retryWhen: e => e is XunitException);
}

[OuterLoop("Connection failure takes long on Windows.")]
[Fact]
public async Task Connect_WithoutListener_ThrowSocketExceptionWithAppropriateInfo()
{
using PortBlocker portBlocker = new PortBlocker(() =>
{
Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
socket.BindToAnonymousPort(IPAddress.Loopback);
return socket;
});
Socket a = portBlocker.MainSocket;
using Socket b = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

SocketException ex = await Assert.ThrowsAsync<SocketException>(() => ConnectAsync(b, a.LocalEndPoint));
Assert.Contains(Marshal.GetPInvokeErrorMessage(ex.NativeErrorCode), ex.Message);

if (UsesSync)
{
Assert.Contains(a.LocalEndPoint.ToString(), ex.Message);
}
}
}

public sealed class ConnectSync : Connect<SocketHelperArraySync>
Expand All @@ -215,29 +237,6 @@ public ConnectApm(ITestOutputHelper output) : base(output) {}
public sealed class ConnectTask : Connect<SocketHelperTask>
{
public ConnectTask(ITestOutputHelper output) : base(output) {}

[OuterLoop]
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/79820", TestPlatforms.Linux | TestPlatforms.Android)]
public static void Connect_ThrowSocketException_Success()
{
using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
int anonymousPort = socket.BindToAnonymousPort(IPAddress.Loopback);
IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, anonymousPort);
Assert.ThrowsAsync<SocketException>(() => socket.ConnectAsync(ep));
try
{
socket.Connect(ep);
Assert.Fail("Socket Connect should throw SocketException in this case.");
}
catch (SocketException ex)
{
Assert.Contains(Marshal.GetPInvokeErrorMessage(ex.NativeErrorCode), ex.Message);
Assert.Contains(ep.ToString(), ex.Message);
}
}
}
}

public sealed class ConnectEap : Connect<SocketHelperEap>
Expand Down
Loading