Skip to content

Commit

Permalink
(service-client: Refactor and add implementation for token credential…
Browse files Browse the repository at this point in the history
… input) (#1781)
  • Loading branch information
vinagesh authored Feb 10, 2021
1 parent 5b0e738 commit 8fa8efb
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 78 deletions.
62 changes: 10 additions & 52 deletions common/src/service/IotHubConnectionString.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,29 @@

namespace Microsoft.Azure.Devices
{
internal sealed class IotHubConnectionString : IAuthorizationHeaderProvider, ICbsTokenProvider
/// <summary>
/// The properties required for authentication to IoT hub using a connection string.
/// </summary>
internal sealed class IotHubConnectionString : IotHubCredential
{
private static readonly TimeSpan s_defaultTokenTimeToLive = TimeSpan.FromHours(1);
private const char UserSeparator = '@';
private static readonly TimeSpan _tokenTimeToLive = TimeSpan.FromHours(1);

public IotHubConnectionString(IotHubConnectionStringBuilder builder)
public IotHubConnectionString(IotHubConnectionStringBuilder builder) : base(builder.HostName)
{
if (builder == null)
{
throw new ArgumentNullException(nameof(builder));
}

Audience = builder.HostName;
HostName = string.IsNullOrEmpty(builder.GatewayHostName) ? builder.HostName : builder.GatewayHostName;
SharedAccessKeyName = builder.SharedAccessKeyName;
SharedAccessKey = builder.SharedAccessKey;
SharedAccessSignature = builder.SharedAccessSignature;
IotHubName = builder.IotHubName;
HttpsEndpoint = new UriBuilder("https", HostName).Uri;
AmqpEndpoint = new UriBuilder(CommonConstants.AmqpsScheme, builder.HostName, AmqpConstants.DefaultSecurePort).Uri;
DeviceId = builder.DeviceId;
ModuleId = builder.ModuleId;
GatewayHostName = builder.GatewayHostName;
}

public string IotHubName { get; private set; }

public string HostName { get; private set; }

public Uri HttpsEndpoint { get; private set; }

public Uri AmqpEndpoint { get; private set; }

public string Audience { get; private set; }

public string SharedAccessKeyName { get; private set; }
Expand All @@ -58,39 +48,17 @@ public IotHubConnectionString(IotHubConnectionStringBuilder builder)

public string GatewayHostName { get; private set; }

public string GetUser()
{
var stringBuilder = new StringBuilder();
stringBuilder.Append(SharedAccessKeyName);
stringBuilder.Append(UserSeparator);
stringBuilder.Append("sas.");
stringBuilder.Append("root.");
stringBuilder.Append(IotHubName);

return stringBuilder.ToString();
}

public string GetPassword()
{
string password;
if (string.IsNullOrWhiteSpace(SharedAccessSignature))
{
password = BuildToken(out _);
}
else
{
password = SharedAccessSignature;
}

return password;
return string.IsNullOrWhiteSpace(SharedAccessSignature) ? BuildToken(out _) : SharedAccessSignature;
}

public string GetAuthorizationHeader()
public override string GetAuthorizationHeader()
{
return GetPassword();
}

Task<CbsToken> ICbsTokenProvider.GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims)
public override Task<CbsToken> GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims)
{
string tokenValue;
CbsToken token;
Expand All @@ -108,16 +76,6 @@ Task<CbsToken> ICbsTokenProvider.GetTokenAsync(Uri namespaceAddress, string appl
return Task.FromResult(token);
}

public Uri BuildLinkAddress(string path)
{
var builder = new UriBuilder(AmqpEndpoint)
{
Path = path,
};

return builder.Uri;
}

public static IotHubConnectionString Parse(string connectionString)
{
var builder = IotHubConnectionStringBuilder.Create(connectionString);
Expand All @@ -130,7 +88,7 @@ private string BuildToken(out TimeSpan ttl)
{
KeyName = SharedAccessKeyName,
Key = SharedAccessKey,
TimeToLive = s_defaultTokenTimeToLive,
TimeToLive = _tokenTimeToLive,
Target = Audience
};

Expand Down
28 changes: 14 additions & 14 deletions iothub/service/src/IotHubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ internal sealed class IotHubConnection : IDisposable
private IOThreadTimer _refreshTokenTimer;
#endif

public IotHubConnection(IotHubConnectionString connectionString, AccessRights accessRights, bool useWebSocketOnly, ServiceClientTransportSettings transportSettings)
public IotHubConnection(IotHubCredential credential, AccessRights accessRights, bool useWebSocketOnly, ServiceClientTransportSettings transportSettings)
{
#if !NET451
_refreshTokenTimer = new IOThreadTimerSlim(s => ((IotHubConnection)s).OnRefreshTokenAsync(), this);
#else
_refreshTokenTimer = new IOThreadTimer(s => ((IotHubConnection)s).OnRefreshTokenAsync(), this, false);
#endif

ConnectionString = connectionString;
Credential = credential;
_accessRights = accessRights;
_faultTolerantSession = new FaultTolerantAmqpObject<AmqpSession>(CreateSessionAsync, CloseConnection);
_useWebSocketOnly = useWebSocketOnly;
Expand All @@ -71,7 +71,7 @@ internal IotHubConnection(Func<TimeSpan, Task<AmqpSession>> onCreate, Action<Amq
_faultTolerantSession = new FaultTolerantAmqpObject<AmqpSession>(onCreate, onClose);
}

internal IotHubConnectionString ConnectionString { get; private set; }
internal IotHubCredential Credential { get; private set; }

public Task OpenAsync(TimeSpan timeout)
{
Expand Down Expand Up @@ -114,7 +114,7 @@ public async Task<SendingAmqpLink> CreateSendingLinkAsync(string path, TimeSpan
session = await _faultTolerantSession.GetOrCreateAsync(timeoutHelper.RemainingTime()).ConfigureAwait(false);
}

Uri linkAddress = ConnectionString.BuildLinkAddress(path);
Uri linkAddress = Credential.BuildLinkAddress(path);

var linkSettings = new AmqpLinkSettings
{
Expand Down Expand Up @@ -156,7 +156,7 @@ public async Task<ReceivingAmqpLink> CreateReceivingLinkAsync(string path, TimeS
session = await _faultTolerantSession.GetOrCreateAsync(timeoutHelper.RemainingTime()).ConfigureAwait(false);
}

Uri linkAddress = ConnectionString.BuildLinkAddress(path);
Uri linkAddress = Credential.BuildLinkAddress(path);

var linkSettings = new AmqpLinkSettings
{
Expand Down Expand Up @@ -306,7 +306,7 @@ private async Task<AmqpSession> CreateSessionAsync(TimeSpan timeout)
{
MaxFrameSize = AmqpConstants.DefaultMaxFrameSize,
ContainerId = Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture), // Use a human readable link name to help with debugging
HostName = ConnectionString.AmqpEndpoint.Host,
HostName = Credential.AmqpEndpoint.Host,
};

var amqpConnection = new AmqpConnection(transport, amqpSettings, amqpConnectionSettings);
Expand Down Expand Up @@ -430,7 +430,7 @@ private async Task<TransportBase> CreateClientWebSocketTransportAsync(TimeSpan t
try
{
var timeoutHelper = new TimeoutHelper(timeout);
var websocketUri = new Uri($"{ WebSocketConstants.Scheme }{ ConnectionString.HostName}:{ WebSocketConstants.SecurePort}{WebSocketConstants.UriSuffix}");
var websocketUri = new Uri($"{ WebSocketConstants.Scheme }{ Credential.HostName}:{ WebSocketConstants.SecurePort}{WebSocketConstants.UriSuffix}");

Logging.Info(this, websocketUri, nameof(CreateClientWebSocketTransportAsync));

Expand Down Expand Up @@ -494,13 +494,13 @@ private TlsTransportSettings CreateTlsTransportSettings()
{
var tcpTransportSettings = new TcpTransportSettings
{
Host = ConnectionString.HostName,
Port = ConnectionString.AmqpEndpoint.Port,
Host = Credential.HostName,
Port = Credential.AmqpEndpoint.Port,
};

var tlsTransportSettings = new TlsTransportSettings(tcpTransportSettings)
{
TargetHost = ConnectionString.HostName,
TargetHost = Credential.HostName,
Certificate = null, // TODO: add client cert support
CertificateValidationCallback = OnRemoteCertificateValidation
};
Expand Down Expand Up @@ -545,12 +545,12 @@ private async Task SendCbsTokenAsync(AmqpCbsLink cbsLink, TimeSpan timeout)
{
Logging.Enter(this, cbsLink, timeout, nameof(SendCbsTokenAsync));

string audience = ConnectionString.AmqpEndpoint.AbsoluteUri;
string resource = ConnectionString.AmqpEndpoint.AbsoluteUri;
string audience = Credential.AmqpEndpoint.AbsoluteUri;
string resource = Credential.AmqpEndpoint.AbsoluteUri;
DateTime expiresAtUtc = await cbsLink
.SendTokenAsync(
ConnectionString,
ConnectionString.AmqpEndpoint,
Credential,
Credential.AmqpEndpoint,
audience,
resource,
AccessRightsHelper.AccessRightsToStringArray(_accessRights),
Expand Down
69 changes: 69 additions & 0 deletions iothub/service/src/IotHubCredential.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Devices.Common;

namespace Microsoft.Azure.Devices
{
/// <summary>
/// The properties required for authentication to IoT hub that are independent of the authentication type.
/// </summary>
internal abstract class IotHubCredential
: IAuthorizationHeaderProvider, ICbsTokenProvider
{
private const string HostNameSeparator = ".";
private const string HttpsEndpointPrefix = "https";

// Azure.Core (used in IotHubTokenCredential) is not available in NET451.
// So we need this constructor for the build to pass.
protected IotHubCredential()
{
}

protected IotHubCredential(string hostName)
{
HostName = hostName;
IotHubName = GetIotHubName(hostName);
AmqpEndpoint = new UriBuilder(CommonConstants.AmqpsScheme, HostName, AmqpConstants.DefaultSecurePort).Uri;
HttpsEndpoint = new UriBuilder(HttpsEndpointPrefix, HostName).Uri;
}

public string IotHubName { get; protected set; }

public string HostName { get; protected set; }

public Uri HttpsEndpoint { get; protected set; }

public Uri AmqpEndpoint { get; protected set; }

public abstract string GetAuthorizationHeader();

public abstract Task<CbsToken> GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims);

public Uri BuildLinkAddress(string path)
{
var builder = new UriBuilder(AmqpEndpoint)
{
Path = path,
};

return builder.Uri;
}

private static string GetIotHubName(string hostName)
{
if (string.IsNullOrWhiteSpace(hostName))
{
throw new ArgumentNullException($"{nameof(hostName)} is null or empty.");
}

int index = hostName.IndexOf(HostNameSeparator, StringComparison.OrdinalIgnoreCase);
string iotHubName = index >= 0 ? hostName.Substring(0, index) : hostName;
return iotHubName;
}
}
}
69 changes: 69 additions & 0 deletions iothub/service/src/IotHubTokenCredential.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Azure.Amqp;
using System.Threading;

#if !NET451

using Azure.Core;

#endif

namespace Microsoft.Azure.Devices
{
/// <summary>
/// The properties required for authentication to IoT hub using a token credential.
/// </summary>
internal class IotHubTokenCredential : IotHubCredential
{
#if NET451

public IotHubTokenCredential()
{
throw new InvalidOperationException("nameof(TokenCredential) is not supported in NET451");
}
#else
private const string _tokenType = "jwt";
private readonly TokenCredential _credential;

public IotHubTokenCredential(string hostName, TokenCredential credential) : base(hostName)
{
_credential = credential;
}

#endif

public override string GetAuthorizationHeader()
{
#if NET451
throw new InvalidOperationException($"{nameof(GetAuthorizationHeader)} is not supported on NET451");

#else
AccessToken token = _credential.GetToken(new TokenRequestContext(), new CancellationToken());
return $"Bearer {token.Token}";

#endif
}

#pragma warning disable CS1998 // Disabled as we need to throw exception for NET 451.

public async override Task<CbsToken> GetTokenAsync(Uri namespaceAddress, string appliesTo, string[] requiredClaims)
{
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
#if NET451
throw new InvalidOperationException($"{nameof(GetTokenAsync)} is not supported on NET451");

#else
AccessToken token = await _credential.GetTokenAsync(new TokenRequestContext(), new CancellationToken()).ConfigureAwait(false);
return new CbsToken(
token.Token,
_tokenType,
token.ExpiresOn.UtcDateTime);
#endif
}
}
}
1 change: 1 addition & 0 deletions iothub/service/src/Microsoft.Azure.Devices.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
<PackageReference Include="System.Diagnostics.TraceSource" Version="4.3.0" />
<PackageReference Include="System.Diagnostics.Contracts" Version="4.3.0" />
<PackageReference Include="Microsoft.Rest.ClientRuntime" Version="2.3.21" />
<PackageReference Include="Azure.Core" Version="1.9.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Loading

0 comments on commit 8fa8efb

Please sign in to comment.