From de7fc4c99b04a4d4b4d7cf29bae03eeebea87201 Mon Sep 17 00:00:00 2001 From: MichaC Date: Sat, 12 Jun 2021 23:16:06 +0200 Subject: [PATCH] Fixed some query cancellation implementation details (#124) Fixed one bug where the wrong cancellation token was passed to the async query impl... Removed the custom cancellation callback to dispose UdpClient for example, can now use the cancellation token + a callback registration. Includes other fix from #121 --- samples/MiniDig/RandomCommand.cs | 2 +- src/DnsClient/DnsMessageHandler.cs | 1 - src/DnsClient/DnsTcpMessageHandler.cs | 10 ++++----- src/DnsClient/DnsUdpMessageHandler.cs | 20 +++++++++-------- src/DnsClient/LookupClient.cs | 22 +++++++++---------- src/DnsClient/TaskExtensions.cs | 10 ++++----- .../DnsClientBenchmarks.DatagramReader.cs | 1 - ...ClientBenchmarks.RequestResponseParsing.cs | 1 - test/DnsClient.Tests/DnsClient.Tests.csproj | 1 + .../DnsClient.Tests/DnsResponseParsingTest.cs | 1 - test/DnsClient.Tests/LookupClientRetryTest.cs | 2 +- .../LookupConfigurationTest.cs | 1 - 12 files changed, 33 insertions(+), 39 deletions(-) diff --git a/samples/MiniDig/RandomCommand.cs b/samples/MiniDig/RandomCommand.cs index a6f658f8..f6d38445 100644 --- a/samples/MiniDig/RandomCommand.cs +++ b/samples/MiniDig/RandomCommand.cs @@ -73,7 +73,7 @@ protected override async Task Execute() _runSync = SyncArg.HasValue(); _settings = GetLookupSettings(); - _settings.EnableAuditTrail = true; + _settings.EnableAuditTrail = false; _settings.ThrowDnsErrors = false; _settings.ContinueOnDnsError = false; _lookup = GetDnsLookup(_settings); diff --git a/src/DnsClient/DnsMessageHandler.cs b/src/DnsClient/DnsMessageHandler.cs index 0359e810..7813f98b 100644 --- a/src/DnsClient/DnsMessageHandler.cs +++ b/src/DnsClient/DnsMessageHandler.cs @@ -23,7 +23,6 @@ internal abstract class DnsMessageHandler public abstract Task QueryAsync( IPEndPoint endpoint, DnsRequestMessage request, - Action cancelationCallback, CancellationToken cancellationToken); // Transient errors will be retried on the same NameServer before the resolver moves on diff --git a/src/DnsClient/DnsTcpMessageHandler.cs b/src/DnsClient/DnsTcpMessageHandler.cs index 7e8725a3..0dd733ea 100644 --- a/src/DnsClient/DnsTcpMessageHandler.cs +++ b/src/DnsClient/DnsTcpMessageHandler.cs @@ -23,20 +23,18 @@ public override DnsResponseMessage Query(IPEndPoint endpoint, DnsRequestMessage { using (var cts = new CancellationTokenSource(timeout)) { - Action onCancel = () => { }; - return QueryAsync(endpoint, request, (s) => onCancel = s, cts.Token) - .WithCancellation(onCancel, cts.Token) + return QueryAsync(endpoint, request, cts.Token) + .WithCancellation(cts.Token) .ConfigureAwait(false).GetAwaiter().GetResult(); } } - return QueryAsync(endpoint, request, (s) => { }, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult(); + return QueryAsync(endpoint, request, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult(); } public override async Task QueryAsync( IPEndPoint server, DnsRequestMessage request, - Action cancelationCallback, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -48,7 +46,7 @@ public override async Task QueryAsync( _pools.TryAdd(server, new ClientPool(true, server)); } - cancelationCallback(() => + using var cancelCallback = cancellationToken.Register(() => { if (entry == null) { diff --git a/src/DnsClient/DnsUdpMessageHandler.cs b/src/DnsClient/DnsUdpMessageHandler.cs index 0fd27024..f1c92541 100644 --- a/src/DnsClient/DnsUdpMessageHandler.cs +++ b/src/DnsClient/DnsUdpMessageHandler.cs @@ -82,26 +82,24 @@ public override DnsResponseMessage Query( public override async Task QueryAsync( IPEndPoint endpoint, DnsRequestMessage request, - Action cancelationCallback, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); UdpClient udpClient = GetNextUdpClient(endpoint.AddressFamily); - bool mustDispose = false; - try + using var callback = cancellationToken.Register(() => { - // setup timeout cancellation, dispose socket (the only way to actually cancel the request in async... - cancelationCallback(() => - { #if !NET45 - udpClient.Dispose(); + udpClient.Dispose(); #else - udpClient.Close(); + udpClient.Close(); #endif - }); + }); + bool mustDispose = false; + try + { using (var writer = new DnsDatagramWriter()) { GetRequestData(request, writer); @@ -128,6 +126,10 @@ public override async Task QueryAsync( return response; } } + catch (SocketException se) when (se.SocketErrorCode == SocketError.OperationAborted) + { + throw new TimeoutException(); + } catch (ObjectDisposedException) { // we disposed it in case of a timeout request, lets indicate it actually timed out... diff --git a/src/DnsClient/LookupClient.cs b/src/DnsClient/LookupClient.cs index c2496db4..47191f1f 100644 --- a/src/DnsClient/LookupClient.cs +++ b/src/DnsClient/LookupClient.cs @@ -1097,15 +1097,6 @@ private async Task ResolveQueryAsync( audit?.StartTimer(); DnsResponseMessage response; - Action onCancel = () => { }; - Task resultTask = handler.QueryAsync( - serverInfo.IPEndPoint, - request, - (cancel) => - { - onCancel = cancel; - }, - cancellationToken); if (settings.Timeout != System.Threading.Timeout.InfiniteTimeSpan || (cancellationToken != CancellationToken.None && cancellationToken.CanBeCanceled)) @@ -1119,12 +1110,21 @@ private async Task ResolveQueryAsync( using (cts) using (linkedCts) { - response = await resultTask.WithCancellation(onCancel, (linkedCts ?? cts).Token).ConfigureAwait(false); + response = await handler.QueryAsync( + serverInfo.IPEndPoint, + request, + (linkedCts ?? cts).Token) + .WithCancellation((linkedCts ?? cts).Token) + .ConfigureAwait(false); } } else { - response = await resultTask.ConfigureAwait(false); + response = await handler.QueryAsync( + serverInfo.IPEndPoint, + request, + cancellationToken) + .ConfigureAwait(false); } lastQueryResponse = ProcessResponseMessage( diff --git a/src/DnsClient/TaskExtensions.cs b/src/DnsClient/TaskExtensions.cs index 12fac6c5..fb986b5b 100644 --- a/src/DnsClient/TaskExtensions.cs +++ b/src/DnsClient/TaskExtensions.cs @@ -2,7 +2,7 @@ { internal static class TaskExtensions { - public static async Task WithCancellation(this Task task, Action onCancel, CancellationToken cancellationToken) + public static async Task WithCancellation(this Task task, CancellationToken cancellationToken) { var tcs = new TaskCompletionSource(); @@ -10,11 +10,9 @@ public static async Task WithCancellation(this Task task, Action onCanc { if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false)) { - try - { - onCancel(); - } - catch { } + // observe the exception to avoid "System.AggregateException: A Task's exception(s) were + // not observed either by Waiting on the Task or accessing its Exception property." + _ = task.ContinueWith(t => t.Exception); throw new OperationCanceledException(cancellationToken); } } diff --git a/test-other/Benchmarks/DnsClientBenchmarks.DatagramReader.cs b/test-other/Benchmarks/DnsClientBenchmarks.DatagramReader.cs index b0c79b4e..08e6708b 100644 --- a/test-other/Benchmarks/DnsClientBenchmarks.DatagramReader.cs +++ b/test-other/Benchmarks/DnsClientBenchmarks.DatagramReader.cs @@ -138,7 +138,6 @@ public override DnsResponseMessage Query( public override Task QueryAsync( IPEndPoint server, DnsRequestMessage request, - Action cancelationCallback, CancellationToken cancellationToken) { // no need to run async here as we don't do any IO diff --git a/test-other/Benchmarks/DnsClientBenchmarks.RequestResponseParsing.cs b/test-other/Benchmarks/DnsClientBenchmarks.RequestResponseParsing.cs index fb1764c8..d931baf7 100644 --- a/test-other/Benchmarks/DnsClientBenchmarks.RequestResponseParsing.cs +++ b/test-other/Benchmarks/DnsClientBenchmarks.RequestResponseParsing.cs @@ -166,7 +166,6 @@ public override DnsResponseMessage Query( public override Task QueryAsync( IPEndPoint server, DnsRequestMessage request, - Action cancelationCallback, CancellationToken cancellationToken) { return Task.FromResult(Query(server, request, Timeout.InfiniteTimeSpan)); diff --git a/test/DnsClient.Tests/DnsClient.Tests.csproj b/test/DnsClient.Tests/DnsClient.Tests.csproj index 2303d9ef..8ae2ee17 100644 --- a/test/DnsClient.Tests/DnsClient.Tests.csproj +++ b/test/DnsClient.Tests/DnsClient.Tests.csproj @@ -9,6 +9,7 @@ true DnsClient.Tests false + latest diff --git a/test/DnsClient.Tests/DnsResponseParsingTest.cs b/test/DnsClient.Tests/DnsResponseParsingTest.cs index a878f8ad..1969e887 100644 --- a/test/DnsClient.Tests/DnsResponseParsingTest.cs +++ b/test/DnsClient.Tests/DnsResponseParsingTest.cs @@ -359,7 +359,6 @@ public override DnsResponseMessage Query( public override Task QueryAsync( IPEndPoint server, DnsRequestMessage request, - Action cancelationCallback, CancellationToken cancellationToken) { // no need to run async here as we don't do any IO diff --git a/test/DnsClient.Tests/LookupClientRetryTest.cs b/test/DnsClient.Tests/LookupClientRetryTest.cs index e604f451..93805ad6 100644 --- a/test/DnsClient.Tests/LookupClientRetryTest.cs +++ b/test/DnsClient.Tests/LookupClientRetryTest.cs @@ -1664,7 +1664,7 @@ public override DnsResponseMessage Query(IPEndPoint endpoint, DnsRequestMessage return _onQuery(endpoint, request); } - public override Task QueryAsync(IPEndPoint endpoint, DnsRequestMessage request, Action cancelationCallback, CancellationToken cancellationToken) + public override Task QueryAsync(IPEndPoint endpoint, DnsRequestMessage request, CancellationToken cancellationToken) { return Task.FromResult(_onQuery(endpoint, request)); } diff --git a/test/DnsClient.Tests/LookupConfigurationTest.cs b/test/DnsClient.Tests/LookupConfigurationTest.cs index 84609693..3684204a 100644 --- a/test/DnsClient.Tests/LookupConfigurationTest.cs +++ b/test/DnsClient.Tests/LookupConfigurationTest.cs @@ -1039,7 +1039,6 @@ public override DnsResponseMessage Query( public override Task QueryAsync( IPEndPoint server, DnsRequestMessage request, - Action cancelationCallback, CancellationToken cancellationToken) { LastServer = server;