Skip to content

Commit

Permalink
fix(deviceClient): Fix issue with AMQP connection pool and TokenRefer…
Browse files Browse the repository at this point in the history
…esher disposal. (Azure#2260)

* Fix issue with AMQP connection pool

* Dispose the connection holder appropriately

* Add unit testing capability to connection pool
  • Loading branch information
azabbasi authored Jan 7, 2022
1 parent 4c8a608 commit 21e3324
Show file tree
Hide file tree
Showing 16 changed files with 199 additions and 87 deletions.
27 changes: 27 additions & 0 deletions iothub/device/src/IotHubConnectionString.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,33 @@ public IotHubConnectionString(IotHubConnectionStringBuilder builder)
}
}

// This constructor is only used for unit testing.
internal IotHubConnectionString(
string ioTHubName = null,
string deviceId = null,
string moduleId = null,
string hostName = null,
Uri httpsEndpoint = null,
Uri amqpEndpoint = null,
string audience = null,
string sharedAccessKeyName = null,
string sharedAccessKey = null,
string sharedAccessSignature = null,
bool isUsingGateway = false)
{
IotHubName = ioTHubName;
DeviceId = deviceId;
ModuleId = moduleId;
HostName = hostName;
HttpsEndpoint = httpsEndpoint;
AmqpEndpoint = amqpEndpoint;
Audience = audience;
SharedAccessKeyName = sharedAccessKeyName;
SharedAccessKey = sharedAccessKey;
SharedAccessSignature = sharedAccessSignature;
IsUsingGateway = isUsingGateway;
}

public string IotHubName { get; private set; }

public string DeviceId { get; private set; }
Expand Down
2 changes: 1 addition & 1 deletion iothub/device/src/IotHubConnectionStringBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public sealed class IotHubConnectionStringBuilder
/// <summary>
/// Initializes a new instance of the <see cref="IotHubConnectionStringBuilder"/> class.
/// </summary>
private IotHubConnectionStringBuilder()
internal IotHubConnectionStringBuilder()
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal class AmqpAuthenticationRefresher : IAmqpAuthenticationRefresher, IDisp
private Task _refreshLoop;
private bool _disposed;

internal AmqpAuthenticationRefresher(DeviceIdentity deviceIdentity, AmqpIotCbsLink amqpCbsLink)
internal AmqpAuthenticationRefresher(IDeviceIdentity deviceIdentity, AmqpIotCbsLink amqpCbsLink)
{
_amqpIotCbsLink = amqpCbsLink;
_connectionString = deviceIdentity.IotHubConnectionString;
Expand Down
20 changes: 13 additions & 7 deletions iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
using Microsoft.Azure.Devices.Client.Extensions;
using Microsoft.Azure.Devices.Client.Transport.AmqpIot;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.Azure.Devices.Client.Transport.Amqp
{
internal class AmqpConnectionHolder : IAmqpConnectionHolder, IAmqpUnitManager
{
private readonly DeviceIdentity _deviceIdentity;
private readonly IDeviceIdentity _deviceIdentity;
private readonly AmqpIotConnector _amqpIotConnector;
private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1);
private readonly HashSet<AmqpUnit> _amqpUnits = new HashSet<AmqpUnit>();
Expand All @@ -23,7 +24,7 @@ internal class AmqpConnectionHolder : IAmqpConnectionHolder, IAmqpUnitManager
private IAmqpAuthenticationRefresher _amqpAuthenticationRefresher;
private volatile bool _disposed;

public AmqpConnectionHolder(DeviceIdentity deviceIdentity)
public AmqpConnectionHolder(IDeviceIdentity deviceIdentity)
{
_deviceIdentity = deviceIdentity;
_amqpIotConnector = new AmqpIotConnector(deviceIdentity.AmqpTransportSettings, deviceIdentity.IotHubConnectionString.HostName);
Expand All @@ -34,7 +35,7 @@ public AmqpConnectionHolder(DeviceIdentity deviceIdentity)
}

public AmqpUnit CreateAmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Func<string, Message, Task> onModuleMessageReceivedCallback,
Expand Down Expand Up @@ -140,7 +141,7 @@ private void Dispose(bool disposing)
_disposed = true;
}

public async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
public async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Expand All @@ -159,7 +160,7 @@ public async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(DeviceIdent
return amqpAuthenticator;
}

public async Task<AmqpIotSession> OpenSessionAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
public async Task<AmqpIotSession> OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Expand Down Expand Up @@ -274,9 +275,14 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
}
}

