Skip to content

Commit

Permalink
fix ReceiveMessageFrom cancellation on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
antonfirsov committed Jan 19, 2021
1 parent aa1bf42 commit 4e4add1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ public ValueTask<SocketReceiveMessageFromResult> ReceiveMessageFromAsync(Memory<
{
throw new InvalidOperationException(SR.net_sockets_mustbind);
}

if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled<SocketReceiveMessageFromResult>(cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,25 +558,33 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc
_wsaRecvMsgWSABufferArrayPinned = GC.AllocateUninitializedArray<WSABuffer>(1, pinned: true);
}

Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
_singleBufferHandle = _buffer.Pin();
_singleBufferHandleState = SingleBufferHandleState.Set;
fixed (byte* bufferPtr = &MemoryMarshal.GetReference(_buffer.Span))
{
Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
_singleBufferHandleState = SingleBufferHandleState.InProcess;

_wsaRecvMsgWSABufferArrayPinned[0].Pointer = (IntPtr)bufferPtr;
_wsaRecvMsgWSABufferArrayPinned[0].Length = _count;
wsaRecvMsgWSABufferArray = _wsaRecvMsgWSABufferArrayPinned;
wsaRecvMsgWSABufferCount = 1;

_wsaRecvMsgWSABufferArrayPinned[0].Pointer = (IntPtr)_singleBufferHandle.Pointer;
_wsaRecvMsgWSABufferArrayPinned[0].Length = _count;
wsaRecvMsgWSABufferArray = _wsaRecvMsgWSABufferArrayPinned;
wsaRecvMsgWSABufferCount = 1;
return Core();
}
}
else
{
// Use the multi-buffer WSABuffer.
wsaRecvMsgWSABufferArray = _wsaBufferArrayPinned!;
wsaRecvMsgWSABufferCount = (uint)_bufferListInternal!.Count;

return Core();
}

// Fill in WSAMessageBuffer.
unsafe
// Fill in WSAMessageBuffer, run WSARecvMsg and process the IOCP result.
// Logic is in a separate method so we can share code between the (pinned) single buffer and the multi-buffer case
SocketError Core()
{
// Fill in WSAMessageBuffer.
Interop.Winsock.WSAMsg* pMessage = (Interop.Winsock.WSAMsg*)Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0);
pMessage->socketAddress = PtrSocketAddressBuffer;
pMessage->addressLength = (uint)_socketAddress.Size;
Expand All @@ -596,26 +604,26 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc
pMessage->controlBuffer.Length = _controlBufferPinned.Length;
}
pMessage->flags = _socketFlags;
}

NativeOverlapped* overlapped = AllocateNativeOverlapped();
try
{
SocketError socketError = socket.WSARecvMsg(
handle,
Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0),
out int bytesTransferred,
overlapped,
IntPtr.Zero);
NativeOverlapped* overlapped = AllocateNativeOverlapped();
try
{
SocketError socketError = socket.WSARecvMsg(
handle,
Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0),
out int bytesTransferred,
overlapped,
IntPtr.Zero);

