Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(deviceClient): Fix issue with AMQP connection pool and TokenReferesher disposal. #2260

Merged
merged 5 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions iothub/device/src/IotHubConnectionString.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,33 @@ public IotHubConnectionString(IotHubConnectionStringBuilder builder)
}
}

// This constructor is only used for unit testing.
internal IotHubConnectionString(
string ioTHubName = null,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
string ioTHubName = null,
string iotHubName = null,

string deviceId = null,
string moduleId = null,
string hostName = null,
Uri httpsEndpoint = null,
Uri amqpEndpoint = null,
string audience = null,
string sharedAccessKeyName = null,
string sharedAccessKey = null,
string sharedAccessSignature = null,
bool isUsingGateway = false)
{
IotHubName = ioTHubName;
DeviceId = deviceId;
ModuleId = moduleId;
HostName = hostName;
HttpsEndpoint = httpsEndpoint;
AmqpEndpoint = amqpEndpoint;
Audience = audience;
SharedAccessKeyName = sharedAccessKeyName;
SharedAccessKey = sharedAccessKey;
SharedAccessSignature = sharedAccessSignature;
IsUsingGateway = isUsingGateway;
}

public string IotHubName { get; private set; }

public string DeviceId { get; private set; }
Expand Down
2 changes: 1 addition & 1 deletion iothub/device/src/IotHubConnectionStringBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public sealed class IotHubConnectionStringBuilder
/// <summary>
/// Initializes a new instance of the <see cref="IotHubConnectionStringBuilder"/> class.
/// </summary>
private IotHubConnectionStringBuilder()
internal IotHubConnectionStringBuilder()
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal class AmqpAuthenticationRefresher : IAmqpAuthenticationRefresher, IDisp
private Task _refreshLoop;
private bool _disposed;

internal AmqpAuthenticationRefresher(DeviceIdentity deviceIdentity, AmqpIotCbsLink amqpCbsLink)
internal AmqpAuthenticationRefresher(IDeviceIdentity deviceIdentity, AmqpIotCbsLink amqpCbsLink)
{
_amqpIotCbsLink = amqpCbsLink;
_connectionString = deviceIdentity.IotHubConnectionString;
Expand Down
20 changes: 13 additions & 7 deletions iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
using Microsoft.Azure.Devices.Client.Extensions;
using Microsoft.Azure.Devices.Client.Transport.AmqpIot;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.Azure.Devices.Client.Transport.Amqp
{
internal class AmqpConnectionHolder : IAmqpConnectionHolder, IAmqpUnitManager
{
private readonly DeviceIdentity _deviceIdentity;
private readonly IDeviceIdentity _deviceIdentity;
private readonly AmqpIotConnector _amqpIotConnector;
private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1);
private readonly HashSet<AmqpUnit> _amqpUnits = new HashSet<AmqpUnit>();
Expand All @@ -23,7 +24,7 @@ internal class AmqpConnectionHolder : IAmqpConnectionHolder, IAmqpUnitManager
private IAmqpAuthenticationRefresher _amqpAuthenticationRefresher;
private volatile bool _disposed;

public AmqpConnectionHolder(DeviceIdentity deviceIdentity)
public AmqpConnectionHolder(IDeviceIdentity deviceIdentity)
{
_deviceIdentity = deviceIdentity;
_amqpIotConnector = new AmqpIotConnector(deviceIdentity.AmqpTransportSettings, deviceIdentity.IotHubConnectionString.HostName);
Expand All @@ -34,7 +35,7 @@ public AmqpConnectionHolder(DeviceIdentity deviceIdentity)
}

public AmqpUnit CreateAmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Func<string, Message, Task> onModuleMessageReceivedCallback,
Expand Down Expand Up @@ -140,7 +141,7 @@ private void Dispose(bool disposing)
_disposed = true;
}

public async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
public async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Expand All @@ -159,7 +160,7 @@ public async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(DeviceIdent
return amqpAuthenticator;
}

public async Task<AmqpIotSession> OpenSessionAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
public async Task<AmqpIotSession> OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Expand Down Expand Up @@ -274,9 +275,14 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
}
}

internal DeviceIdentity GetDeviceIdentityOfAuthenticationProvider()
internal IDeviceIdentity GetDeviceIdentityOfAuthenticationProvider()
{
return _deviceIdentity;
}

internal bool IsEmpty()
{
return !_amqpUnits.Any();
}
}
}
}
53 changes: 30 additions & 23 deletions iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ internal class AmqpConnectionPool : IAmqpUnitManager
private readonly IDictionary<string, AmqpConnectionHolder[]> _amqpSasGroupedPool = new Dictionary<string, AmqpConnectionHolder[]>();
private readonly object _lock = new object();

