Skip to content

Commit

Permalink
quic test improvements (#56043)
Browse files Browse the repository at this point in the history
* quic test improvements

* fix incorrect use of PassingTestTimeout

* feedback from review
  • Loading branch information
wfurt authored Jul 27, 2021
1 parent 4143ac3 commit 66d23a7
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<TargetFrameworks>$(NetCoreAppCurrent)</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
<Compile Include="System\IO\*" />
<Compile Include="System\IO\*.cs" />
</ItemGroup>
<ItemGroup>
<Compile Include="$(CommonTestPath)System\IO\ConnectedStreams.cs" Link="System\IO\ConnectedStreams.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ public static class TaskTimeoutExtensions
#region WaitAsync polyfills
// Test polyfills when targeting a platform that doesn't have these ConfigureAwait overloads on Task

public static Task WaitAsync(this Task task, int millisecondsTimeout) =>
WaitAsync(task, TimeSpan.FromMilliseconds(millisecondsTimeout), default);

public static Task WaitAsync(this Task task, TimeSpan timeout) =>
WaitAsync(task, timeout, default);

Expand All @@ -28,6 +31,9 @@ public async static Task WaitAsync(this Task task, TimeSpan timeout, Cancellatio
}
}

public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, int millisecondsTimeout) =>
WaitAsync(task, TimeSpan.FromMilliseconds(millisecondsTimeout), default);

public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, TimeSpan timeout) =>
WaitAsync(task, timeout, default);

Expand All @@ -48,6 +54,9 @@ public static async Task<TResult> WaitAsync<TResult>(this Task<TResult> task, Ti
public static async Task WhenAllOrAnyFailed(this Task[] tasks, int millisecondsTimeout) =>
await tasks.WhenAllOrAnyFailed().WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));

public static async Task WhenAllOrAnyFailed(Task t1, Task t2, int millisecondsTimeout) =>
await new Task[] {t1, t2}.WhenAllOrAnyFailed(millisecondsTimeout);

public static async Task WhenAllOrAnyFailed(this Task[] tasks)
{
try
Expand Down
74 changes: 34 additions & 40 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,19 @@ namespace System.Net.Quic.Tests
[ConditionalClass(typeof(QuicTestBase<MsQuicProviderFactory>), nameof(IsSupported))]
public class MsQuicTests : QuicTestBase<MsQuicProviderFactory>
{
readonly ITestOutputHelper _output;
private static ReadOnlyMemory<byte> s_data = Encoding.UTF8.GetBytes("Hello world!");

public MsQuicTests(ITestOutputHelper output)
{
_output = output;
}
public MsQuicTests(ITestOutputHelper output) : base(output) { }

[Fact]
public async Task UnidirectionalAndBidirectionalStreamCountsWork()
{
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Assert.Equal(100, serverConnection.GetRemoteAvailableBidirectionalStreamCount());
Assert.Equal(100, serverConnection.GetRemoteAvailableUnidirectionalStreamCount());
}
Expand All @@ -55,10 +51,10 @@ public async Task UnidirectionalAndBidirectionalChangeValues()
};

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Assert.Equal(100, clientConnection.GetRemoteAvailableBidirectionalStreamCount());
Assert.Equal(100, clientConnection.GetRemoteAvailableUnidirectionalStreamCount());
Assert.Equal(10, serverConnection.GetRemoteAvailableBidirectionalStreamCount());
Expand Down Expand Up @@ -112,10 +108,9 @@ public async Task ConnectWithCertificateChain()
};

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
ValueTask clientTask = clientConnection.ConnectAsync();

using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;
}

[Fact]
Expand Down Expand Up @@ -342,10 +337,10 @@ public async Task ConnectWithClientCertificate()
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
ValueTask clientTask = clientConnection.ConnectAsync();
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
// Verify functionality of the connections.
await PingPong(clientConnection, serverConnection);
// check we completed the client certificate verification.
Expand All @@ -359,10 +354,9 @@ public async Task WaitForAvailableUnidirectionStreamsAsyncWorks()
{
using QuicListener listener = CreateQuicListener(maxUnidirectionalStreams: 1);
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;
listener.Dispose();

// No stream opened yet, should return immediately.
Expand All @@ -387,9 +381,9 @@ public async Task WaitForAvailableBidirectionStreamsAsyncWorks()
using QuicListener listener = CreateQuicListener(maxBidirectionalStreams: 1);
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

// No stream opened yet, should return immediately.
Assert.True(clientConnection.WaitForAvailableBidirectionalStreamsAsync().IsCompletedSuccessfully);
Expand Down Expand Up @@ -425,16 +419,15 @@ public async Task SetListenerTimeoutWorksWithSmallTimeout()
};

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

await Assert.ThrowsAsync<QuicOperationAbortedException>(async () => await serverConnection.AcceptStreamAsync().AsTask().WaitAsync(TimeSpan.FromSeconds(100)));
}

[Theory]
[MemberData(nameof(WriteData))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/49157")]
public async Task WriteTests(int[][] writes, WriteType writeType)
{
await RunClientServer(
Expand Down Expand Up @@ -530,9 +523,10 @@ public async Task CallDifferentWriteMethodsWorks()
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;


ReadOnlyMemory<byte> helloWorld = Encoding.ASCII.GetBytes("Hello world!");
ReadOnlySequence<byte> ros = CreateReadOnlySequenceFromBytes(helloWorld.ToArray());
Expand Down Expand Up @@ -714,9 +708,9 @@ async Task GetStreamIdWithoutStartWorks()
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);
Expand All @@ -737,9 +731,9 @@ async Task GetStreamIdWithoutStartWorks()
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);
Expand Down Expand Up @@ -781,7 +775,7 @@ await Task.Run(async () =>
byte[] buffer = new byte[100];
QuicConnectionAbortedException ex = await Assert.ThrowsAsync<QuicConnectionAbortedException>(() => serverStream.ReadAsync(buffer).AsTask());
Assert.Equal(ExpectedErrorCode, ex.ErrorCode);
}).WaitAsync(TimeSpan.FromSeconds(5));
}).WaitAsync(TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds));
}