return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken);
}
catch
{
_singleBufferHandleState = SingleBufferHandleState.None;
FreeNativeOverlapped(overlapped);
_singleBufferHandle.Dispose();
throw;
return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken);
}
catch
{
_singleBufferHandleState = SingleBufferHandleState.None;
FreeNativeOverlapped(overlapped);
_singleBufferHandle.Dispose();
throw;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,17 @@ public ReceiveFrom_CancellableTask(ITestOutputHelper output) : base(output) { }
public async Task WhenCanceled_Throws(IPAddress loopback, bool precanceled)
{
using var socket = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using var dummy = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
socket.BindToAnonymousPort(loopback);
dummy.BindToAnonymousPort(loopback);
Memory<byte> buffer = new byte[1];

CancellationTokenSource cts = new CancellationTokenSource();
if (precanceled) cts.Cancel();
else cts.CancelAfter(100);

OperationCanceledException ex = await Assert.ThrowsAnyAsync<OperationCanceledException>(
() => socket.ReceiveFromAsync(buffer, SocketFlags.None, GetGetDummyTestEndpoint(loopback.AddressFamily), cts.Token).AsTask())
() => socket.ReceiveFromAsync(buffer, SocketFlags.None, dummy.LocalEndPoint, cts.Token).AsTask())
.TimeoutAfter(10_000);
Assert.Equal(cts.Token, ex.CancellationToken);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -93,21 +94,19 @@ public ReceiveMessageFrom_CancellableTask(ITestOutputHelper output) : base(outpu
public async Task WhenCanceled_Throws(IPAddress loopback, bool precanceled)
{
using var socket = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using var dummy = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
socket.BindToAnonymousPort(loopback);
dummy.BindToAnonymousPort(loopback);
Memory<byte> buffer = new byte[1];

CancellationTokenSource cts = new CancellationTokenSource();
if (precanceled) cts.Cancel();
else cts.CancelAfter(100);

OperationCanceledException ex = await Assert.ThrowsAnyAsync<OperationCanceledException>(
() => socket.ReceiveMessageFromAsync(buffer, SocketFlags.None, GetGetDummyTestEndpoint(loopback.AddressFamily), cts.Token).AsTask())
() => socket.ReceiveMessageFromAsync(buffer, SocketFlags.None, dummy.LocalEndPoint, cts.Token).AsTask())
.TimeoutAfter(10_000);
Assert.Equal(cts.Token, ex.CancellationToken);

IPEndPoint GetGetDummyTestEndpoint(AddressFamily addressFamily = AddressFamily.InterNetwork) =>
addressFamily == AddressFamily.InterNetwork ?
new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1234) : new IPEndPoint(IPAddress.Parse("1:2:3::4"), 1234);
}
}

Expand All @@ -124,7 +123,7 @@ public ReceiveMessageFrom_Eap(ITestOutputHelper output) : base(output) { }
[InlineData(true, 2)]
public void ReceiveSentMessages_ReuseEventArgs_Success(bool ipv4, int bufferMode)
{
const int DatagramsToSend = 30;
const int DatagramsToSend = 5;
const int TimeoutMs = 30_000;

AddressFamily family;
Expand Down Expand Up @@ -156,34 +155,40 @@ public void ReceiveSentMessages_ReuseEventArgs_Success(bool ipv4, int bufferMode
sender.Bind(new IPEndPoint(loopback, 0));
saea.RemoteEndPoint = new IPEndPoint(any, 0);

Random random = new Random(0);
byte[] sendBuffer = new byte[1024];
random.NextBytes(sendBuffer);

for (int i = 0; i < DatagramsToSend; i++)
{
byte[] receiveBuffer = new byte[1024];
switch (bufferMode)
{
case 0: // single buffer
saea.SetBuffer(new byte[1024], 0, 1024);
saea.SetBuffer(receiveBuffer, 0, 1024);
break;
case 1: // single buffer in buffer list
saea.BufferList = new List<ArraySegment<byte>>
{
new ArraySegment<byte>(new byte[1024])
new ArraySegment<byte>(receiveBuffer)
};
break;
case 2: // multiple buffers in buffer list
saea.BufferList = new List<ArraySegment<byte>>
{
new ArraySegment<byte>(new byte[512]),
new ArraySegment<byte>(new byte[512])
new ArraySegment<byte>(receiveBuffer, 0, 512),
new ArraySegment<byte>(receiveBuffer, 512, 512)
};
break;
}

bool pending = receiver.ReceiveMessageFromAsync(saea);
sender.SendTo(new byte[1024], new IPEndPoint(loopback, port));
sender.SendTo(sendBuffer, new IPEndPoint(loopback, port));
if (pending) Assert.True(completed.Wait(TimeoutMs), "Expected operation to complete within timeout");
completed.Reset();

Assert.Equal(1024, saea.BytesTransferred);
AssertExtensions.SequenceEqual(sendBuffer, receiveBuffer);
Assert.Equal(sender.LocalEndPoint, saea.RemoteEndPoint);
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, saea.ReceiveMessageFromPacketInfo.Address);
}
Expand Down

0 comments on commit 4e4add1

Please sign in to comment.