internal DeviceIdentity GetDeviceIdentityOfAuthenticationProvider()
internal IDeviceIdentity GetDeviceIdentityOfAuthenticationProvider()
{
return _deviceIdentity;
}

internal bool IsEmpty()
{
return !_amqpUnits.Any();
}
}
}
}
53 changes: 30 additions & 23 deletions iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ internal class AmqpConnectionPool : IAmqpUnitManager
private readonly IDictionary<string, AmqpConnectionHolder[]> _amqpSasGroupedPool = new Dictionary<string, AmqpConnectionHolder[]>();
private readonly object _lock = new object();

protected virtual IDictionary<string, AmqpConnectionHolder[]> GetAmqpSasGroupedPoolDictionary()
{
return _amqpSasGroupedPool;
}

public AmqpUnit CreateAmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Func<string, Message, Task> onModuleMessageReceivedCallback,
Expand All @@ -36,21 +41,6 @@ public AmqpUnit CreateAmqpUnit(
{
AmqpConnectionHolder[] amqpConnectionHolders = ResolveConnectionGroup(deviceIdentity);
amqpConnectionHolder = ResolveConnectionByHashing(amqpConnectionHolders, deviceIdentity);

// For group sas token authenticated devices over a multiplexed connection, the TokenRefresher
// of the first client connecting will be used for generating the group sas tokens
// and will be associated with the connection itself.
// For this reason, if the device identity of the client is not the one associated with the
// connection, the associated TokenRefresher can be safely disposed.
// Note - This does not cause any identity related issues since the group sas tokens are generated
// against the hub host as the intended audience (without the "device Id").
if (deviceIdentity.AuthenticationModel == AuthenticationModel.SasGrouped
&& !ReferenceEquals(amqpConnectionHolder.GetDeviceIdentityOfAuthenticationProvider(), deviceIdentity)
&& deviceIdentity.IotHubConnectionString?.TokenRefresher != null
&& deviceIdentity.IotHubConnectionString.TokenRefresher.DisposalWithClient)
{
deviceIdentity.IotHubConnectionString.TokenRefresher.Dispose();
}
}

if (Logging.IsEnabled)
Expand Down Expand Up @@ -91,16 +81,25 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
Logging.Enter(this, amqpUnit, nameof(RemoveAmqpUnit));
}

DeviceIdentity deviceIdentity = amqpUnit.GetDeviceIdentity();
IDeviceIdentity deviceIdentity = amqpUnit.GetDeviceIdentity();
if (deviceIdentity.IsPooling())
{
AmqpConnectionHolder amqpConnectionHolder;
lock (_lock)
{
AmqpConnectionHolder[] amqpConnectionHolders = ResolveConnectionGroup(deviceIdentity);
amqpConnectionHolder = ResolveConnectionByHashing(amqpConnectionHolders, deviceIdentity);

amqpConnectionHolder.RemoveAmqpUnit(amqpUnit);

// If the connection holder does not have any more units, the entry needs to be nullified.
if (amqpConnectionHolder.IsEmpty())
{
int index = GetDeviceIdentityIndex(deviceIdentity, amqpConnectionHolders.Length);
amqpConnectionHolders[index] = null;
amqpConnectionHolder?.Dispose();
}
}
amqpConnectionHolder.RemoveAmqpUnit(amqpUnit);
}

if (Logging.IsEnabled)
Expand All @@ -109,7 +108,7 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
}
}

private AmqpConnectionHolder[] ResolveConnectionGroup(DeviceIdentity deviceIdentity)
private AmqpConnectionHolder[] ResolveConnectionGroup(IDeviceIdentity deviceIdentity)
{
if (deviceIdentity.AuthenticationModel == AuthenticationModel.SasIndividual)
{
Expand All @@ -123,25 +122,26 @@ private AmqpConnectionHolder[] ResolveConnectionGroup(DeviceIdentity deviceIdent
else
{
string scope = deviceIdentity.IotHubConnectionString.SharedAccessKeyName;
_amqpSasGroupedPool.TryGetValue(scope, out AmqpConnectionHolder[] amqpConnectionHolders);
GetAmqpSasGroupedPoolDictionary().TryGetValue(scope, out AmqpConnectionHolder[] amqpConnectionHolders);
if (amqpConnectionHolders == null)
{
amqpConnectionHolders = new AmqpConnectionHolder[deviceIdentity.AmqpTransportSettings.AmqpConnectionPoolSettings.MaxPoolSize];
_amqpSasGroupedPool.Add(scope, amqpConnectionHolders);
GetAmqpSasGroupedPoolDictionary().Add(scope, amqpConnectionHolders);
}

return amqpConnectionHolders;
}
}

private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] pool, DeviceIdentity deviceIdentity)
private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] pool, IDeviceIdentity deviceIdentity)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(ResolveConnectionByHashing));
}