protected virtual IDictionary<string, AmqpConnectionHolder[]> GetAmqpSasGroupedPoolDictionary()
{
return _amqpSasGroupedPool;
}

public AmqpUnit CreateAmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Func<string, Message, Task> onModuleMessageReceivedCallback,
Expand All @@ -36,21 +41,6 @@ public AmqpUnit CreateAmqpUnit(
{
AmqpConnectionHolder[] amqpConnectionHolders = ResolveConnectionGroup(deviceIdentity);
amqpConnectionHolder = ResolveConnectionByHashing(amqpConnectionHolders, deviceIdentity);

// For group sas token authenticated devices over a multiplexed connection, the TokenRefresher
// of the first client connecting will be used for generating the group sas tokens
// and will be associated with the connection itself.
// For this reason, if the device identity of the client is not the one associated with the
// connection, the associated TokenRefresher can be safely disposed.
// Note - This does not cause any identity related issues since the group sas tokens are generated
// against the hub host as the intended audience (without the "device Id").
if (deviceIdentity.AuthenticationModel == AuthenticationModel.SasGrouped
&& !ReferenceEquals(amqpConnectionHolder.GetDeviceIdentityOfAuthenticationProvider(), deviceIdentity)
&& deviceIdentity.IotHubConnectionString?.TokenRefresher != null
&& deviceIdentity.IotHubConnectionString.TokenRefresher.DisposalWithClient)
{
deviceIdentity.IotHubConnectionString.TokenRefresher.Dispose();
azabbasi marked this conversation as resolved.
Show resolved Hide resolved
}
}

if (Logging.IsEnabled)
Expand Down Expand Up @@ -91,16 +81,25 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
Logging.Enter(this, amqpUnit, nameof(RemoveAmqpUnit));
}

DeviceIdentity deviceIdentity = amqpUnit.GetDeviceIdentity();
IDeviceIdentity deviceIdentity = amqpUnit.GetDeviceIdentity();
if (deviceIdentity.IsPooling())
{
AmqpConnectionHolder amqpConnectionHolder;
lock (_lock)
{
AmqpConnectionHolder[] amqpConnectionHolders = ResolveConnectionGroup(deviceIdentity);
amqpConnectionHolder = ResolveConnectionByHashing(amqpConnectionHolders, deviceIdentity);

amqpConnectionHolder.RemoveAmqpUnit(amqpUnit);

// If the connection holder does not have any more units, the entry needs to be nullified.
if (amqpConnectionHolder.IsEmpty())
{
int index = GetDeviceIdentityIndex(deviceIdentity, amqpConnectionHolders.Length);
amqpConnectionHolders[index] = null;
azabbasi marked this conversation as resolved.
Show resolved Hide resolved
amqpConnectionHolder?.Dispose();
azabbasi marked this conversation as resolved.
Show resolved Hide resolved
}
}
amqpConnectionHolder.RemoveAmqpUnit(amqpUnit);
}

if (Logging.IsEnabled)
Expand All @@ -109,7 +108,7 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
}
}

private AmqpConnectionHolder[] ResolveConnectionGroup(DeviceIdentity deviceIdentity)
private AmqpConnectionHolder[] ResolveConnectionGroup(IDeviceIdentity deviceIdentity)
{
if (deviceIdentity.AuthenticationModel == AuthenticationModel.SasIndividual)
{
Expand All @@ -123,25 +122,26 @@ private AmqpConnectionHolder[] ResolveConnectionGroup(DeviceIdentity deviceIdent
else
{
string scope = deviceIdentity.IotHubConnectionString.SharedAccessKeyName;
_amqpSasGroupedPool.TryGetValue(scope, out AmqpConnectionHolder[] amqpConnectionHolders);
GetAmqpSasGroupedPoolDictionary().TryGetValue(scope, out AmqpConnectionHolder[] amqpConnectionHolders);
if (amqpConnectionHolders == null)
{
amqpConnectionHolders = new AmqpConnectionHolder[deviceIdentity.AmqpTransportSettings.AmqpConnectionPoolSettings.MaxPoolSize];
_amqpSasGroupedPool.Add(scope, amqpConnectionHolders);
GetAmqpSasGroupedPoolDictionary().Add(scope, amqpConnectionHolders);
}

return amqpConnectionHolders;
}
}

private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] pool, DeviceIdentity deviceIdentity)
private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] pool, IDeviceIdentity deviceIdentity)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(ResolveConnectionByHashing));
}

