From 8fa8efb04805fc1c1bf862638217068a121a2205 Mon Sep 17 00:00:00 2001 From: Sindhu Nagesh Date: Wed, 10 Feb 2021 10:10:57 -0800 Subject: [PATCH] (service-client: Refactor and add implementation for token credential input) (#1781) --- common/src/service/IotHubConnectionString.cs | 62 +++-------------- iothub/service/src/IotHubConnection.cs | 28 ++++---- iothub/service/src/IotHubCredential.cs | 69 +++++++++++++++++++ iothub/service/src/IotHubTokenCredential.cs | 69 +++++++++++++++++++ .../src/Microsoft.Azure.Devices.csproj | 1 + .../ServiceClientConnectionStringTests.cs | 24 +++---- 6 files changed, 175 insertions(+), 78 deletions(-) create mode 100644 iothub/service/src/IotHubCredential.cs create mode 100644 iothub/service/src/IotHubTokenCredential.cs diff --git a/common/src/service/IotHubConnectionString.cs b/common/src/service/IotHubConnectionString.cs index be59119d68..a23c509351 100644 --- a/common/src/service/IotHubConnectionString.cs +++ b/common/src/service/IotHubConnectionString.cs @@ -11,12 +11,14 @@ namespace Microsoft.Azure.Devices { - internal sealed class IotHubConnectionString : IAuthorizationHeaderProvider, ICbsTokenProvider + /// + /// The properties required for authentication to IoT hub using a connection string. + /// + internal sealed class IotHubConnectionString : IotHubCredential { - private static readonly TimeSpan s_defaultTokenTimeToLive = TimeSpan.FromHours(1); - private const char UserSeparator = '@'; + private static readonly TimeSpan _tokenTimeToLive = TimeSpan.FromHours(1); - public IotHubConnectionString(IotHubConnectionStringBuilder builder) + public IotHubConnectionString(IotHubConnectionStringBuilder builder) : base(builder.HostName) { if (builder == null) { @@ -24,26 +26,14 @@ public IotHubConnectionString(IotHubConnectionStringBuilder builder) } Audience = builder.HostName; - HostName = string.IsNullOrEmpty(builder.GatewayHostName) ? builder.HostName : builder.GatewayHostName; SharedAccessKeyName = builder.SharedAccessKeyName; SharedAccessKey = builder.SharedAccessKey; SharedAccessSignature = builder.SharedAccessSignature; - IotHubName = builder.IotHubName; - HttpsEndpoint = new UriBuilder("https", HostName).Uri; - AmqpEndpoint = new UriBuilder(CommonConstants.AmqpsScheme, builder.HostName, AmqpConstants.DefaultSecurePort).Uri; DeviceId = builder.DeviceId; ModuleId = builder.ModuleId; GatewayHostName = builder.GatewayHostName; } - public string IotHubName { get; private set; } - - public string HostName { get; private set; } - - public Uri HttpsEndpoint { get; private set; } - - public Uri AmqpEndpoint { get; private set; } - public string Audience { get; private set; } public string SharedAccessKeyName { get; private set; } @@ -58,39 +48,17 @@ public IotHubConnectionString(IotHubConnectionStringBuilder builder) public string GatewayHostName { get; private set; } - public string GetUser() - { - var stringBuilder = new StringBuilder(); - stringBuilder.Append(SharedAccessKeyName); - stringBuilder.Append(UserSeparator); - stringBuilder.Append("sas."); - stringBuilder.Append("root."); - stringBuilder.Append(IotHubName); - - return stringBuilder.ToString(); - } - public string GetPassword() { - string password; - if (string.IsNullOrWhiteSpace(SharedAccessSignature)) - { - password = BuildToken(out _); - } - else - { - password = SharedAccessSignature; - } - - return password; + return string.IsNullOrWhiteSpace(SharedAccessSignature) ? BuildToken(out _) : SharedAccessSignature; } - public string GetAuthorizationHeader() + public override string GetAuthorizationHeader() { return GetPassword(); } - Task ICbsTokenProvider.GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims) + public override Task GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims) { string tokenValue; CbsToken token; @@ -108,16 +76,6 @@ Task ICbsTokenProvider.GetTokenAsync(Uri namespaceAddress, string appl return Task.FromResult(token); } - public Uri BuildLinkAddress(string path) - { - var builder = new UriBuilder(AmqpEndpoint) - { - Path = path, - }; - - return builder.Uri; - } - public static IotHubConnectionString Parse(string connectionString) { var builder = IotHubConnectionStringBuilder.Create(connectionString); @@ -130,7 +88,7 @@ private string BuildToken(out TimeSpan ttl) { KeyName = SharedAccessKeyName, Key = SharedAccessKey, - TimeToLive = s_defaultTokenTimeToLive, + TimeToLive = _tokenTimeToLive, Target = Audience }; diff --git a/iothub/service/src/IotHubConnection.cs b/iothub/service/src/IotHubConnection.cs index 8e7cebf434..5a17e2ede5 100644 --- a/iothub/service/src/IotHubConnection.cs +++ b/iothub/service/src/IotHubConnection.cs @@ -51,7 +51,7 @@ internal sealed class IotHubConnection : IDisposable private IOThreadTimer _refreshTokenTimer; #endif - public IotHubConnection(IotHubConnectionString connectionString, AccessRights accessRights, bool useWebSocketOnly, ServiceClientTransportSettings transportSettings) + public IotHubConnection(IotHubCredential credential, AccessRights accessRights, bool useWebSocketOnly, ServiceClientTransportSettings transportSettings) { #if !NET451 _refreshTokenTimer = new IOThreadTimerSlim(s => ((IotHubConnection)s).OnRefreshTokenAsync(), this); @@ -59,7 +59,7 @@ public IotHubConnection(IotHubConnectionString connectionString, AccessRights ac _refreshTokenTimer = new IOThreadTimer(s => ((IotHubConnection)s).OnRefreshTokenAsync(), this, false); #endif - ConnectionString = connectionString; + Credential = credential; _accessRights = accessRights; _faultTolerantSession = new FaultTolerantAmqpObject(CreateSessionAsync, CloseConnection); _useWebSocketOnly = useWebSocketOnly; @@ -71,7 +71,7 @@ internal IotHubConnection(Func> onCreate, Action(onCreate, onClose); } - internal IotHubConnectionString ConnectionString { get; private set; } + internal IotHubCredential Credential { get; private set; } public Task OpenAsync(TimeSpan timeout) { @@ -114,7 +114,7 @@ public async Task CreateSendingLinkAsync(string path, TimeSpan session = await _faultTolerantSession.GetOrCreateAsync(timeoutHelper.RemainingTime()).ConfigureAwait(false); } - Uri linkAddress = ConnectionString.BuildLinkAddress(path); + Uri linkAddress = Credential.BuildLinkAddress(path); var linkSettings = new AmqpLinkSettings { @@ -156,7 +156,7 @@ public async Task CreateReceivingLinkAsync(string path, TimeS session = await _faultTolerantSession.GetOrCreateAsync(timeoutHelper.RemainingTime()).ConfigureAwait(false); } - Uri linkAddress = ConnectionString.BuildLinkAddress(path); + Uri linkAddress = Credential.BuildLinkAddress(path); var linkSettings = new AmqpLinkSettings { @@ -306,7 +306,7 @@ private async Task CreateSessionAsync(TimeSpan timeout) { MaxFrameSize = AmqpConstants.DefaultMaxFrameSize, ContainerId = Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture), // Use a human readable link name to help with debugging - HostName = ConnectionString.AmqpEndpoint.Host, + HostName = Credential.AmqpEndpoint.Host, }; var amqpConnection = new AmqpConnection(transport, amqpSettings, amqpConnectionSettings); @@ -430,7 +430,7 @@ private async Task CreateClientWebSocketTransportAsync(TimeSpan t try { var timeoutHelper = new TimeoutHelper(timeout); - var websocketUri = new Uri($"{ WebSocketConstants.Scheme }{ ConnectionString.HostName}:{ WebSocketConstants.SecurePort}{WebSocketConstants.UriSuffix}"); + var websocketUri = new Uri($"{ WebSocketConstants.Scheme }{ Credential.HostName}:{ WebSocketConstants.SecurePort}{WebSocketConstants.UriSuffix}"); Logging.Info(this, websocketUri, nameof(CreateClientWebSocketTransportAsync)); @@ -494,13 +494,13 @@ private TlsTransportSettings CreateTlsTransportSettings() { var tcpTransportSettings = new TcpTransportSettings { - Host = ConnectionString.HostName, - Port = ConnectionString.AmqpEndpoint.Port, + Host = Credential.HostName, + Port = Credential.AmqpEndpoint.Port, }; var tlsTransportSettings = new TlsTransportSettings(tcpTransportSettings) { - TargetHost = ConnectionString.HostName, + TargetHost = Credential.HostName, Certificate = null, // TODO: add client cert support CertificateValidationCallback = OnRemoteCertificateValidation }; @@ -545,12 +545,12 @@ private async Task SendCbsTokenAsync(AmqpCbsLink cbsLink, TimeSpan timeout) { Logging.Enter(this, cbsLink, timeout, nameof(SendCbsTokenAsync)); - string audience = ConnectionString.AmqpEndpoint.AbsoluteUri; - string resource = ConnectionString.AmqpEndpoint.AbsoluteUri; + string audience = Credential.AmqpEndpoint.AbsoluteUri; + string resource = Credential.AmqpEndpoint.AbsoluteUri; DateTime expiresAtUtc = await cbsLink .SendTokenAsync( - ConnectionString, - ConnectionString.AmqpEndpoint, + Credential, + Credential.AmqpEndpoint, audience, resource, AccessRightsHelper.AccessRightsToStringArray(_accessRights), diff --git a/iothub/service/src/IotHubCredential.cs b/iothub/service/src/IotHubCredential.cs new file mode 100644 index 0000000000..42ec547ae7 --- /dev/null +++ b/iothub/service/src/IotHubCredential.cs @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Azure.Amqp; +using Microsoft.Azure.Devices.Common; + +namespace Microsoft.Azure.Devices +{ + /// + /// The properties required for authentication to IoT hub that are independent of the authentication type. + /// + internal abstract class IotHubCredential + : IAuthorizationHeaderProvider, ICbsTokenProvider + { + private const string HostNameSeparator = "."; + private const string HttpsEndpointPrefix = "https"; + + // Azure.Core (used in IotHubTokenCredential) is not available in NET451. + // So we need this constructor for the build to pass. + protected IotHubCredential() + { + } + + protected IotHubCredential(string hostName) + { + HostName = hostName; + IotHubName = GetIotHubName(hostName); + AmqpEndpoint = new UriBuilder(CommonConstants.AmqpsScheme, HostName, AmqpConstants.DefaultSecurePort).Uri; + HttpsEndpoint = new UriBuilder(HttpsEndpointPrefix, HostName).Uri; + } + + public string IotHubName { get; protected set; } + + public string HostName { get; protected set; } + + public Uri HttpsEndpoint { get; protected set; } + + public Uri AmqpEndpoint { get; protected set; } + + public abstract string GetAuthorizationHeader(); + + public abstract Task GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims); + + public Uri BuildLinkAddress(string path) + { + var builder = new UriBuilder(AmqpEndpoint) + { + Path = path, + }; + + return builder.Uri; + } + + private static string GetIotHubName(string hostName) + { + if (string.IsNullOrWhiteSpace(hostName)) + { + throw new ArgumentNullException($"{nameof(hostName)} is null or empty."); + } + + int index = hostName.IndexOf(HostNameSeparator, StringComparison.OrdinalIgnoreCase); + string iotHubName = index >= 0 ? hostName.Substring(0, index) : hostName; + return iotHubName; + } + } +} diff --git a/iothub/service/src/IotHubTokenCredential.cs b/iothub/service/src/IotHubTokenCredential.cs new file mode 100644 index 0000000000..bb2c1bdde5 --- /dev/null +++ b/iothub/service/src/IotHubTokenCredential.cs @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Azure.Amqp; +using System.Threading; + +#if !NET451 + +using Azure.Core; + +#endif + +namespace Microsoft.Azure.Devices +{ + /// + /// The properties required for authentication to IoT hub using a token credential. + /// + internal class IotHubTokenCredential : IotHubCredential + { +#if NET451 + + public IotHubTokenCredential() + { + throw new InvalidOperationException("nameof(TokenCredential) is not supported in NET451"); + } +#else + private const string _tokenType = "jwt"; + private readonly TokenCredential _credential; + + public IotHubTokenCredential(string hostName, TokenCredential credential) : base(hostName) + { + _credential = credential; + } + +#endif + + public override string GetAuthorizationHeader() + { +#if NET451 + throw new InvalidOperationException($"{nameof(GetAuthorizationHeader)} is not supported on NET451"); + +#else + AccessToken token = _credential.GetToken(new TokenRequestContext(), new CancellationToken()); + return $"Bearer {token.Token}"; + +#endif + } + +#pragma warning disable CS1998 // Disabled as we need to throw exception for NET 451. + + public async override Task GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims) + { +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously +#if NET451 + throw new InvalidOperationException($"{nameof(GetTokenAsync)} is not supported on NET451"); + +#else + AccessToken token = await _credential.GetTokenAsync(new TokenRequestContext(), new CancellationToken()).ConfigureAwait(false); + return new CbsToken( + token.Token, + _tokenType, + token.ExpiresOn.UtcDateTime); +#endif + } + } +} diff --git a/iothub/service/src/Microsoft.Azure.Devices.csproj b/iothub/service/src/Microsoft.Azure.Devices.csproj index 67c8b026a0..d3498902cb 100644 --- a/iothub/service/src/Microsoft.Azure.Devices.csproj +++ b/iothub/service/src/Microsoft.Azure.Devices.csproj @@ -161,6 +161,7 @@ + diff --git a/iothub/service/tests/ConnectionString/ServiceClientConnectionStringTests.cs b/iothub/service/tests/ConnectionString/ServiceClientConnectionStringTests.cs index 1607a1a6a9..c542f86aea 100644 --- a/iothub/service/tests/ConnectionString/ServiceClientConnectionStringTests.cs +++ b/iothub/service/tests/ConnectionString/ServiceClientConnectionStringTests.cs @@ -11,7 +11,7 @@ namespace Microsoft.Azure.Devices.Api.Test.ConnectionString [TestCategory("Unit")] public class ServiceClientConnectionStringTests { - class TestAuthenticationMethod : IAuthenticationMethod + private class TestAuthenticationMethod : IAuthenticationMethod { public virtual IotHubConnectionStringBuilder Populate(IotHubConnectionStringBuilder iotHubConnectionStringBuilder) { @@ -29,7 +29,7 @@ public void ServiceClientConnectionStringDefaultScopeDefaultCredentialTypeTest() var serviceClient = (AmqpServiceClient)ServiceClient.CreateFromConnectionString(connectionString); Assert.IsNotNull(serviceClient.Connection); - Assert.IsNotNull(serviceClient.Connection.ConnectionString); + Assert.IsNotNull(serviceClient.Connection.Credential); } [TestMethod] @@ -39,7 +39,7 @@ public void ServiceClientConnectionStringIotHubScopeImplicitSharedAccessSignatur var serviceClient = (AmqpServiceClient)ServiceClient.CreateFromConnectionString(connectionString); Assert.IsNotNull(serviceClient.Connection); - Assert.IsNotNull(serviceClient.Connection.ConnectionString); + Assert.IsNotNull(serviceClient.Connection.Credential); } [TestMethod] @@ -49,7 +49,7 @@ public void ServiceClientConnectionStringIotHubScopeExplicitSharedAccessSignatur var serviceClient = (AmqpServiceClient)ServiceClient.CreateFromConnectionString(connectionString); Assert.IsNotNull(serviceClient.Connection); - Assert.IsNotNull(serviceClient.Connection.ConnectionString); + Assert.IsNotNull(serviceClient.Connection.Credential); } [TestMethod] @@ -59,7 +59,7 @@ public void ServiceClientConnectionStringIotHubScopeSharedAccessKeyCredentialTyp var serviceClient = (AmqpServiceClient)ServiceClient.CreateFromConnectionString(connectionString); Assert.IsNotNull(serviceClient.Connection); - Assert.IsNotNull(serviceClient.Connection.ConnectionString); + Assert.IsNotNull(serviceClient.Connection.Credential); } [TestMethod] @@ -69,7 +69,7 @@ public void ServiceClientConnectionStringDeviceScopeImplicitSharedAccessSignatur var serviceClient = (AmqpServiceClient)ServiceClient.CreateFromConnectionString(connectionString); Assert.IsNotNull(serviceClient.Connection); - Assert.IsNotNull(serviceClient.Connection.ConnectionString); + Assert.IsNotNull(serviceClient.Connection.Credential); } [TestMethod] @@ -79,7 +79,7 @@ public void ServiceClientConnectionStringDeviceScopeExplicitSharedAccessSignatur var serviceClient = (AmqpServiceClient)ServiceClient.CreateFromConnectionString(connectionString); Assert.IsNotNull(serviceClient.Connection); - Assert.IsNotNull(serviceClient.Connection.ConnectionString); + Assert.IsNotNull(serviceClient.Connection.Credential); } [TestMethod] @@ -89,7 +89,7 @@ public void ServiceClientConnectionStringDeviceScopeSharedAccessKeyCredentialTyp var serviceClient = (AmqpServiceClient)ServiceClient.CreateFromConnectionString(connectionString); Assert.IsNotNull(serviceClient.Connection); - Assert.IsNotNull(serviceClient.Connection.ConnectionString); + Assert.IsNotNull(serviceClient.Connection.Credential); } [TestMethod] @@ -117,7 +117,7 @@ public void ServiceClientIotHubConnectionStringBuilderTest() // Hostname without DNS is acceptable for localhost testing. iotHubConnectionStringBuilder.HostName = "adshgfvyregferuehfiuehr"; - + try { iotHubConnectionStringBuilder.HostName = "acme.azure-devices.net"; @@ -181,11 +181,11 @@ public void ServiceClient_ConnectionString_ModuleIdentity_SharedAccessKeyCredent string connectionString = "HostName=testhub.azure-devices-int.net;DeviceId=edgecapabledevice1;ModuleId=testModule;SharedAccessKey=dGVzdFN0cmluZzE=;GatewayHostName=edgehub1.ms.com"; var serviceClient = (AmqpServiceClient)ServiceClient.CreateFromConnectionString(connectionString); - Assert.IsNotNull(serviceClient.Connection); - IotHubConnectionString iotHubConnectionString = serviceClient.Connection.ConnectionString; + Assert.IsNotNull(serviceClient.Connection); + IotHubConnectionString iotHubConnectionString = (IotHubConnectionString)serviceClient.Connection.Credential; Assert.IsNotNull(iotHubConnectionString); Assert.AreEqual("testhub.azure-devices-int.net", iotHubConnectionString.Audience); - Assert.AreEqual("edgehub1.ms.com", iotHubConnectionString.HostName); + Assert.AreEqual("testhub.azure-devices-int.net", iotHubConnectionString.HostName); Assert.AreEqual("edgecapabledevice1", iotHubConnectionString.DeviceId); Assert.AreEqual("testModule", iotHubConnectionString.ModuleId); Assert.AreEqual("dGVzdFN0cmluZzE=", iotHubConnectionString.SharedAccessKey);