[Fact]
Expand All @@ -807,7 +801,7 @@ await Task.Run(async () =>

byte[] buffer = new byte[100];
await Assert.ThrowsAsync<QuicOperationAbortedException>(() => serverStream.ReadAsync(buffer).AsTask());
}).WaitAsync(TimeSpan.FromSeconds(5));
}).WaitAsync(TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ public abstract class QuicConnectionTests<T> : QuicTestBase<T>
{
const int ExpectedErrorCode = 1234;

public QuicConnectionTests(ITestOutputHelper output) : base(output) { }

[Fact]
public async Task TestConnect()
{
Expand Down Expand Up @@ -285,8 +287,14 @@ await RunClientServer(
}
}

public sealed class QuicConnectionTests_MockProvider : QuicConnectionTests<MockProviderFactory> { }
public sealed class QuicConnectionTests_MockProvider : QuicConnectionTests<MockProviderFactory>
{
public QuicConnectionTests_MockProvider(ITestOutputHelper output) : base(output) { }
}

[ConditionalClass(typeof(QuicTestBase<MsQuicProviderFactory>), nameof(QuicTestBase<MsQuicProviderFactory>.IsSupported))]
public sealed class QuicConnectionTests_MsQuicProvider : QuicConnectionTests<MsQuicProviderFactory> { }
public sealed class QuicConnectionTests_MsQuicProvider : QuicConnectionTests<MsQuicProviderFactory>
{
public QuicConnectionTests_MsQuicProvider(ITestOutputHelper output) : base(output) { }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;

namespace System.Net.Quic.Tests
{
public abstract class QuicListenerTests<T> : QuicTestBase<T>
where T : IQuicImplProviderFactory, new()
{
public QuicListenerTests(ITestOutputHelper output) : base(output) { }

[Fact]
public async Task Listener_Backlog_Success()
{
Expand All @@ -25,8 +28,14 @@ await Task.Run(async () =>
}
}

public sealed class QuicListenerTests_MockProvider : QuicListenerTests<MockProviderFactory> { }
public sealed class QuicListenerTests_MockProvider : QuicListenerTests<MockProviderFactory>
{
public QuicListenerTests_MockProvider(ITestOutputHelper output) : base(output) { }
}

[ConditionalClass(typeof(QuicTestBase<MsQuicProviderFactory>), nameof(QuicTestBase<MsQuicProviderFactory>.IsSupported))]
public sealed class QuicListenerTests_MsQuicProvider : QuicListenerTests<MsQuicProviderFactory> { }
public sealed class QuicListenerTests_MsQuicProvider : QuicListenerTests<MsQuicProviderFactory>
{
public QuicListenerTests_MsQuicProvider(ITestOutputHelper output) : base(output) { }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
using System.Collections.Generic;
using System.IO;
using System.IO.Tests;
using System.Net.Sockets;
using System.Net.Quic.Implementations;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;

namespace System.Net.Quic.Tests
{
Expand All @@ -23,11 +26,17 @@ public sealed class MsQuicQuicStreamConformanceTests : QuicStreamConformanceTest
protected override QuicImplementationProvider Provider => QuicImplementationProviders.MsQuic;
protected override bool UsableAfterCanceledReads => false;
protected override bool BlocksOnZeroByteReads => true;

public MsQuicQuicStreamConformanceTests(ITestOutputHelper output)
{
_output = output;
}
}

public abstract class QuicStreamConformanceTests : ConnectedStreamConformanceTests
{
public X509Certificate2 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
public ITestOutputHelper _output;

public bool RemoteCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
{
Expand Down Expand Up @@ -75,21 +84,31 @@ await WhenAllOrAnyFailed(
}),
Task.Run(async () =>
{
connection2 = new QuicConnection(
provider,
listener.ListenEndPoint,
GetSslClientAuthenticationOptions());
await connection2.ConnectAsync();
stream2 = connection2.OpenBidirectionalStream();
// OpenBidirectionalStream only allocates ID. We will force stream opening
// by Writing there and receiving data on the other side.
await stream2.WriteAsync(buffer);
try
{
connection2 = new QuicConnection(
provider,
listener.ListenEndPoint,
GetSslClientAuthenticationOptions());
await connection2.ConnectAsync();
stream2 = connection2.OpenBidirectionalStream();
// OpenBidirectionalStream only allocates ID. We will force stream opening
// by Writing there and receiving data on the other side.
await stream2.WriteAsync(buffer);
}
catch (Exception ex)
{
_output?.WriteLine($"Failed to {ex.Message}");
throw;
}
}));

// No need to keep the listener once we have connected connection and streams
listener.Dispose();

var result = new StreamPairWithOtherDisposables(stream1, stream2);
result.Disposables.Add(connection1);
result.Disposables.Add(connection2);
result.Disposables.Add(listener);

return result;
}
Expand Down
Loading

0 comments on commit 66d23a7

Please sign in to comment.