Skip to content

Commit

Permalink
fix: time out and cancel propagation in managed client
Browse files Browse the repository at this point in the history
  • Loading branch information
marve committed May 3, 2024
1 parent 4f9e54b commit be89178
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 19 deletions.
52 changes: 34 additions & 18 deletions Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,13 @@ static TimeSpan GetRemainingTime(DateTime endTime)
return remainingTime < TimeSpan.Zero ? TimeSpan.Zero : remainingTime;
}

private CancellationTokenSource NewTimeoutToken(CancellationToken linkedToken)
{
var newTimeoutToken = CancellationTokenSource.CreateLinkedTokenSource(linkedToken);
newTimeoutToken.CancelAfter(Options.ClientOptions.Timeout);
return newTimeoutToken;
}

async Task HandleSubscriptionExceptionAsync(Exception exception, List<MqttTopicFilter> addedSubscriptions, List<string> removedSubscriptions)
{
_logger.Warning(exception, "Synchronizing subscriptions failed.");
Expand Down Expand Up @@ -411,7 +418,7 @@ async Task MaintainConnectionAsync(CancellationToken cancellationToken)
{
if (_isCleanDisconnect)
{
using (var disconnectTimeout = new CancellationTokenSource(Options.ClientOptions.Timeout))
using (var disconnectTimeout = NewTimeoutToken(CancellationToken.None))
{
await InternalClient.DisconnectAsync(new MqttClientDisconnectOptions(), disconnectTimeout.Token).ConfigureAwait(false);
}
Expand Down Expand Up @@ -461,7 +468,7 @@ async Task PublishQueuedMessagesAsync(CancellationToken cancellationToken)

cancellationToken.ThrowIfCancellationRequested();

await TryPublishQueuedMessageAsync(message).ConfigureAwait(false);
await TryPublishQueuedMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
}
catch (OperationCanceledException)
Expand All @@ -477,7 +484,7 @@ async Task PublishQueuedMessagesAsync(CancellationToken cancellationToken)
}
}

async Task PublishReconnectSubscriptionsAsync()
async Task PublishReconnectSubscriptionsAsync(CancellationToken cancellationToken)
{
_logger.Info("Publishing subscriptions at reconnect");

Expand All @@ -496,13 +503,13 @@ async Task PublishReconnectSubscriptionsAsync()

if (topicFilters.Count == Options.MaxTopicFiltersInSubscribeUnsubscribePackets)
{
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(topicFilters, null).ConfigureAwait(false);
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(topicFilters, null, cancellationToken).ConfigureAwait(false);
topicFilters.Clear();
await HandleSubscriptionsResultAsync(subscribeUnsubscribeResult).ConfigureAwait(false);
}
}

subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(topicFilters, null).ConfigureAwait(false);
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(topicFilters, null, cancellationToken).ConfigureAwait(false);
await HandleSubscriptionsResultAsync(subscribeUnsubscribeResult).ConfigureAwait(false);
}
}
Expand Down Expand Up @@ -555,13 +562,13 @@ async Task PublishSubscriptionsAsync(TimeSpan timeout, CancellationToken cancell

if (addedTopicFilters.Count == Options.MaxTopicFiltersInSubscribeUnsubscribePackets)
{
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(addedTopicFilters, null).ConfigureAwait(false);
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(addedTopicFilters, null, cancellationToken).ConfigureAwait(false);
addedTopicFilters.Clear();
await HandleSubscriptionsResultAsync(subscribeUnsubscribeResult).ConfigureAwait(false);
}
}

subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(addedTopicFilters, null).ConfigureAwait(false);
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(addedTopicFilters, null, cancellationToken).ConfigureAwait(false);
await HandleSubscriptionsResultAsync(subscribeUnsubscribeResult).ConfigureAwait(false);

var removedTopicFilters = new List<string>();
Expand All @@ -571,13 +578,13 @@ async Task PublishSubscriptionsAsync(TimeSpan timeout, CancellationToken cancell

if (removedTopicFilters.Count == Options.MaxTopicFiltersInSubscribeUnsubscribePackets)
{
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(null, removedTopicFilters).ConfigureAwait(false);
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(null, removedTopicFilters, cancellationToken).ConfigureAwait(false);
removedTopicFilters.Clear();
await HandleSubscriptionsResultAsync(subscribeUnsubscribeResult).ConfigureAwait(false);
}
}

subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(null, removedTopicFilters).ConfigureAwait(false);
subscribeUnsubscribeResult = await SendSubscribeUnsubscribe(null, removedTopicFilters, cancellationToken).ConfigureAwait(false);
await HandleSubscriptionsResultAsync(subscribeUnsubscribeResult).ConfigureAwait(false);
}
}
Expand All @@ -592,7 +599,7 @@ async Task<ReconnectionResult> ReconnectIfRequiredAsync(CancellationToken cancel
MqttClientConnectResult connectResult = null;
try
{
using (var connectTimeout = new CancellationTokenSource(Options.ClientOptions.Timeout))
using (var connectTimeout = NewTimeoutToken(cancellationToken))
{
connectResult = await InternalClient.ConnectAsync(Options.ClientOptions, connectTimeout.Token).ConfigureAwait(false);
}
Expand All @@ -611,7 +618,7 @@ async Task<ReconnectionResult> ReconnectIfRequiredAsync(CancellationToken cancel
}
}

async Task<SendSubscribeUnsubscribeResult> SendSubscribeUnsubscribe(List<MqttTopicFilter> addedSubscriptions, List<string> removedSubscriptions)
async Task<SendSubscribeUnsubscribeResult> SendSubscribeUnsubscribe(List<MqttTopicFilter> addedSubscriptions, List<string> removedSubscriptions, CancellationToken cancellationToken)
{
var subscribeResults = new List<MqttClientSubscribeResult>();
var unsubscribeResults = new List<MqttClientUnsubscribeResult>();
Expand All @@ -626,8 +633,11 @@ async Task<SendSubscribeUnsubscribeResult> SendSubscribeUnsubscribe(List<MqttTop
unsubscribeOptionsBuilder.WithTopicFilter(removedSubscription);
}

var unsubscribeResult = await InternalClient.UnsubscribeAsync(unsubscribeOptionsBuilder.Build()).ConfigureAwait(false);
unsubscribeResults.Add(unsubscribeResult);
using (var unsubscribeTimeout = NewTimeoutToken(cancellationToken))
{
var unsubscribeResult = await InternalClient.UnsubscribeAsync(unsubscribeOptionsBuilder.Build(), unsubscribeTimeout.Token).ConfigureAwait(false);
unsubscribeResults.Add(unsubscribeResult);
}

//clear because these worked, maybe the subscribe below will fail, only report those
removedSubscriptions.Clear();
Expand All @@ -642,8 +652,11 @@ async Task<SendSubscribeUnsubscribeResult> SendSubscribeUnsubscribe(List<MqttTop
subscribeOptionsBuilder.WithTopicFilter(addedSubscription);
}

var subscribeResult = await InternalClient.SubscribeAsync(subscribeOptionsBuilder.Build()).ConfigureAwait(false);
subscribeResults.Add(subscribeResult);
using (var subscribeTimeout = NewTimeoutToken(cancellationToken))
{
var subscribeResult = await InternalClient.SubscribeAsync(subscribeOptionsBuilder.Build(), subscribeTimeout.Token).ConfigureAwait(false);
subscribeResults.Add(subscribeResult);
}
}
}
catch (Exception exception)
Expand Down Expand Up @@ -705,7 +718,7 @@ async Task TryMaintainConnectionAsync(CancellationToken cancellationToken)
}
else if (connectionState == ReconnectionResult.Reconnected)
{
await PublishReconnectSubscriptionsAsync().ConfigureAwait(false);
await PublishReconnectSubscriptionsAsync(cancellationToken).ConfigureAwait(false);
StartPublishing();
}
else if (connectionState == ReconnectionResult.Recovered)
Expand Down Expand Up @@ -735,7 +748,7 @@ async Task TryMaintainConnectionAsync(CancellationToken cancellationToken)
}
}

async Task TryPublishQueuedMessageAsync(ManagedMqttApplicationMessage message)
async Task TryPublishQueuedMessageAsync(ManagedMqttApplicationMessage message, CancellationToken cancellationToken)
{
Exception transmitException = null;
bool acceptPublish = true;
Expand All @@ -750,7 +763,10 @@ async Task TryPublishQueuedMessageAsync(ManagedMqttApplicationMessage message)

if (acceptPublish)
{
await InternalClient.PublishAsync(message.ApplicationMessage).ConfigureAwait(false);
using (var publishTimeout = NewTimeoutToken(cancellationToken))
{
await InternalClient.PublishAsync(message.ApplicationMessage, publishTimeout.Token).ConfigureAwait(false);
}
}

using (await _messageQueueLock.EnterAsync().ConfigureAwait(false)) //lock to avoid conflict with this.PublishAsync
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Client;
using MQTTnet.Diagnostics;
using MQTTnet.Extensions.ManagedClient;
using MQTTnet.Formatter;
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Protocol;
using MQTTnet.Server;
using MQTTnet.Tests.Mockups;
Expand Down Expand Up @@ -574,16 +576,152 @@ public async Task Subscriptions_Subscribe_Only_New_Subscriptions()
}
}

