Skip to content

Commit

Permalink
Remove received bytes length check (dotnet#100619)
Browse files Browse the repository at this point in the history
* Remove received bytes length check

* delete whitespace

* fix typo

* Improve tests

* Update UdpClientTest.cs

* Update UdpClientTest.cs

---------

Co-authored-by: Anton Firszov <[email protected]>
  • Loading branch information
2 people authored and Ruihan-Yin committed May 30, 2024
1 parent 9e0a9e5 commit f06c460
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,7 @@ public byte[] EndReceive(IAsyncResult asyncResult, ref IPEndPoint? remoteEP)

// Because we don't return the actual length, we need to ensure the returned buffer
// has the appropriate length.
if (received < MaxUDPSize)
{
byte[] newBuffer = new byte[received];
Buffer.BlockCopy(_buffer, 0, newBuffer, 0, received);
return newBuffer;
}

return _buffer;
return _buffer.AsSpan(0, received).ToArray();
}

// Joins a multicast address group.
Expand Down Expand Up @@ -623,11 +616,7 @@ public Task<UdpReceiveResult> ReceiveAsync()
async Task<UdpReceiveResult> WaitAndWrap(Task<SocketReceiveFromResult> task)
{
SocketReceiveFromResult result = await task.ConfigureAwait(false);

byte[] buffer = result.ReceivedBytes < MaxUDPSize ?
_buffer.AsSpan(0, result.ReceivedBytes).ToArray() :
_buffer;

byte[] buffer = _buffer.AsSpan(0, result.ReceivedBytes).ToArray();
return new UdpReceiveResult(buffer, (IPEndPoint)result.RemoteEndPoint);
}
}
Expand All @@ -653,11 +642,7 @@ public ValueTask<UdpReceiveResult> ReceiveAsync(CancellationToken cancellationTo
async ValueTask<UdpReceiveResult> WaitAndWrap(ValueTask<SocketReceiveFromResult> task)
{
SocketReceiveFromResult result = await task.ConfigureAwait(false);

byte[] buffer = result.ReceivedBytes < MaxUDPSize ?
_buffer.AsSpan(0, result.ReceivedBytes).ToArray() :
_buffer;

byte[] buffer = _buffer.AsSpan(0, result.ReceivedBytes).ToArray();
return new UdpReceiveResult(buffer, (IPEndPoint)result.RemoteEndPoint);
}
}
Expand Down Expand Up @@ -845,14 +830,7 @@ public byte[] Receive([NotNull] ref IPEndPoint? remoteEP)

// because we don't return the actual length, we need to ensure the returned buffer
// has the appropriate length.

if (received < MaxUDPSize)
{
byte[] newBuffer = new byte[received];
Buffer.BlockCopy(_buffer, 0, newBuffer, 0, received);
return newBuffer;
}
return _buffer;
return _buffer.AsSpan(0, received).ToArray();
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -517,11 +518,12 @@ public void Send_Receive_Success(bool ipv4)
using (var receiver = new UdpClient(new IPEndPoint(address, 0)))
using (var sender = new UdpClient(new IPEndPoint(address, 0)))
{
sender.Send(new byte[1], 1, new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port));
AssertReceive(receiver);
byte[] data = [1, 2, 3];
sender.Send(data, 2, new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port));
AssertReceive(receiver, [1, 2]);

sender.Send(new ReadOnlySpan<byte>(new byte[1]), new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port));
AssertReceive(receiver);
sender.Send(new ReadOnlySpan<byte>(data), new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port));
AssertReceive(receiver, data);
}
}

