Skip to content

Commit

Permalink
Fixed some query cancellation implementation details (#124)
Browse files Browse the repository at this point in the history
Fixed one bug where the wrong cancellation token was passed to the async query impl...
Removed the custom cancellation callback to dispose UdpClient for example, can now use the cancellation token + a callback registration.
Includes other fix from #121
  • Loading branch information
MichaCo authored Jun 12, 2021
1 parent 314f0a7 commit de7fc4c
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 39 deletions.
2 changes: 1 addition & 1 deletion samples/MiniDig/RandomCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected override async Task<int> Execute()
_runSync = SyncArg.HasValue();

_settings = GetLookupSettings();
_settings.EnableAuditTrail = true;
_settings.EnableAuditTrail = false;
_settings.ThrowDnsErrors = false;
_settings.ContinueOnDnsError = false;
_lookup = GetDnsLookup(_settings);
Expand Down
1 change: 0 additions & 1 deletion src/DnsClient/DnsMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ internal abstract class DnsMessageHandler
public abstract Task<DnsResponseMessage> QueryAsync(
IPEndPoint endpoint,
DnsRequestMessage request,
Action<Action> cancelationCallback,
CancellationToken cancellationToken);

// Transient errors will be retried on the same NameServer before the resolver moves on
Expand Down
10 changes: 4 additions & 6 deletions src/DnsClient/DnsTcpMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,18 @@ public override DnsResponseMessage Query(IPEndPoint endpoint, DnsRequestMessage
{
using (var cts = new CancellationTokenSource(timeout))
{
Action onCancel = () => { };
return QueryAsync(endpoint, request, (s) => onCancel = s, cts.Token)
.WithCancellation(onCancel, cts.Token)
return QueryAsync(endpoint, request, cts.Token)
.WithCancellation(cts.Token)
.ConfigureAwait(false).GetAwaiter().GetResult();
}
}

return QueryAsync(endpoint, request, (s) => { }, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult();
return QueryAsync(endpoint, request, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult();
}

public override async Task<DnsResponseMessage> QueryAsync(
IPEndPoint server,
DnsRequestMessage request,
Action<Action> cancelationCallback,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
Expand All @@ -48,7 +46,7 @@ public override async Task<DnsResponseMessage> QueryAsync(
_pools.TryAdd(server, new ClientPool(true, server));
}

cancelationCallback(() =>
using var cancelCallback = cancellationToken.Register(() =>
{
if (entry == null)
{
Expand Down
20 changes: 11 additions & 9 deletions src/DnsClient/DnsUdpMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,24 @@ public override DnsResponseMessage Query(
public override async Task<DnsResponseMessage> QueryAsync(
IPEndPoint endpoint,
DnsRequestMessage request,
Action<Action> cancelationCallback,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

UdpClient udpClient = GetNextUdpClient(endpoint.AddressFamily);

bool mustDispose = false;
try
using var callback = cancellationToken.Register(() =>
{
// setup timeout cancellation, dispose socket (the only way to actually cancel the request in async...
cancelationCallback(() =>
{
#if !NET45
udpClient.Dispose();
udpClient.Dispose();
#else
udpClient.Close();
udpClient.Close();
#endif
});
});

bool mustDispose = false;
try
{
using (var writer = new DnsDatagramWriter())
{
GetRequestData(request, writer);
Expand All @@ -128,6 +126,10 @@ public override async Task<DnsResponseMessage> QueryAsync(
return response;
}
}
catch (SocketException se) when (se.SocketErrorCode == SocketError.OperationAborted)
{
throw new TimeoutException();
}
catch (ObjectDisposedException)
{
// we disposed it in case of a timeout request, lets indicate it actually timed out...
Expand Down
22 changes: 11 additions & 11 deletions src/DnsClient/LookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1097,15 +1097,6 @@ private async Task<IDnsQueryResponse> ResolveQueryAsync(
audit?.StartTimer();

DnsResponseMessage response;
Action onCancel = () => { };
Task<DnsResponseMessage> resultTask = handler.QueryAsync(
serverInfo.IPEndPoint,
request,
(cancel) =>
{
onCancel = cancel;
},
cancellationToken);

if (settings.Timeout != System.Threading.Timeout.InfiniteTimeSpan
|| (cancellationToken != CancellationToken.None && cancellationToken.CanBeCanceled))
Expand All @@ -1119,12 +1110,21 @@ private async Task<IDnsQueryResponse> ResolveQueryAsync(
using (cts)
using (linkedCts)
{
response = await resultTask.WithCancellation(onCancel, (linkedCts ?? cts).Token).ConfigureAwait(false);
response = await handler.QueryAsync(
serverInfo.IPEndPoint,
request,
(linkedCts ?? cts).Token)
.WithCancellation((linkedCts ?? cts).Token)
.ConfigureAwait(false);
}
}
else
{
response = await resultTask.ConfigureAwait(false);
response = await handler.QueryAsync(
serverInfo.IPEndPoint,
request,
cancellationToken)
.ConfigureAwait(false);
}

lastQueryResponse = ProcessResponseMessage(
Expand Down
10 changes: 4 additions & 6 deletions src/DnsClient/TaskExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
{
internal static class TaskExtensions
{
public static async Task<T> WithCancellation<T>(this Task<T> task, Action onCancel, CancellationToken cancellationToken)
public static async Task<T> WithCancellation<T>(this Task<T> task, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<bool>();

using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs))
{
if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false))
{
try
{
onCancel();
}
catch { }
// observe the exception to avoid "System.AggregateException: A Task's exception(s) were
// not observed either by Waiting on the Task or accessing its Exception property."
_ = task.ContinueWith(t => t.Exception);
throw new OperationCanceledException(cancellationToken);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ public override DnsResponseMessage Query(
public override Task<DnsResponseMessage> QueryAsync(
IPEndPoint server,
DnsRequestMessage request,
Action<Action> cancelationCallback,
CancellationToken cancellationToken)
{
// no need to run async here as we don't do any IO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ public override DnsResponseMessage Query(
public override Task<DnsResponseMessage> QueryAsync(
IPEndPoint server,
DnsRequestMessage request,
Action<Action> cancelationCallback,
CancellationToken cancellationToken)
{
return Task.FromResult(Query(server, request, Timeout.InfiniteTimeSpan));
Expand Down
1 change: 1 addition & 0 deletions test/DnsClient.Tests/DnsClient.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<PublicSign Condition=" '$(OS)' != 'Windows_NT' ">true</PublicSign>
<PackageId>DnsClient.Tests</PackageId>
<IsPackable>false</IsPackable>
<LangVersion>latest</LangVersion>
</PropertyGroup>

<PropertyGroup>
Expand Down
1 change: 0 additions & 1 deletion test/DnsClient.Tests/DnsResponseParsingTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ public override DnsResponseMessage Query(
public override Task<DnsResponseMessage> QueryAsync(
IPEndPoint server,
DnsRequestMessage request,
Action<Action> cancelationCallback,
CancellationToken cancellationToken)
{
// no need to run async here as we don't do any IO
Expand Down
2 changes: 1 addition & 1 deletion test/DnsClient.Tests/LookupClientRetryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ public override DnsResponseMessage Query(IPEndPoint endpoint, DnsRequestMessage
return _onQuery(endpoint, request);
}

public override Task<DnsResponseMessage> QueryAsync(IPEndPoint endpoint, DnsRequestMessage request, Action<Action> cancelationCallback, CancellationToken cancellationToken)
public override Task<DnsResponseMessage> QueryAsync(IPEndPoint endpoint, DnsRequestMessage request, CancellationToken cancellationToken)
{
return Task.FromResult(_onQuery(endpoint, request));
}
Expand Down
1 change: 0 additions & 1 deletion test/DnsClient.Tests/LookupConfigurationTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,6 @@ public override DnsResponseMessage Query(
public override Task<DnsResponseMessage> QueryAsync(
IPEndPoint server,
DnsRequestMessage request,
Action<Action> cancelationCallback,
CancellationToken cancellationToken)
{
LastServer = server;
Expand Down

0 comments on commit de7fc4c

Please sign in to comment.