[TestMethod]
public async Task Subscribe_Does_Not_Hang_On_Server_Stop()
{
var timeout = TimeSpan.FromSeconds(2);
var testTimeout = timeout * 2;
const string topic = "test_topic_2";
using (var testEnvironment = CreateTestEnvironment())
using (var managedClient = await CreateManagedClientAsync(testEnvironment, timeout: timeout))
{
testEnvironment.IgnoreClientLogErrors = true;
bool reject = true;
var receivedOnServer = new SemaphoreSlim(0, 1);
var failedOnClient = new SemaphoreSlim(0, 1);
testEnvironment.Server.InterceptingInboundPacketAsync += e =>
{
if (e.Packet is MqttSubscribePacket)
{
if (reject)
{
e.ProcessPacket = false;
}
receivedOnServer.Release();
}
return Task.CompletedTask;
};
managedClient.SynchronizingSubscriptionsFailedAsync += e =>
{
failedOnClient.Release();
return Task.CompletedTask;
};

await managedClient.SubscribeAsync(topic);
Assert.IsTrue(await receivedOnServer.WaitAsync(testTimeout));
Assert.IsTrue(await failedOnClient.WaitAsync(testTimeout));

reject = false;
await managedClient.SubscribeAsync(topic);
Assert.IsTrue(await receivedOnServer.WaitAsync(testTimeout));
}
}

[TestMethod]
public async Task Unsubscribe_Does_Not_Hang_On_Server_Stop()
{
var timeout = TimeSpan.FromSeconds(2);
var testTimeout = timeout * 2;
const string topic = "test_topic_2";
using (var testEnvironment = CreateTestEnvironment())
using (var managedClient = await CreateManagedClientAsync(testEnvironment, timeout: timeout))
{
testEnvironment.IgnoreClientLogErrors = true;
bool reject = true;
var receivedOnServer = new SemaphoreSlim(0, 1);
var failedOnClient = new SemaphoreSlim(0, 1);
testEnvironment.Server.InterceptingInboundPacketAsync += e =>
{
if (e.Packet is MqttUnsubscribePacket)
{
if (reject)
{
e.ProcessPacket = false;
}
receivedOnServer.Release();
}
else if (e.Packet is MqttSubscribePacket)
{
receivedOnServer.Release();
}
return Task.CompletedTask;
};
managedClient.SynchronizingSubscriptionsFailedAsync += e =>
{
failedOnClient.Release();
return Task.CompletedTask;
};

await managedClient.SubscribeAsync(topic);
Assert.IsTrue(await receivedOnServer.WaitAsync(testTimeout));

await managedClient.UnsubscribeAsync(topic);
Assert.IsTrue(await receivedOnServer.WaitAsync(testTimeout));
Assert.IsTrue(await failedOnClient.WaitAsync(testTimeout));

reject = false;
await managedClient.UnsubscribeAsync(topic);
Assert.IsTrue(await receivedOnServer.WaitAsync(testTimeout));
}
}

[TestMethod]
public async Task Publish_Does_Not_Hang_On_Server_Error()
{
var timeout = TimeSpan.FromSeconds(2);
var testTimeout = timeout * 2;

const string topic = "test_topic_42";

using (var testEnvironment = CreateTestEnvironment())
using (var managedClient = await CreateManagedClientAsync(testEnvironment, timeout: timeout))
{
testEnvironment.IgnoreClientLogErrors = true;
bool reject = true;
var receivedOnServer = new TaskCompletionSource();
managedClient.ApplicationMessageProcessedAsync += e => Task.FromResult(reject &= e.Exception is null);
testEnvironment.Server.InterceptingInboundPacketAsync += e =>
{
if (e.Packet is MqttPublishPacket)
{
if (reject)
{
e.ProcessPacket = false;
}
else
{
receivedOnServer.TrySetResult();
}
}
return Task.CompletedTask;
};


await managedClient.EnqueueAsync(new MqttApplicationMessage { Topic = topic, Payload = new byte[] { 1 }, Retain = true, QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce });

var timeoutTask = Task.Delay(testTimeout);

var firstDone = await Task.WhenAny(receivedOnServer.Task, timeoutTask);
Assert.AreEqual(receivedOnServer.Task, firstDone, "Client is hung on publish!");
}
}

async Task<MQTTnet.Extensions.ManagedClient.ManagedMqttClient> CreateManagedClientAsync(
TestEnvironment testEnvironment,
IMqttClient underlyingClient = null,
TimeSpan? connectionCheckInterval = null,
string host = "localhost")
string host = "localhost",
TimeSpan? timeout = null)
{
await testEnvironment.StartServer();

var clientOptions = new MqttClientOptionsBuilder().WithTcpServer(host, testEnvironment.ServerPort);

if (timeout != null)
{
clientOptions.WithTimeout(timeout.Value);
}

var managedOptions = new ManagedMqttClientOptionsBuilder().WithClientOptions(clientOptions).Build();

// Use a short connection check interval so that subscription operations are performed quickly
Expand Down

0 comments on commit be89178

Please sign in to comment.