Expand All @@ -536,11 +538,12 @@ public void Send_Receive_With_HostName_Success(bool ipv4)
using (var receiver = new UdpClient(new IPEndPoint(address, 0)))
using (var sender = new UdpClient(new IPEndPoint(address, 0)))
{
sender.Send(new byte[1], 1, "localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port);
AssertReceive(receiver);
byte[] data = [1, 2, 3];
sender.Send(data, 2, "localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port);
AssertReceive(receiver, [1, 2]);

sender.Send(new ReadOnlySpan<byte>(new byte[1]), "localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port);
AssertReceive(receiver);
sender.Send(new ReadOnlySpan<byte>(data), "localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port);
AssertReceive(receiver, data);
}
}

Expand All @@ -551,20 +554,22 @@ public void Send_Receive_Connected_Success()
using (var receiver = new UdpClient("localhost", 0))
using (var sender = new UdpClient("localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port))
{
sender.Send(new byte[1], 1);
AssertReceive(receiver);
byte[] data = [1, 2, 3];

sender.Send(data, 2);
AssertReceive(receiver, [1, 2]);

sender.Send(new ReadOnlySpan<byte>(new byte[1]));
AssertReceive(receiver);
sender.Send(new ReadOnlySpan<byte>(data));
AssertReceive(receiver, data);
}
}

private static void AssertReceive(UdpClient receiver)
private static void AssertReceive(UdpClient receiver, byte[] sentData)
{
IPEndPoint remoteEP = null;
byte[] data = receiver.Receive(ref remoteEP);
Assert.NotNull(remoteEP);
Assert.InRange(data.Length, 1, int.MaxValue);
Assert.True(Enumerable.SequenceEqual(sentData, data));
}

[Theory]
Expand All @@ -589,32 +594,35 @@ public void Send_Available_Success(bool ipv4)
public void BeginEndSend_BeginEndReceive_Success(bool ipv4)
{
IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback;
byte[] data = [1, 2, 3];

using (var receiver = new UdpClient(new IPEndPoint(address, 0)))
using (var sender = new UdpClient(new IPEndPoint(address, 0)))
{
sender.EndSend(sender.BeginSend(new byte[1], 1, new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port), null, null));
sender.EndSend(sender.BeginSend(data, 2, new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port), null, null));

IPEndPoint remoteEP = null;
byte[] data = receiver.EndReceive(receiver.BeginReceive(null, null), ref remoteEP);
byte[] receivedData = receiver.EndReceive(receiver.BeginReceive(null, null), ref remoteEP);
Assert.NotNull(remoteEP);
Assert.InRange(data.Length, 1, int.MaxValue);
Assert.True(Enumerable.SequenceEqual(receivedData, new byte[] {1, 2}));
}
}

[Fact]
[PlatformSpecific(TestPlatforms.Windows)] // "localhost" resolves to IPv4 & IPV6 on Windows, but may resolve to only one of those on Unix
public void BeginEndSend_BeginEndReceive_Connected_Success()
{
byte[] data = [1, 2, 3];

using (var receiver = new UdpClient("localhost", 0))
using (var sender = new UdpClient("localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port))
{
sender.EndSend(sender.BeginSend(new byte[1], 1, null, null));
sender.EndSend(sender.BeginSend(data, 2, null, null));

IPEndPoint remoteEP = null;
byte[] data = receiver.EndReceive(receiver.BeginReceive(null, null), ref remoteEP);
byte[] receivedData = receiver.EndReceive(receiver.BeginReceive(null, null), ref remoteEP);
Assert.NotNull(remoteEP);
Assert.InRange(data.Length, 1, int.MaxValue);
Assert.True(Enumerable.SequenceEqual(receivedData, new byte[] {1, 2}));
}
}

Expand All @@ -624,15 +632,16 @@ public void BeginEndSend_BeginEndReceive_Connected_Success()
public async Task SendAsync_ReceiveAsync_Success(bool ipv4)
{
IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback;
byte[] data = [1, 2, 3];

using (var receiver = new UdpClient(new IPEndPoint(address, 0)))
using (var sender = new UdpClient(new IPEndPoint(address, 0)))
{
await sender.SendAsync(new byte[1], 1, new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port));
await AssertReceiveAsync(receiver);
await sender.SendAsync(data, 2, new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port));
await AssertReceiveAsync(receiver, [1, 2]);

await sender.SendAsync(new ReadOnlyMemory<byte>(new byte[1]), new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port));
await AssertReceiveAsync(receiver);
await sender.SendAsync(new ReadOnlyMemory<byte>(data), new IPEndPoint(address, ((IPEndPoint)receiver.Client.LocalEndPoint).Port));
await AssertReceiveAsync(receiver, data);
}
}

Expand All @@ -643,15 +652,16 @@ public async Task SendAsync_ReceiveAsync_Success(bool ipv4)
public async Task SendAsync_ReceiveAsync_With_HostName_Success(bool ipv4)
{
IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback;
byte[] data = [1, 2, 3];

using (var receiver = new UdpClient(new IPEndPoint(address, 0)))
using (var sender = new UdpClient(new IPEndPoint(address, 0)))
{
await sender.SendAsync(new byte[1], "localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port);
await AssertReceiveAsync(receiver);
await sender.SendAsync(data, "localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port);
await AssertReceiveAsync(receiver, data);

await sender.SendAsync(new ReadOnlyMemory<byte>(new byte[1]), "localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port);
await AssertReceiveAsync(receiver);
await sender.SendAsync(new ReadOnlyMemory<byte>(data), "localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port);
await AssertReceiveAsync(receiver, data);
}
}

Expand All @@ -675,29 +685,30 @@ public async Task ReceiveAsync_Cancel_Throw(bool ipv4)
[PlatformSpecific(TestPlatforms.Windows)] // "localhost" resolves to IPv4 & IPV6 on Windows, but may resolve to only one of those on Unix
public async Task SendAsync_ReceiveAsync_Connected_Success()
{
byte[] data = [1, 2, 3];

using (var receiver = new UdpClient("localhost", 0))
using (var sender = new UdpClient("localhost", ((IPEndPoint)receiver.Client.LocalEndPoint).Port))
{
await sender.SendAsync(new byte[1], 1);
await AssertReceiveAsync(receiver);
await sender.SendAsync(data, 2);
await AssertReceiveAsync(receiver, [1, 2]);

await sender.SendAsync(new ReadOnlyMemory<byte>(new byte[1]));
await AssertReceiveAsync(receiver);
await sender.SendAsync(new ReadOnlyMemory<byte>(data));
await AssertReceiveAsync(receiver, data);

await sender.SendAsync(new ReadOnlyMemory<byte>(new byte[1]), null);
await AssertReceiveAsync(receiver);
await sender.SendAsync(new ReadOnlyMemory<byte>(data), null);
await AssertReceiveAsync(receiver, data);

await sender.SendAsync(new ReadOnlyMemory<byte>(new byte[1]), null, 0);
await AssertReceiveAsync(receiver);
await sender.SendAsync(new ReadOnlyMemory<byte>(data), null, 0);
await AssertReceiveAsync(receiver, data);
}
}

private static async Task AssertReceiveAsync(UdpClient receiver)
private static async Task AssertReceiveAsync(UdpClient receiver, byte[] sentData)
{
UdpReceiveResult result = await receiver.ReceiveAsync();
Assert.NotNull(result.RemoteEndPoint);
Assert.NotNull(result.Buffer);
Assert.InRange(result.Buffer.Length, 1, int.MaxValue);
Assert.True(Enumerable.SequenceEqual(sentData, result.Buffer));
}

[Fact]
Expand Down

0 comments on commit f06c460

Please sign in to comment.