int index = Math.Abs(deviceIdentity.GetHashCode()) % pool.Length;
int index = GetDeviceIdentityIndex(deviceIdentity, pool.Length);

if (pool[index] == null)
{
pool[index] = new AmqpConnectionHolder(deviceIdentity);
Expand All @@ -154,5 +154,12 @@ private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] p

return pool[index];
}

private static int GetDeviceIdentityIndex(IDeviceIdentity deviceIdentity, int poolLength)
{
return deviceIdentity == null
azabbasi marked this conversation as resolved.
Show resolved Hide resolved
? throw new ArgumentNullException(nameof(deviceIdentity))
: Math.Abs(deviceIdentity.GetHashCode()) % poolLength;
}
}
}
2 changes: 1 addition & 1 deletion iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ internal AmqpTransportHandler(
{
_operationTimeout = transportSettings.OperationTimeout;
_onDesiredStatePatchListener = onDesiredStatePatchReceivedCallback;
var deviceIdentity = new DeviceIdentity(connectionString, transportSettings, context.Get<ProductInfo>(), context.Get<ClientOptions>());
IDeviceIdentity deviceIdentity = new DeviceIdentity(connectionString, transportSettings, context.Get<ProductInfo>(), context.Get<ClientOptions>());
_amqpUnit = AmqpUnitManager.GetInstance().CreateAmqpUnit(
deviceIdentity,
onMethodCallback,
Expand Down
6 changes: 3 additions & 3 deletions iothub/device/src/Transport/Amqp/AmqpUnit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.Azure.Devices.Client.Transport.AmqpIot
internal class AmqpUnit : IDisposable
{
// If the first argument is set to true, we are disconnecting gracefully via CloseAsync.
private readonly DeviceIdentity _deviceIdentity;
private readonly IDeviceIdentity _deviceIdentity;

private readonly Func<MethodRequestInternal, Task> _onMethodCallback;
private readonly Action<Twin, string, TwinCollection, IotHubException> _twinMessageListener;
Expand Down Expand Up @@ -51,7 +51,7 @@ internal class AmqpUnit : IDisposable
private IAmqpAuthenticationRefresher _amqpAuthenticationRefresher;

public AmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
IAmqpConnectionHolder amqpConnectionHolder,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Expand All @@ -70,7 +70,7 @@ public AmqpUnit(
Logging.Associate(this, _deviceIdentity, nameof(_deviceIdentity));
}

internal DeviceIdentity GetDeviceIdentity()
internal IDeviceIdentity GetDeviceIdentity()
{
return _deviceIdentity;
}
Expand Down
4 changes: 2 additions & 2 deletions iothub/device/src/Transport/Amqp/AmqpUnitManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal static AmqpUnitManager GetInstance()
}

public AmqpUnit CreateAmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Func<string, Message, Task> onModuleMessageReceivedCallback,
Expand All @@ -47,9 +47,9 @@ public AmqpUnit CreateAmqpUnit(

public void RemoveAmqpUnit(AmqpUnit amqpUnit)
{
amqpUnit.Dispose();
IAmqpUnitManager amqpConnectionPool = ResolveConnectionPool(amqpUnit.GetDeviceIdentity().IotHubConnectionString.HostName);
amqpConnectionPool.RemoveAmqpUnit(amqpUnit);
amqpUnit.Dispose();
azabbasi marked this conversation as resolved.
Show resolved Hide resolved
}

private IAmqpUnitManager ResolveConnectionPool(string host)
Expand Down
4 changes: 2 additions & 2 deletions iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ namespace Microsoft.Azure.Devices.Client.Transport.Amqp
{
internal interface IAmqpConnectionHolder : IDisposable
{
Task<AmqpIotSession> OpenSessionAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken);
Task<AmqpIotSession> OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken);

Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken cancellationToken);

Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken);
Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken);

void Shutdown();
}
Expand Down
2 changes: 1 addition & 1 deletion iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Azure.Devices.Client.Transport.Amqp
internal interface IAmqpUnitManager
{
AmqpUnit CreateAmqpUnit(
DeviceIdentity deviceIdentity,
IDeviceIdentity deviceIdentity,
Func<MethodRequestInternal, Task> onMethodCallback,
Action<Twin, string, TwinCollection, IotHubException> twinMessageListener,
Func<string, Message, Task> onModuleMessageReceivedCallback,
Expand Down
2 changes: 1 addition & 1 deletion iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ internal async Task<AmqpIotSession> OpenSessionAsync(CancellationToken cancellat
}
}

internal async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
internal async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (_amqpConnection.IsClosing())
{
Expand Down
Loading