Skip to content

Commit

Permalink
return TokenCredential instead of AccessKey in parser
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Dec 5, 2024
1 parent af7141a commit a4501de
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 108 deletions.
13 changes: 10 additions & 3 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ public ServiceEndpoint(string nameWithEndpointType, string connectionString) : t
(Name, EndpointType) = Parse(nameWithEndpointType);
}

private static IAccessKey BuildAccessKey(ParsedConnectionString parsed)
{
return string.IsNullOrEmpty(parsed.AccessKey)
? new MicrosoftEntraAccessKey(parsed.Endpoint, parsed.TokenCredential, parsed.ServerEndpoint)
: new AccessKey(parsed.Endpoint, parsed.AccessKey);
}

/// <summary>
/// Connection string constructor
/// </summary>
Expand All @@ -116,12 +123,12 @@ public ServiceEndpoint(string connectionString, EndpointType type = EndpointType
throw new ArgumentException($"'{nameof(connectionString)}' cannot be null or whitespace.", nameof(connectionString));
}
ConnectionString = connectionString;

var result = ConnectionStringParser.Parse(connectionString);
EndpointType = type;
Name = name;

_accessKey = result.AccessKey;
var result = ConnectionStringParser.Parse(connectionString);

_accessKey = BuildAccessKey(result);
_serviceEndpoint = result.Endpoint;
_clientEndpoint = result.ClientEndpoint;
_serverEndpoint = result.ServerEndpoint;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Text.RegularExpressions;
using Azure.Core;
using Azure.Identity;

namespace Microsoft.Azure.SignalR;
Expand Down Expand Up @@ -39,6 +40,7 @@ internal static class ConnectionStringParser

private const string TypeAzure = "azure";

[Obsolete]
private const string TypeAzureAD = "aad";

private const string TypeAzureApp = "azure.app";
Expand Down Expand Up @@ -107,14 +109,9 @@ internal static ParsedConnectionString Parse(string connectionString)
// parse and validate port.
if (dict.TryGetValue(PortProperty, out var s))
{
if (int.TryParse(s, out var port) && port > 0 && port <= 0xFFFF)
{
builder.Port = port;
}
else
{
throw new ArgumentException(InvalidPortValue, nameof(port));
}
builder.Port = int.TryParse(s, out var port) && port > 0 && port <= 0xFFFF
? port
: throw new ArgumentException(InvalidPortValue, nameof(port));
}

Uri? clientEndpointUri = null;
Expand All @@ -140,19 +137,22 @@ internal static ParsedConnectionString Parse(string connectionString)

// try building accesskey.
dict.TryGetValue(AuthTypeProperty, out var type);
var accessKey = type?.ToLower() switch
{
TypeAzureAD => BuildAzureADAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzure => BuildAzureAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, serverEndpointUri, dict),
_ => BuildAccessKey(builder.Uri, dict),
var tokenCredential = type?.ToLower() switch
{
TypeAzureApp => BuildApplicationCredential(dict),
TypeAzureMsi => BuildManagedIdentityCredential(dict),
#pragma warning disable CS0612 // Type or member is obsolete
TypeAzureAD => BuildAzureTokenCredential(dict),
#pragma warning restore CS0612 // Type or member is obsolete
_ => new DefaultAzureCredential(),
};

return new ParsedConnectionString(builder.Uri)
dict.TryGetValue(AccessKeyProperty, out var accessKey);

return new ParsedConnectionString(builder.Uri, tokenCredential)
{
ClientEndpoint = clientEndpointUri,
AccessKey = accessKey,
ClientEndpoint = clientEndpointUri,
ServerEndpoint = serverEndpointUri
};
}
Expand All @@ -163,19 +163,20 @@ private static bool TryCreateEndpointUri(string endpoint, out Uri? uriResult)
&& (uriResult.Scheme == Uri.UriSchemeHttp || uriResult.Scheme == Uri.UriSchemeHttps);
}

private static IAccessKey BuildAzureADAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary<string, string> dict)
[Obsolete]
private static TokenCredential BuildAzureTokenCredential(Dictionary<string, string> dict)
{
if (dict.TryGetValue(ClientIdProperty, out var clientId))
{
if (dict.TryGetValue(TenantIdProperty, out var tenantId))
{
if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new MicrosoftEntraAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
return new ClientSecretCredential(tenantId, clientId, clientSecret);
}
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new MicrosoftEntraAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
return new ClientCertificateCredential(tenantId, clientId, clientCertPath);
}
else
{
Expand All @@ -184,28 +185,16 @@ private static IAccessKey BuildAzureADAccessKey(Uri uri, Uri? serverEndpointUri,
}
else
{
return new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri);
return new ManagedIdentityCredential(clientId);
}
}
else
{
return new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri);
return new ManagedIdentityCredential();
}
}

private static IAccessKey BuildAccessKey(Uri uri, Dictionary<string, string> dict)
{
return dict.TryGetValue(AccessKeyProperty, out var key)
? new AccessKey(uri, key)
: throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty);
}

private static IAccessKey BuildAzureAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary<string, string> dict)
{
return new MicrosoftEntraAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri);
}

private static IAccessKey BuildAzureAppAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary<string, string> dict)
private static TokenCredential BuildApplicationCredential(Dictionary<string, string> dict)
{
if (!dict.TryGetValue(ClientIdProperty, out var clientId))
{
Expand All @@ -219,20 +208,20 @@ private static IAccessKey BuildAzureAppAccessKey(Uri uri, Uri? serverEndpointUri

if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new MicrosoftEntraAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
return new ClientSecretCredential(tenantId, clientId, clientSecret);
}
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new MicrosoftEntraAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
return new ClientCertificateCredential(tenantId, clientId, clientCertPath);
}
throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty);
}

private static IAccessKey BuildAzureMsiAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary<string, string> dict)
private static TokenCredential BuildManagedIdentityCredential(Dictionary<string, string> dict)
{
return dict.TryGetValue(ClientIdProperty, out var clientId)
? new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri)
: new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri);
? new ManagedIdentityCredential(clientId)
: new ManagedIdentityCredential();
}

private static Dictionary<string, string> ToDictionary(string connectionString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using Azure.Core;

namespace Microsoft.Azure.SignalR;

Expand All @@ -11,14 +12,17 @@ internal class ParsedConnectionString
{
internal Uri Endpoint { get; }

internal IAccessKey? AccessKey { get; set; }
internal string? AccessKey { get; init; }

internal Uri? ClientEndpoint { get; set; }
internal TokenCredential TokenCredential { get; }

internal Uri? ServerEndpoint { get; set; }
internal Uri? ClientEndpoint { get; init; }

public ParsedConnectionString(Uri endpoint)
internal Uri? ServerEndpoint { get; init; }

public ParsedConnectionString(Uri endpoint, TokenCredential tokenCredential)
{
Endpoint = endpoint;
TokenCredential = tokenCredential;
}
}
Loading

0 comments on commit a4501de

Please sign in to comment.