diff --git a/iothub/device/src/IotHubConnectionString.cs b/iothub/device/src/IotHubConnectionString.cs
index 3d8a7349fc..b41bad43e6 100644
--- a/iothub/device/src/IotHubConnectionString.cs
+++ b/iothub/device/src/IotHubConnectionString.cs
@@ -92,6 +92,33 @@ public IotHubConnectionString(IotHubConnectionStringBuilder builder)
}
}
+ // This constructor is only used for unit testing.
+ internal IotHubConnectionString(
+ string ioTHubName = null,
+ string deviceId = null,
+ string moduleId = null,
+ string hostName = null,
+ Uri httpsEndpoint = null,
+ Uri amqpEndpoint = null,
+ string audience = null,
+ string sharedAccessKeyName = null,
+ string sharedAccessKey = null,
+ string sharedAccessSignature = null,
+ bool isUsingGateway = false)
+ {
+ IotHubName = ioTHubName;
+ DeviceId = deviceId;
+ ModuleId = moduleId;
+ HostName = hostName;
+ HttpsEndpoint = httpsEndpoint;
+ AmqpEndpoint = amqpEndpoint;
+ Audience = audience;
+ SharedAccessKeyName = sharedAccessKeyName;
+ SharedAccessKey = sharedAccessKey;
+ SharedAccessSignature = sharedAccessSignature;
+ IsUsingGateway = isUsingGateway;
+ }
+
public string IotHubName { get; private set; }
public string DeviceId { get; private set; }
diff --git a/iothub/device/src/IotHubConnectionStringBuilder.cs b/iothub/device/src/IotHubConnectionStringBuilder.cs
index 8c133f9e04..181cd642f6 100644
--- a/iothub/device/src/IotHubConnectionStringBuilder.cs
+++ b/iothub/device/src/IotHubConnectionStringBuilder.cs
@@ -53,7 +53,7 @@ public sealed class IotHubConnectionStringBuilder
///
/// Initializes a new instance of the class.
///
- private IotHubConnectionStringBuilder()
+ internal IotHubConnectionStringBuilder()
{
}
diff --git a/iothub/device/src/Transport/Amqp/AmqpAuthenticationRefresher.cs b/iothub/device/src/Transport/Amqp/AmqpAuthenticationRefresher.cs
index 4d9a8d6cad..0b9c09b1b7 100644
--- a/iothub/device/src/Transport/Amqp/AmqpAuthenticationRefresher.cs
+++ b/iothub/device/src/Transport/Amqp/AmqpAuthenticationRefresher.cs
@@ -21,7 +21,7 @@ internal class AmqpAuthenticationRefresher : IAmqpAuthenticationRefresher, IDisp
private Task _refreshLoop;
private bool _disposed;
- internal AmqpAuthenticationRefresher(DeviceIdentity deviceIdentity, AmqpIotCbsLink amqpCbsLink)
+ internal AmqpAuthenticationRefresher(IDeviceIdentity deviceIdentity, AmqpIotCbsLink amqpCbsLink)
{
_amqpIotCbsLink = amqpCbsLink;
_connectionString = deviceIdentity.IotHubConnectionString;
diff --git a/iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs b/iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs
index 1e477ed796..46d1b245dc 100644
--- a/iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs
+++ b/iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs
@@ -9,12 +9,13 @@
using Microsoft.Azure.Devices.Client.Extensions;
using Microsoft.Azure.Devices.Client.Transport.AmqpIot;
using System.Collections.Generic;
+using System.Linq;
namespace Microsoft.Azure.Devices.Client.Transport.Amqp
{
internal class AmqpConnectionHolder : IAmqpConnectionHolder, IAmqpUnitManager
{
- private readonly DeviceIdentity _deviceIdentity;
+ private readonly IDeviceIdentity _deviceIdentity;
private readonly AmqpIotConnector _amqpIotConnector;
private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1);
private readonly HashSet _amqpUnits = new HashSet();
@@ -23,7 +24,7 @@ internal class AmqpConnectionHolder : IAmqpConnectionHolder, IAmqpUnitManager
private IAmqpAuthenticationRefresher _amqpAuthenticationRefresher;
private volatile bool _disposed;
- public AmqpConnectionHolder(DeviceIdentity deviceIdentity)
+ public AmqpConnectionHolder(IDeviceIdentity deviceIdentity)
{
_deviceIdentity = deviceIdentity;
_amqpIotConnector = new AmqpIotConnector(deviceIdentity.AmqpTransportSettings, deviceIdentity.IotHubConnectionString.HostName);
@@ -34,7 +35,7 @@ public AmqpConnectionHolder(DeviceIdentity deviceIdentity)
}
public AmqpUnit CreateAmqpUnit(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
Func onMethodCallback,
Action twinMessageListener,
Func onModuleMessageReceivedCallback,
@@ -140,7 +141,7 @@ private void Dispose(bool disposing)
_disposed = true;
}
- public async Task CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
+ public async Task CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
@@ -159,7 +160,7 @@ public async Task CreateRefresherAsync(DeviceIdent
return amqpAuthenticator;
}
- public async Task OpenSessionAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
+ public async Task OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
@@ -274,9 +275,14 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
}
}
- internal DeviceIdentity GetDeviceIdentityOfAuthenticationProvider()
+ internal IDeviceIdentity GetDeviceIdentityOfAuthenticationProvider()
{
return _deviceIdentity;
}
+
+ internal bool IsEmpty()
+ {
+ return !_amqpUnits.Any();
+ }
}
-}
\ No newline at end of file
+}
diff --git a/iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs b/iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs
index caffb063b4..dfcc9327d1 100644
--- a/iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs
+++ b/iothub/device/src/Transport/Amqp/AmqpConnectionPool.cs
@@ -16,8 +16,13 @@ internal class AmqpConnectionPool : IAmqpUnitManager
private readonly IDictionary _amqpSasGroupedPool = new Dictionary();
private readonly object _lock = new object();
+ protected virtual IDictionary GetAmqpSasGroupedPoolDictionary()
+ {
+ return _amqpSasGroupedPool;
+ }
+
public AmqpUnit CreateAmqpUnit(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
Func onMethodCallback,
Action twinMessageListener,
Func onModuleMessageReceivedCallback,
@@ -36,21 +41,6 @@ public AmqpUnit CreateAmqpUnit(
{
AmqpConnectionHolder[] amqpConnectionHolders = ResolveConnectionGroup(deviceIdentity);
amqpConnectionHolder = ResolveConnectionByHashing(amqpConnectionHolders, deviceIdentity);
-
- // For group sas token authenticated devices over a multiplexed connection, the TokenRefresher
- // of the first client connecting will be used for generating the group sas tokens
- // and will be associated with the connection itself.
- // For this reason, if the device identity of the client is not the one associated with the
- // connection, the associated TokenRefresher can be safely disposed.
- // Note - This does not cause any identity related issues since the group sas tokens are generated
- // against the hub host as the intended audience (without the "device Id").
- if (deviceIdentity.AuthenticationModel == AuthenticationModel.SasGrouped
- && !ReferenceEquals(amqpConnectionHolder.GetDeviceIdentityOfAuthenticationProvider(), deviceIdentity)
- && deviceIdentity.IotHubConnectionString?.TokenRefresher != null
- && deviceIdentity.IotHubConnectionString.TokenRefresher.DisposalWithClient)
- {
- deviceIdentity.IotHubConnectionString.TokenRefresher.Dispose();
- }
}
if (Logging.IsEnabled)
@@ -91,7 +81,7 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
Logging.Enter(this, amqpUnit, nameof(RemoveAmqpUnit));
}
- DeviceIdentity deviceIdentity = amqpUnit.GetDeviceIdentity();
+ IDeviceIdentity deviceIdentity = amqpUnit.GetDeviceIdentity();
if (deviceIdentity.IsPooling())
{
AmqpConnectionHolder amqpConnectionHolder;
@@ -99,8 +89,17 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
{
AmqpConnectionHolder[] amqpConnectionHolders = ResolveConnectionGroup(deviceIdentity);
amqpConnectionHolder = ResolveConnectionByHashing(amqpConnectionHolders, deviceIdentity);
+
+ amqpConnectionHolder.RemoveAmqpUnit(amqpUnit);
+
+ // If the connection holder does not have any more units, the entry needs to be nullified.
+ if (amqpConnectionHolder.IsEmpty())
+ {
+ int index = GetDeviceIdentityIndex(deviceIdentity, amqpConnectionHolders.Length);
+ amqpConnectionHolders[index] = null;
+ amqpConnectionHolder?.Dispose();
+ }
}
- amqpConnectionHolder.RemoveAmqpUnit(amqpUnit);
}
if (Logging.IsEnabled)
@@ -109,7 +108,7 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
}
}
- private AmqpConnectionHolder[] ResolveConnectionGroup(DeviceIdentity deviceIdentity)
+ private AmqpConnectionHolder[] ResolveConnectionGroup(IDeviceIdentity deviceIdentity)
{
if (deviceIdentity.AuthenticationModel == AuthenticationModel.SasIndividual)
{
@@ -123,25 +122,26 @@ private AmqpConnectionHolder[] ResolveConnectionGroup(DeviceIdentity deviceIdent
else
{
string scope = deviceIdentity.IotHubConnectionString.SharedAccessKeyName;
- _amqpSasGroupedPool.TryGetValue(scope, out AmqpConnectionHolder[] amqpConnectionHolders);
+ GetAmqpSasGroupedPoolDictionary().TryGetValue(scope, out AmqpConnectionHolder[] amqpConnectionHolders);
if (amqpConnectionHolders == null)
{
amqpConnectionHolders = new AmqpConnectionHolder[deviceIdentity.AmqpTransportSettings.AmqpConnectionPoolSettings.MaxPoolSize];
- _amqpSasGroupedPool.Add(scope, amqpConnectionHolders);
+ GetAmqpSasGroupedPoolDictionary().Add(scope, amqpConnectionHolders);
}
return amqpConnectionHolders;
}
}
- private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] pool, DeviceIdentity deviceIdentity)
+ private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] pool, IDeviceIdentity deviceIdentity)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(ResolveConnectionByHashing));
}
- int index = Math.Abs(deviceIdentity.GetHashCode()) % pool.Length;
+ int index = GetDeviceIdentityIndex(deviceIdentity, pool.Length);
+
if (pool[index] == null)
{
pool[index] = new AmqpConnectionHolder(deviceIdentity);
@@ -154,5 +154,12 @@ private AmqpConnectionHolder ResolveConnectionByHashing(AmqpConnectionHolder[] p
return pool[index];
}
+
+ private static int GetDeviceIdentityIndex(IDeviceIdentity deviceIdentity, int poolLength)
+ {
+ return deviceIdentity == null
+ ? throw new ArgumentNullException(nameof(deviceIdentity))
+ : Math.Abs(deviceIdentity.GetHashCode()) % poolLength;
+ }
}
}
diff --git a/iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs b/iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs
index df4ab761d6..116eb0cbf9 100644
--- a/iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs
+++ b/iothub/device/src/Transport/Amqp/AmqpTransportHandler.cs
@@ -49,7 +49,7 @@ internal AmqpTransportHandler(
{
_operationTimeout = transportSettings.OperationTimeout;
_onDesiredStatePatchListener = onDesiredStatePatchReceivedCallback;
- var deviceIdentity = new DeviceIdentity(connectionString, transportSettings, context.Get(), context.Get());
+ IDeviceIdentity deviceIdentity = new DeviceIdentity(connectionString, transportSettings, context.Get(), context.Get());
_amqpUnit = AmqpUnitManager.GetInstance().CreateAmqpUnit(
deviceIdentity,
onMethodCallback,
diff --git a/iothub/device/src/Transport/Amqp/AmqpUnit.cs b/iothub/device/src/Transport/Amqp/AmqpUnit.cs
index 7487347cc2..04f809dd8c 100644
--- a/iothub/device/src/Transport/Amqp/AmqpUnit.cs
+++ b/iothub/device/src/Transport/Amqp/AmqpUnit.cs
@@ -16,7 +16,7 @@ namespace Microsoft.Azure.Devices.Client.Transport.AmqpIot
internal class AmqpUnit : IDisposable
{
// If the first argument is set to true, we are disconnecting gracefully via CloseAsync.
- private readonly DeviceIdentity _deviceIdentity;
+ private readonly IDeviceIdentity _deviceIdentity;
private readonly Func _onMethodCallback;
private readonly Action _twinMessageListener;
@@ -51,7 +51,7 @@ internal class AmqpUnit : IDisposable
private IAmqpAuthenticationRefresher _amqpAuthenticationRefresher;
public AmqpUnit(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
IAmqpConnectionHolder amqpConnectionHolder,
Func onMethodCallback,
Action twinMessageListener,
@@ -70,7 +70,7 @@ public AmqpUnit(
Logging.Associate(this, _deviceIdentity, nameof(_deviceIdentity));
}
- internal DeviceIdentity GetDeviceIdentity()
+ internal IDeviceIdentity GetDeviceIdentity()
{
return _deviceIdentity;
}
diff --git a/iothub/device/src/Transport/Amqp/AmqpUnitManager.cs b/iothub/device/src/Transport/Amqp/AmqpUnitManager.cs
index f9d4723f84..f83cdda1f3 100644
--- a/iothub/device/src/Transport/Amqp/AmqpUnitManager.cs
+++ b/iothub/device/src/Transport/Amqp/AmqpUnitManager.cs
@@ -28,7 +28,7 @@ internal static AmqpUnitManager GetInstance()
}
public AmqpUnit CreateAmqpUnit(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
Func onMethodCallback,
Action twinMessageListener,
Func onModuleMessageReceivedCallback,
@@ -47,9 +47,9 @@ public AmqpUnit CreateAmqpUnit(
public void RemoveAmqpUnit(AmqpUnit amqpUnit)
{
- amqpUnit.Dispose();
IAmqpUnitManager amqpConnectionPool = ResolveConnectionPool(amqpUnit.GetDeviceIdentity().IotHubConnectionString.HostName);
amqpConnectionPool.RemoveAmqpUnit(amqpUnit);
+ amqpUnit.Dispose();
}
private IAmqpUnitManager ResolveConnectionPool(string host)
diff --git a/iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs b/iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs
index 1fd4fa5fa4..5def753284 100644
--- a/iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs
+++ b/iothub/device/src/Transport/Amqp/IAmqpConnectionHolder.cs
@@ -10,11 +10,11 @@ namespace Microsoft.Azure.Devices.Client.Transport.Amqp
{
internal interface IAmqpConnectionHolder : IDisposable
{
- Task OpenSessionAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken);
+ Task OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken);
Task EnsureConnectionAsync(CancellationToken cancellationToken);
- Task CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken);
+ Task CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken);
void Shutdown();
}
diff --git a/iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs b/iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs
index 509b5da002..46dbb412b5 100644
--- a/iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs
+++ b/iothub/device/src/Transport/Amqp/IAmqpUnitManager.cs
@@ -12,7 +12,7 @@ namespace Microsoft.Azure.Devices.Client.Transport.Amqp
internal interface IAmqpUnitManager
{
AmqpUnit CreateAmqpUnit(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
Func onMethodCallback,
Action twinMessageListener,
Func onModuleMessageReceivedCallback,
diff --git a/iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs b/iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs
index 36004aeb1d..b97c0fc8ce 100644
--- a/iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs
+++ b/iothub/device/src/Transport/AmqpIot/AmqpIotConnection.cs
@@ -84,7 +84,7 @@ internal async Task OpenSessionAsync(CancellationToken cancellat
}
}
- internal async Task CreateRefresherAsync(DeviceIdentity deviceIdentity, CancellationToken cancellationToken)
+ internal async Task CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (_amqpConnection.IsClosing())
{
diff --git a/iothub/device/src/Transport/AmqpIot/AmqpIotSession.cs b/iothub/device/src/Transport/AmqpIot/AmqpIotSession.cs
index cd6c6f26c1..e03bfa93bc 100644
--- a/iothub/device/src/Transport/AmqpIot/AmqpIotSession.cs
+++ b/iothub/device/src/Transport/AmqpIot/AmqpIotSession.cs
@@ -58,7 +58,7 @@ internal bool IsClosing()
#region Telemetry links
internal async Task OpenTelemetrySenderLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
CancellationToken cancellationToken)
{
return await OpenSendingAmqpLinkAsync(
@@ -75,7 +75,7 @@ internal async Task OpenTelemetrySenderLinkAsync(
}
internal async Task OpenMessageReceiverLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
CancellationToken cancellationToken)
{
return await OpenReceivingAmqpLinkAsync(
@@ -96,7 +96,7 @@ internal async Task OpenMessageReceiverLinkAsync(
#region EventLink
internal async Task OpenEventsReceiverLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
CancellationToken cancellationToken)
{
return await OpenReceivingAmqpLinkAsync(
@@ -117,7 +117,7 @@ internal async Task OpenEventsReceiverLinkAsync(
#region MethodLink
internal async Task OpenMethodsSenderLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
string correlationIdSuffix,
CancellationToken cancellationToken)
{
@@ -135,7 +135,7 @@ internal async Task OpenMethodsSenderLinkAsync(
}
internal async Task OpenMethodsReceiverLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
string correlationIdSuffix,
CancellationToken cancellationToken)
{
@@ -157,7 +157,7 @@ internal async Task OpenMethodsReceiverLinkAsync(
#region TwinLink
internal async Task OpenTwinReceiverLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
string correlationIdSuffix,
CancellationToken cancellationToken)
{
@@ -175,7 +175,7 @@ internal async Task OpenTwinReceiverLinkAsync(
}
internal async Task OpenTwinSenderLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
string correlationIdSuffix,
CancellationToken cancellationToken)
{
@@ -197,7 +197,7 @@ internal async Task OpenTwinSenderLinkAsync(
#region Common link handling
private static async Task OpenSendingAmqpLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
AmqpSession amqpSession,
byte? senderSettleMode,
byte? receiverSettleMode,
@@ -278,7 +278,7 @@ private static async Task OpenSendingAmqpLinkAsync(
}
private static async Task OpenReceivingAmqpLinkAsync(
- DeviceIdentity deviceIdentity,
+ IDeviceIdentity deviceIdentity,
AmqpSession amqpSession,
byte? senderSettleMode,
byte? receiverSettleMode,
@@ -353,7 +353,7 @@ private static async Task OpenReceivingAmqpLinkAsync(
}
}
- private static string BuildLinkAddress(DeviceIdentity deviceIdentity, string deviceTemplate, string moduleTemplate)
+ private static string BuildLinkAddress(IDeviceIdentity deviceIdentity, string deviceTemplate, string moduleTemplate)
{
string path = string.IsNullOrEmpty(deviceIdentity.IotHubConnectionString.ModuleId)
? string.Format(
diff --git a/iothub/device/src/Transport/DeviceIdentity.cs b/iothub/device/src/Transport/DeviceIdentity.cs
index dc1505eac6..6b13877ca4 100644
--- a/iothub/device/src/Transport/DeviceIdentity.cs
+++ b/iothub/device/src/Transport/DeviceIdentity.cs
@@ -12,16 +12,20 @@ namespace Microsoft.Azure.Devices.Client.Transport
/// - connection string
/// - transport settings
///
- internal class DeviceIdentity
+ internal class DeviceIdentity : IDeviceIdentity
{
- internal IotHubConnectionString IotHubConnectionString { get; }
- internal AmqpTransportSettings AmqpTransportSettings { get; }
- internal ProductInfo ProductInfo { get; }
- internal AuthenticationModel AuthenticationModel { get; }
- internal string Audience { get; }
- internal ClientOptions Options { get; }
+ public IotHubConnectionString IotHubConnectionString { get; }
+ public AmqpTransportSettings AmqpTransportSettings { get; }
+ public ProductInfo ProductInfo { get; }
+ public AuthenticationModel AuthenticationModel { get; }
+ public string Audience { get; }
+ public ClientOptions Options { get; }
- internal DeviceIdentity(IotHubConnectionString iotHubConnectionString, AmqpTransportSettings amqpTransportSettings, ProductInfo productInfo, ClientOptions options)
+ internal DeviceIdentity(
+ IotHubConnectionString iotHubConnectionString,
+ AmqpTransportSettings amqpTransportSettings,
+ ProductInfo productInfo,
+ ClientOptions options)
{
IotHubConnectionString = iotHubConnectionString;
AmqpTransportSettings = amqpTransportSettings;
@@ -31,14 +35,9 @@ internal DeviceIdentity(IotHubConnectionString iotHubConnectionString, AmqpTrans
if (amqpTransportSettings.ClientCertificate == null)
{
Audience = CreateAudience(IotHubConnectionString);
- if (iotHubConnectionString.SharedAccessKeyName == null)
- {
- AuthenticationModel = AuthenticationModel.SasIndividual;
- }
- else
- {
- AuthenticationModel = AuthenticationModel.SasGrouped;
- }
+ AuthenticationModel = iotHubConnectionString.SharedAccessKeyName == null
+ ? AuthenticationModel.SasIndividual
+ : AuthenticationModel.SasGrouped;
}
else
{
@@ -61,7 +60,7 @@ private static string CreateAudience(IotHubConnectionString connectionString)
}
}
- internal bool IsPooling()
+ public bool IsPooling()
{
return (AuthenticationModel != AuthenticationModel.X509) && (AmqpTransportSettings?.AmqpConnectionPoolSettings?.Pooling ?? false);
}
diff --git a/iothub/device/src/Transport/IDeviceClientEndpointIdentityFactory.cs b/iothub/device/src/Transport/IDeviceClientEndpointIdentityFactory.cs
deleted file mode 100644
index b0d61a6686..0000000000
--- a/iothub/device/src/Transport/IDeviceClientEndpointIdentityFactory.cs
+++ /dev/null
@@ -1,18 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-
-using System;
-using System.Collections.Generic;
-using System.Text;
-using System.IO;
-
-namespace Microsoft.Azure.Devices.Client.Transport
-{
- ///
- /// Factory interface to create DeviceClientEndpointIdentity objects for Amqp transport layer
- ///
- internal interface IDeviceClientEndpointIdentityFactory
- {
- DeviceIdentity Create(IotHubConnectionString iotHubConnectionString, AmqpTransportSettings amqpTransportSettings, ProductInfo productInfo);
- }
-}
diff --git a/iothub/device/src/Transport/IDeviceIdentity.cs b/iothub/device/src/Transport/IDeviceIdentity.cs
new file mode 100644
index 0000000000..f03e181959
--- /dev/null
+++ b/iothub/device/src/Transport/IDeviceIdentity.cs
@@ -0,0 +1,17 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Microsoft.Azure.Devices.Client.Transport
+{
+ internal interface IDeviceIdentity
+ {
+ AuthenticationModel AuthenticationModel { get; }
+ AmqpTransportSettings AmqpTransportSettings { get; }
+ IotHubConnectionString IotHubConnectionString { get; }
+ ProductInfo ProductInfo { get; }
+ ClientOptions Options { get; }
+ string Audience { get; }
+
+ bool IsPooling();
+ }
+}
diff --git a/iothub/device/tests/Amqp/AmqpConnectionPoolTests.cs b/iothub/device/tests/Amqp/AmqpConnectionPoolTests.cs
new file mode 100644
index 0000000000..506aa2ff20
--- /dev/null
+++ b/iothub/device/tests/Amqp/AmqpConnectionPoolTests.cs
@@ -0,0 +1,74 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using FluentAssertions;
+using Microsoft.Azure.Devices.Client.Transport;
+using Microsoft.Azure.Devices.Client.Transport.Amqp;
+using Microsoft.Azure.Devices.Client.Transport.AmqpIot;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+
+namespace Microsoft.Azure.Devices.Client.Tests.Amqp
+{
+ [TestClass]
+ public class AmqpConnectionPoolTests
+ {
+ internal class AmqpConnectionPoolTest : AmqpConnectionPool
+ {
+ private readonly IDictionary _dictionaryToUse;
+
+ public AmqpConnectionPoolTest(IDictionary dictionaryToUse)
+ {
+ _dictionaryToUse = dictionaryToUse;
+ }
+
+ protected override IDictionary GetAmqpSasGroupedPoolDictionary()
+ {
+ return _dictionaryToUse;
+ }
+ }
+
+ [TestMethod]
+ public void AmqpConnectionPool_Add_Remove_ConnectionHolderIsRemoved()
+ {
+ string sharedAccessKeyName = "HubOwner";
+ uint poolSize = 10;
+ IDeviceIdentity testDevice = CreatePooledSasGroupedDeviceIdentity(sharedAccessKeyName, poolSize);
+ IDictionary injectedDictionary = new Dictionary();
+
+ AmqpConnectionPoolTest pool = new AmqpConnectionPoolTest(injectedDictionary);
+
+ AmqpUnit addedUnit = pool.CreateAmqpUnit(testDevice, null, null, null, null, null);
+
+ injectedDictionary[sharedAccessKeyName].Count().Should().Be((int)poolSize);
+
+ pool.RemoveAmqpUnit(addedUnit);
+
+ foreach (object item in injectedDictionary[sharedAccessKeyName])
+ {
+ item.Should().BeNull();
+ }
+ }
+
+ private IDeviceIdentity CreatePooledSasGroupedDeviceIdentity(string sharedAccessKeyName, uint poolSize)
+ {
+ Mock deviceIdentity = new Mock();
+
+ deviceIdentity.Setup(m => m.IsPooling()).Returns(true);
+ deviceIdentity.Setup(m => m.AuthenticationModel).Returns(AuthenticationModel.SasGrouped);
+ deviceIdentity.Setup(m => m.IotHubConnectionString).Returns(new IotHubConnectionString(sharedAccessKeyName: sharedAccessKeyName));
+ deviceIdentity.Setup(m => m.AmqpTransportSettings).Returns(new AmqpTransportSettings(TransportType.Amqp_Tcp_Only)
+ {
+ AmqpConnectionPoolSettings = new AmqpConnectionPoolSettings()
+ {
+ Pooling = true,
+ MaxPoolSize = poolSize,
+ }
+ });
+
+ return deviceIdentity.Object;
+ }
+ }
+}