int index = Math.Abs(deviceIdentity.GetHashCode()) % pool.Length;
int index = GetDeviceIdentityIndex(deviceIdentity, pool.Length);

if (pool[index] == null)
{
pool[index] = new AmqpConnectionHolder(deviceIdentity);
Expand All @@ -154,5 +154,12 @@ private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] p

return pool[index];
}

private static int GetDeviceIdentityIndex(IDeviceIdentity deviceIdentity, int poolLength)
{
return deviceIdentity == null
? throw new ArgumentNullException(nameof(deviceIdentity))
: Math.Abs(deviceIdentity.GetHashCode()) % poolLength;
}
}
}
2 changes: 1 addition & 1 deletion iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ internal AmqpTransportHandler(
{
_operationTimeout = transportSettings.OperationTimeout;
_onDesiredStatePatchListener = onDesiredStatePatchReceivedCallback;
var deviceIdentity = new DeviceIdentity(connectionString, transportSettings, context.Get<ProductInfo>(), context.Get<ClientOptions>());
IDeviceIdentity deviceIdentity = new DeviceIdentity(connectionString, transportSettings, context.Get<ProductInfo>(), context.Get<ClientOptions>());
_amqpUnit = AmqpUnitManager.GetInstance().CreateAmqpUnit(
deviceIdentity,
onMethodCallback,
Expand Down
6 changes: 3 additions & 3 deletions iothub/device/src/Transport/Amqp/AmqpUnit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.Azure.Devices.Client.Transport.AmqpIot
internal class AmqpUnit : IDisposable
{
// If the first argument is set to true, we are disconnecting gracefully via CloseAsync.
private readonly DeviceIdentity _deviceIdentity;
private readonly IDeviceIdentity _deviceIdentity;

private readonly Func<MethodRequestInternal, Task> _onMethodCallback;
private readonly Action<Twin, string, TwinCollection, IotHubException> _twinMessageListener;
Expand Down Expand Up @@ -51,7 +51,7 @@ internal class AmqpUnit : IDisposable
private IAmqpAuthenticationRefresher _amqpAuthenticationRefresher;

public AmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
IAmqpConnectionHolder amqpConnectionHolder,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Expand All @@ -70,7 +70,7 @@ public AmqpUnit(
Logging.Associate(this, _deviceIdentity, nameof(_deviceIdentity));
}

internal DeviceIdentity GetDeviceIdentity()
internal IDeviceIdentity GetDeviceIdentity()
{
return _deviceIdentity;
}
Expand Down
4 changes: 2 additions & 2 deletions iothub/device/src/Transport/Amqp/AmqpUnitManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal static AmqpUnitManager GetInstance()
}

public AmqpUnit CreateAmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Func<string, Message, Task> onModuleMessageReceivedCallback,
Expand All @@ -47,9 +47,9 @@ public AmqpUnit CreateAmqpUnit(

public void RemoveAmqpUnit(AmqpUnit amqpUnit)
{
amqpUnit.Dispose();
IAmqpUnitManager amqpConnectionPool = ResolveConnectionPool(amqpUnit.GetDeviceIdentity().IotHubConnectionString.HostName);
amqpConnectionPool.RemoveAmqpUnit(amqpUnit);
amqpUnit.Dispose();
}

private IAmqpUnitManager ResolveConnectionPool(string host)
Expand Down
4 changes: 2 additions & 2 deletions iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ namespace Microsoft.Azure.Devices.Client.Transport.Amqp
{
internal interface IAmqpConnectionHolder : IDisposable
{
Task<AmqpIotSession> OpenSessionAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken);
Task<AmqpIotSession> OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken);

Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken cancellationToken);

Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken);
Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken);

void Shutdown();
}
Expand Down
2 changes: 1 addition & 1 deletion iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Azure.Devices.Client.Transport.Amqp
internal interface IAmqpUnitManager
{
AmqpUnit CreateAmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Func<string, Message, Task> onModuleMessageReceivedCallback,
Expand Down
2 changes: 1 addition & 1 deletion iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ internal async Task<AmqpIotSession> OpenSessionAsync(CancellationToken cancellat
}
}

internal async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
internal async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (_amqpConnection.IsClosing())
{
Expand Down
Loading

0 comments on commit 21e3324

Please sign in to comment.