diff --git a/iothub/device/src/IotHubConnectionString.cs b/iothub/device/src/IotHubConnectionString.cs index 3d8a7349fc..b41bad43e6 100644 --- a/iothub/device/src/IotHubConnectionString.cs +++ b/iothub/device/src/IotHubConnectionString.cs @@ -92,6 +92,33 @@ public IotHubConnectionString(IotHubConnectionStringBuilder builder) } } + // This constructor is only used for unit testing. + internal IotHubConnectionString( + string ioTHubName = null, + string deviceId = null, + string moduleId = null, + string hostName = null, + Uri httpsEndpoint = null, + Uri amqpEndpoint = null, + string audience = null, + string sharedAccessKeyName = null, + string sharedAccessKey = null, + string sharedAccessSignature = null, + bool isUsingGateway = false) + { + IotHubName = ioTHubName; + DeviceId = deviceId; + ModuleId = moduleId; + HostName = hostName; + HttpsEndpoint = httpsEndpoint; + AmqpEndpoint = amqpEndpoint; + Audience = audience; + SharedAccessKeyName = sharedAccessKeyName; + SharedAccessKey = sharedAccessKey; + SharedAccessSignature = sharedAccessSignature; + IsUsingGateway = isUsingGateway; + } + public string IotHubName { get; private set; } public string DeviceId { get; private set; } diff --git a/iothub/device/src/IotHubConnectionStringBuilder.cs b/iothub/device/src/IotHubConnectionStringBuilder.cs index 8c133f9e04..181cd642f6 100644 --- a/iothub/device/src/IotHubConnectionStringBuilder.cs +++ b/iothub/device/src/IotHubConnectionStringBuilder.cs @@ -53,7 +53,7 @@ public sealed class IotHubConnectionStringBuilder /// /// Initializes a new instance of the class. /// - private IotHubConnectionStringBuilder() + internal IotHubConnectionStringBuilder() { } diff --git a/iothub/device/src/Transport/Amqp/AmqpAuthenticationRefresher.cs b/iothub/device/src/Transport/Amqp/AmqpAuthenticationRefresher.cs index 4d9a8d6cad..0b9c09b1b7 100644 --- a/iothub/device/src/Transport/Amqp/AmqpAuthenticationRefresher.cs +++ b/iothub/device/src/Transport/Amqp/AmqpAuthenticationRefresher.cs @@ -21,7 +21,7 @@ internal class AmqpAuthenticationRefresher : IAmqpAuthenticationRefresher, IDisp private Task _refreshLoop; private bool _disposed; - internal AmqpAuthenticationRefresher(DeviceIdentity deviceIdentity, AmqpIotCbsLink amqpCbsLink) + internal AmqpAuthenticationRefresher(IDeviceIdentity deviceIdentity, AmqpIotCbsLink amqpCbsLink) { _amqpIotCbsLink = amqpCbsLink; _connectionString = deviceIdentity.IotHubConnectionString; diff --git a/iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs b/iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs index 1e477ed796..46d1b245dc 100644 --- a/iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs +++ b/iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs @@ -9,12 +9,13 @@ using Microsoft.Azure.Devices.Client.Extensions; using Microsoft.Azure.Devices.Client.Transport.AmqpIot; using System.Collections.Generic; +using System.Linq; namespace Microsoft.Azure.Devices.Client.Transport.Amqp { internal class AmqpConnectionHolder : IAmqpConnectionHolder, IAmqpUnitManager { - private readonly DeviceIdentity _deviceIdentity; + private readonly IDeviceIdentity _deviceIdentity; private readonly AmqpIotConnector _amqpIotConnector; private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1); private readonly HashSet _amqpUnits = new HashSet(); @@ -23,7 +24,7 @@ internal class AmqpConnectionHolder : IAmqpConnectionHolder, IAmqpUnitManager private IAmqpAuthenticationRefresher _amqpAuthenticationRefresher; private volatile bool _disposed; - public AmqpConnectionHolder(DeviceIdentity deviceIdentity) + public AmqpConnectionHolder(IDeviceIdentity deviceIdentity) { _deviceIdentity = deviceIdentity; _amqpIotConnector = new AmqpIotConnector(deviceIdentity.AmqpTransportSettings, deviceIdentity.IotHubConnectionString.HostName); @@ -34,7 +35,7 @@ public AmqpConnectionHolder(DeviceIdentity deviceIdentity) } public AmqpUnit CreateAmqpUnit( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, Func onMethodCallback, Action twinMessageListener, Func onModuleMessageReceivedCallback, @@ -140,7 +141,7 @@ private void Dispose(bool disposing) _disposed = true; } - public async Task CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken) + public async Task CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken) { if (Logging.IsEnabled) { @@ -159,7 +160,7 @@ public async Task CreateRefresherAsync(DeviceIdent return amqpAuthenticator; } - public async Task OpenSessionAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken) + public async Task OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken) { if (Logging.IsEnabled) { @@ -274,9 +275,14 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit) } } - internal DeviceIdentity GetDeviceIdentityOfAuthenticationProvider() + internal IDeviceIdentity GetDeviceIdentityOfAuthenticationProvider() { return _deviceIdentity; } + + internal bool IsEmpty() + { + return !_amqpUnits.Any(); + } } -} \ No newline at end of file +} diff --git a/iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs b/iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs index caffb063b4..dfcc9327d1 100644 --- a/iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs +++ b/iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs @@ -16,8 +16,13 @@ internal class AmqpConnectionPool : IAmqpUnitManager private readonly IDictionary _amqpSasGroupedPool = new Dictionary(); private readonly object _lock = new object(); + protected virtual IDictionary GetAmqpSasGroupedPoolDictionary() + { + return _amqpSasGroupedPool; + } + public AmqpUnit CreateAmqpUnit( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, Func onMethodCallback, Action twinMessageListener, Func onModuleMessageReceivedCallback, @@ -36,21 +41,6 @@ public AmqpUnit CreateAmqpUnit( { AmqpConnectionHolder[] amqpConnectionHolders = ResolveConnectionGroup(deviceIdentity); amqpConnectionHolder = ResolveConnectionByHashing(amqpConnectionHolders, deviceIdentity); - - // For group sas token authenticated devices over a multiplexed connection, the TokenRefresher - // of the first client connecting will be used for generating the group sas tokens - // and will be associated with the connection itself. - // For this reason, if the device identity of the client is not the one associated with the - // connection, the associated TokenRefresher can be safely disposed. - // Note - This does not cause any identity related issues since the group sas tokens are generated - // against the hub host as the intended audience (without the "device Id"). - if (deviceIdentity.AuthenticationModel == AuthenticationModel.SasGrouped - && !ReferenceEquals(amqpConnectionHolder.GetDeviceIdentityOfAuthenticationProvider(), deviceIdentity) - && deviceIdentity.IotHubConnectionString?.TokenRefresher != null - && deviceIdentity.IotHubConnectionString.TokenRefresher.DisposalWithClient) - { - deviceIdentity.IotHubConnectionString.TokenRefresher.Dispose(); - } } if (Logging.IsEnabled) @@ -91,7 +81,7 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit) Logging.Enter(this, amqpUnit, nameof(RemoveAmqpUnit)); } - DeviceIdentity deviceIdentity = amqpUnit.GetDeviceIdentity(); + IDeviceIdentity deviceIdentity = amqpUnit.GetDeviceIdentity(); if (deviceIdentity.IsPooling()) { AmqpConnectionHolder amqpConnectionHolder; @@ -99,8 +89,17 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit) { AmqpConnectionHolder[] amqpConnectionHolders = ResolveConnectionGroup(deviceIdentity); amqpConnectionHolder = ResolveConnectionByHashing(amqpConnectionHolders, deviceIdentity); + + amqpConnectionHolder.RemoveAmqpUnit(amqpUnit); + + // If the connection holder does not have any more units, the entry needs to be nullified. + if (amqpConnectionHolder.IsEmpty()) + { + int index = GetDeviceIdentityIndex(deviceIdentity, amqpConnectionHolders.Length); + amqpConnectionHolders[index] = null; + amqpConnectionHolder?.Dispose(); + } } - amqpConnectionHolder.RemoveAmqpUnit(amqpUnit); } if (Logging.IsEnabled) @@ -109,7 +108,7 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit) } } - private AmqpConnectionHolder[] ResolveConnectionGroup(DeviceIdentity deviceIdentity) + private AmqpConnectionHolder[] ResolveConnectionGroup(IDeviceIdentity deviceIdentity) { if (deviceIdentity.AuthenticationModel == AuthenticationModel.SasIndividual) { @@ -123,25 +122,26 @@ private AmqpConnectionHolder[] ResolveConnectionGroup(DeviceIdentity deviceIdent else { string scope = deviceIdentity.IotHubConnectionString.SharedAccessKeyName; - _amqpSasGroupedPool.TryGetValue(scope, out AmqpConnectionHolder[] amqpConnectionHolders); + GetAmqpSasGroupedPoolDictionary().TryGetValue(scope, out AmqpConnectionHolder[] amqpConnectionHolders); if (amqpConnectionHolders == null) { amqpConnectionHolders = new AmqpConnectionHolder[deviceIdentity.AmqpTransportSettings.AmqpConnectionPoolSettings.MaxPoolSize]; - _amqpSasGroupedPool.Add(scope, amqpConnectionHolders); + GetAmqpSasGroupedPoolDictionary().Add(scope, amqpConnectionHolders); } return amqpConnectionHolders; } } - private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] pool, DeviceIdentity deviceIdentity) + private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] pool, IDeviceIdentity deviceIdentity) { if (Logging.IsEnabled) { Logging.Enter(this, deviceIdentity, nameof(ResolveConnectionByHashing)); } - int index = Math.Abs(deviceIdentity.GetHashCode()) % pool.Length; + int index = GetDeviceIdentityIndex(deviceIdentity, pool.Length); + if (pool[index] == null) { pool[index] = new AmqpConnectionHolder(deviceIdentity); @@ -154,5 +154,12 @@ private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] p return pool[index]; } + + private static int GetDeviceIdentityIndex(IDeviceIdentity deviceIdentity, int poolLength) + { + return deviceIdentity == null + ? throw new ArgumentNullException(nameof(deviceIdentity)) + : Math.Abs(deviceIdentity.GetHashCode()) % poolLength; + } } } diff --git a/iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs b/iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs index df4ab761d6..116eb0cbf9 100644 --- a/iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs +++ b/iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs @@ -49,7 +49,7 @@ internal AmqpTransportHandler( { _operationTimeout = transportSettings.OperationTimeout; _onDesiredStatePatchListener = onDesiredStatePatchReceivedCallback; - var deviceIdentity = new DeviceIdentity(connectionString, transportSettings, context.Get(), context.Get()); + IDeviceIdentity deviceIdentity = new DeviceIdentity(connectionString, transportSettings, context.Get(), context.Get()); _amqpUnit = AmqpUnitManager.GetInstance().CreateAmqpUnit( deviceIdentity, onMethodCallback, diff --git a/iothub/device/src/Transport/Amqp/AmqpUnit.cs b/iothub/device/src/Transport/Amqp/AmqpUnit.cs index 7487347cc2..04f809dd8c 100644 --- a/iothub/device/src/Transport/Amqp/AmqpUnit.cs +++ b/iothub/device/src/Transport/Amqp/AmqpUnit.cs @@ -16,7 +16,7 @@ namespace Microsoft.Azure.Devices.Client.Transport.AmqpIot internal class AmqpUnit : IDisposable { // If the first argument is set to true, we are disconnecting gracefully via CloseAsync. - private readonly DeviceIdentity _deviceIdentity; + private readonly IDeviceIdentity _deviceIdentity; private readonly Func _onMethodCallback; private readonly Action _twinMessageListener; @@ -51,7 +51,7 @@ internal class AmqpUnit : IDisposable private IAmqpAuthenticationRefresher _amqpAuthenticationRefresher; public AmqpUnit( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, IAmqpConnectionHolder amqpConnectionHolder, Func onMethodCallback, Action twinMessageListener, @@ -70,7 +70,7 @@ public AmqpUnit( Logging.Associate(this, _deviceIdentity, nameof(_deviceIdentity)); } - internal DeviceIdentity GetDeviceIdentity() + internal IDeviceIdentity GetDeviceIdentity() { return _deviceIdentity; } diff --git a/iothub/device/src/Transport/Amqp/AmqpUnitManager.cs b/iothub/device/src/Transport/Amqp/AmqpUnitManager.cs index f9d4723f84..f83cdda1f3 100644 --- a/iothub/device/src/Transport/Amqp/AmqpUnitManager.cs +++ b/iothub/device/src/Transport/Amqp/AmqpUnitManager.cs @@ -28,7 +28,7 @@ internal static AmqpUnitManager GetInstance() } public AmqpUnit CreateAmqpUnit( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, Func onMethodCallback, Action twinMessageListener, Func onModuleMessageReceivedCallback, @@ -47,9 +47,9 @@ public AmqpUnit CreateAmqpUnit( public void RemoveAmqpUnit(AmqpUnit amqpUnit) { - amqpUnit.Dispose(); IAmqpUnitManager amqpConnectionPool = ResolveConnectionPool(amqpUnit.GetDeviceIdentity().IotHubConnectionString.HostName); amqpConnectionPool.RemoveAmqpUnit(amqpUnit); + amqpUnit.Dispose(); } private IAmqpUnitManager ResolveConnectionPool(string host) diff --git a/iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs b/iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs index 1fd4fa5fa4..5def753284 100644 --- a/iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs +++ b/iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs @@ -10,11 +10,11 @@ namespace Microsoft.Azure.Devices.Client.Transport.Amqp { internal interface IAmqpConnectionHolder : IDisposable { - Task OpenSessionAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken); + Task OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken); Task EnsureConnectionAsync(CancellationToken cancellationToken); - Task CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken); + Task CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken); void Shutdown(); } diff --git a/iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs b/iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs index 509b5da002..46dbb412b5 100644 --- a/iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs +++ b/iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs @@ -12,7 +12,7 @@ namespace Microsoft.Azure.Devices.Client.Transport.Amqp internal interface IAmqpUnitManager { AmqpUnit CreateAmqpUnit( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, Func onMethodCallback, Action twinMessageListener, Func onModuleMessageReceivedCallback, diff --git a/iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs b/iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs index 36004aeb1d..b97c0fc8ce 100644 --- a/iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs +++ b/iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs @@ -84,7 +84,7 @@ internal async Task OpenSessionAsync(CancellationToken cancellat } } - internal async Task CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken) + internal async Task CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken) { if (_amqpConnection.IsClosing()) { diff --git a/iothub/device/src/Transport/AmqpIot/AmqpIotSession.cs b/iothub/device/src/Transport/AmqpIot/AmqpIotSession.cs index cd6c6f26c1..e03bfa93bc 100644 --- a/iothub/device/src/Transport/AmqpIot/AmqpIotSession.cs +++ b/iothub/device/src/Transport/AmqpIot/AmqpIotSession.cs @@ -58,7 +58,7 @@ internal bool IsClosing() #region Telemetry links internal async Task OpenTelemetrySenderLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, CancellationToken cancellationToken) { return await OpenSendingAmqpLinkAsync( @@ -75,7 +75,7 @@ internal async Task OpenTelemetrySenderLinkAsync( } internal async Task OpenMessageReceiverLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, CancellationToken cancellationToken) { return await OpenReceivingAmqpLinkAsync( @@ -96,7 +96,7 @@ internal async Task OpenMessageReceiverLinkAsync( #region EventLink internal async Task OpenEventsReceiverLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, CancellationToken cancellationToken) { return await OpenReceivingAmqpLinkAsync( @@ -117,7 +117,7 @@ internal async Task OpenEventsReceiverLinkAsync( #region MethodLink internal async Task OpenMethodsSenderLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, string correlationIdSuffix, CancellationToken cancellationToken) { @@ -135,7 +135,7 @@ internal async Task OpenMethodsSenderLinkAsync( } internal async Task OpenMethodsReceiverLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, string correlationIdSuffix, CancellationToken cancellationToken) { @@ -157,7 +157,7 @@ internal async Task OpenMethodsReceiverLinkAsync( #region TwinLink internal async Task OpenTwinReceiverLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, string correlationIdSuffix, CancellationToken cancellationToken) { @@ -175,7 +175,7 @@ internal async Task OpenTwinReceiverLinkAsync( } internal async Task OpenTwinSenderLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, string correlationIdSuffix, CancellationToken cancellationToken) { @@ -197,7 +197,7 @@ internal async Task OpenTwinSenderLinkAsync( #region Common link handling private static async Task OpenSendingAmqpLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, AmqpSession amqpSession, byte? senderSettleMode, byte? receiverSettleMode, @@ -278,7 +278,7 @@ private static async Task OpenSendingAmqpLinkAsync( } private static async Task OpenReceivingAmqpLinkAsync( - DeviceIdentity deviceIdentity, + IDeviceIdentity deviceIdentity, AmqpSession amqpSession, byte? senderSettleMode, byte? receiverSettleMode, @@ -353,7 +353,7 @@ private static async Task OpenReceivingAmqpLinkAsync( } } - private static string BuildLinkAddress(DeviceIdentity deviceIdentity, string deviceTemplate, string moduleTemplate) + private static string BuildLinkAddress(IDeviceIdentity deviceIdentity, string deviceTemplate, string moduleTemplate) { string path = string.IsNullOrEmpty(deviceIdentity.IotHubConnectionString.ModuleId) ? string.Format( diff --git a/iothub/device/src/Transport/DeviceIdentity.cs b/iothub/device/src/Transport/DeviceIdentity.cs index dc1505eac6..6b13877ca4 100644 --- a/iothub/device/src/Transport/DeviceIdentity.cs +++ b/iothub/device/src/Transport/DeviceIdentity.cs @@ -12,16 +12,20 @@ namespace Microsoft.Azure.Devices.Client.Transport /// - connection string /// - transport settings /// - internal class DeviceIdentity + internal class DeviceIdentity : IDeviceIdentity { - internal IotHubConnectionString IotHubConnectionString { get; } - internal AmqpTransportSettings AmqpTransportSettings { get; } - internal ProductInfo ProductInfo { get; } - internal AuthenticationModel AuthenticationModel { get; } - internal string Audience { get; } - internal ClientOptions Options { get; } + public IotHubConnectionString IotHubConnectionString { get; } + public AmqpTransportSettings AmqpTransportSettings { get; } + public ProductInfo ProductInfo { get; } + public AuthenticationModel AuthenticationModel { get; } + public string Audience { get; } + public ClientOptions Options { get; } - internal DeviceIdentity(IotHubConnectionString iotHubConnectionString, AmqpTransportSettings amqpTransportSettings, ProductInfo productInfo, ClientOptions options) + internal DeviceIdentity( + IotHubConnectionString iotHubConnectionString, + AmqpTransportSettings amqpTransportSettings, + ProductInfo productInfo, + ClientOptions options) { IotHubConnectionString = iotHubConnectionString; AmqpTransportSettings = amqpTransportSettings; @@ -31,14 +35,9 @@ internal DeviceIdentity(IotHubConnectionString iotHubConnectionString, AmqpTrans if (amqpTransportSettings.ClientCertificate == null) { Audience = CreateAudience(IotHubConnectionString); - if (iotHubConnectionString.SharedAccessKeyName == null) - { - AuthenticationModel = AuthenticationModel.SasIndividual; - } - else - { - AuthenticationModel = AuthenticationModel.SasGrouped; - } + AuthenticationModel = iotHubConnectionString.SharedAccessKeyName == null + ? AuthenticationModel.SasIndividual + : AuthenticationModel.SasGrouped; } else { @@ -61,7 +60,7 @@ private static string CreateAudience(IotHubConnectionString connectionString) } } - internal bool IsPooling() + public bool IsPooling() { return (AuthenticationModel != AuthenticationModel.X509) && (AmqpTransportSettings?.AmqpConnectionPoolSettings?.Pooling ?? false); } diff --git a/iothub/device/src/Transport/IDeviceClientEndpointIdentityFactory.cs b/iothub/device/src/Transport/IDeviceClientEndpointIdentityFactory.cs deleted file mode 100644 index b0d61a6686..0000000000 --- a/iothub/device/src/Transport/IDeviceClientEndpointIdentityFactory.cs +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Text; -using System.IO; - -namespace Microsoft.Azure.Devices.Client.Transport -{ - /// - /// Factory interface to create DeviceClientEndpointIdentity objects for Amqp transport layer - /// - internal interface IDeviceClientEndpointIdentityFactory - { - DeviceIdentity Create(IotHubConnectionString iotHubConnectionString, AmqpTransportSettings amqpTransportSettings, ProductInfo productInfo); - } -} diff --git a/iothub/device/src/Transport/IDeviceIdentity.cs b/iothub/device/src/Transport/IDeviceIdentity.cs new file mode 100644 index 0000000000..f03e181959 --- /dev/null +++ b/iothub/device/src/Transport/IDeviceIdentity.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.Devices.Client.Transport +{ + internal interface IDeviceIdentity + { + AuthenticationModel AuthenticationModel { get; } + AmqpTransportSettings AmqpTransportSettings { get; } + IotHubConnectionString IotHubConnectionString { get; } + ProductInfo ProductInfo { get; } + ClientOptions Options { get; } + string Audience { get; } + + bool IsPooling(); + } +} diff --git a/iothub/device/tests/Amqp/AmqpConnectionPoolTests.cs b/iothub/device/tests/Amqp/AmqpConnectionPoolTests.cs new file mode 100644 index 0000000000..506aa2ff20 --- /dev/null +++ b/iothub/device/tests/Amqp/AmqpConnectionPoolTests.cs @@ -0,0 +1,74 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Azure.Devices.Client.Transport; +using Microsoft.Azure.Devices.Client.Transport.Amqp; +using Microsoft.Azure.Devices.Client.Transport.AmqpIot; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; + +namespace Microsoft.Azure.Devices.Client.Tests.Amqp +{ + [TestClass] + public class AmqpConnectionPoolTests + { + internal class AmqpConnectionPoolTest : AmqpConnectionPool + { + private readonly IDictionary _dictionaryToUse; + + public AmqpConnectionPoolTest(IDictionary dictionaryToUse) + { + _dictionaryToUse = dictionaryToUse; + } + + protected override IDictionary GetAmqpSasGroupedPoolDictionary() + { + return _dictionaryToUse; + } + } + + [TestMethod] + public void AmqpConnectionPool_Add_Remove_ConnectionHolderIsRemoved() + { + string sharedAccessKeyName = "HubOwner"; + uint poolSize = 10; + IDeviceIdentity testDevice = CreatePooledSasGroupedDeviceIdentity(sharedAccessKeyName, poolSize); + IDictionary injectedDictionary = new Dictionary(); + + AmqpConnectionPoolTest pool = new AmqpConnectionPoolTest(injectedDictionary); + + AmqpUnit addedUnit = pool.CreateAmqpUnit(testDevice, null, null, null, null, null); + + injectedDictionary[sharedAccessKeyName].Count().Should().Be((int)poolSize); + + pool.RemoveAmqpUnit(addedUnit); + + foreach (object item in injectedDictionary[sharedAccessKeyName]) + { + item.Should().BeNull(); + } + } + + private IDeviceIdentity CreatePooledSasGroupedDeviceIdentity(string sharedAccessKeyName, uint poolSize) + { + Mock deviceIdentity = new Mock(); + + deviceIdentity.Setup(m => m.IsPooling()).Returns(true); + deviceIdentity.Setup(m => m.AuthenticationModel).Returns(AuthenticationModel.SasGrouped); + deviceIdentity.Setup(m => m.IotHubConnectionString).Returns(new IotHubConnectionString(sharedAccessKeyName: sharedAccessKeyName)); + deviceIdentity.Setup(m => m.AmqpTransportSettings).Returns(new AmqpTransportSettings(TransportType.Amqp_Tcp_Only) + { + AmqpConnectionPoolSettings = new AmqpConnectionPoolSettings() + { + Pooling = true, + MaxPoolSize = poolSize, + } + }); + + return deviceIdentity.Object; + } + } +}