From ca8746f7cf356d7f001de3d445fc6761dca886f5 Mon Sep 17 00:00:00 2001 From: dauinsight <145612907+dauinsight@users.noreply.github.com> Date: Fri, 16 Aug 2024 09:26:39 -0700 Subject: [PATCH] [5.1] Add | Cache TokenCredential objects to take advantage of token caching (#2380) (#2776) * Add | Cache TokenCredential objects to take advantage of token caching (dotnet#2380) --- .../ActiveDirectoryAuthenticationProvider.cs | 234 +++++++++++++++--- 1 file changed, 201 insertions(+), 33 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 76c81e282e..f2fd1aaceb 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -25,14 +25,16 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro /// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode /// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache. /// - private static ConcurrentDictionary s_pcaMap - = new ConcurrentDictionary(); private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider)); + private static readonly ConcurrentDictionary s_pcaMap = new(); + private static readonly ConcurrentDictionary s_tokenCredentialMap = new(); + private static SemaphoreSlim s_pcaMapModifierSemaphore = new(1, 1); + private static SemaphoreSlim s_tokenCredentialMapModifierSemaphore = new(1, 1); private static readonly int s_accountPwCacheTtlInHours = 2; private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient"; private static readonly string s_defaultScopeSuffix = "/.default"; private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name; - private readonly SqlClientLogger _logger = new SqlClientLogger(); + private readonly SqlClientLogger _logger = new(); private Func _deviceCodeFlowCallback; private ICustomWebUi _customWebUI = null; private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId; @@ -66,6 +68,11 @@ public static void ClearUserTokenCache() { s_pcaMap.Clear(); } + + if (!s_tokenCredentialMap.IsEmpty) + { + s_tokenCredentialMap.Clear(); + } } /// @@ -145,38 +152,27 @@ public override async Task AcquireTokenAsync(SqlAuthenti * More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration **/ - int seperatorIndex = parameters.Authority.LastIndexOf('/'); - string authority = parameters.Authority.Remove(seperatorIndex + 1); - string audience = parameters.Authority.Substring(seperatorIndex + 1); + int separatorIndex = parameters.Authority.LastIndexOf('/'); + string authority = parameters.Authority.Remove(separatorIndex + 1); + string audience = parameters.Authority.Substring(separatorIndex + 1); string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId; if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault) { - DefaultAzureCredentialOptions defaultAzureCredentialOptions = new() - { - AuthorityHost = new Uri(authority), - SharedTokenCacheTenantId = audience, - VisualStudioCodeTenantId = audience, - VisualStudioTenantId = audience, - ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications. - }; - - // Optionally set clientId when available - if (clientId is not null) - { - defaultAzureCredentialOptions.ManagedIdentityClientId = clientId; - defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId; - } - AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); + // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } - TokenCredentialOptions tokenCredentialOptions = new TokenCredentialOptions() { AuthorityHost = new Uri(authority) }; + TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(authority) }; if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI) { - AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); + // Cache ManagedIdentityCredential based on scope, authority, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } @@ -184,11 +180,12 @@ public override async Task AcquireTokenAsync(SqlAuthenti AuthenticationResult result = null; if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) { - AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); + // Cache ClientSecretCredential based on scope, authority, audience, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } - /* * Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows * (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend @@ -204,7 +201,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti redirectUri = "http://localhost"; } #endif - PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId #if NETFRAMEWORK , _iWin32WindowFunc #endif @@ -213,7 +210,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti #endif ); - IPublicClientApplication app = GetPublicClientAppInstance(pcaKey); + IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { @@ -248,7 +245,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti if (null != previousPw && previousPw is byte[] previousPwBytes && // Only get the cached token if the current password hash matches the previously used password hash - currPwHash.SequenceEqual(previousPwBytes)) + AreEqual(currPwHash, previousPwBytes)) { result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); } @@ -353,7 +350,7 @@ private static async Task AcquireTokenInteractiveDeviceFlo { if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) { - CancellationTokenSource ctsInteractive = new CancellationTokenSource(); + CancellationTokenSource ctsInteractive = new(); #if NETCOREAPP /* * On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser, @@ -447,16 +444,69 @@ public Task AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirec => _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken); } - private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey) + private async Task GetPublicClientAppInstanceAsync(PublicClientAppKey publicClientAppKey, CancellationToken cancellationToken) { if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance)) { - clientApplicationInstance = CreateClientAppInstance(publicClientAppKey); - s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance); + await s_pcaMapModifierSemaphore.WaitAsync(cancellationToken); + try + { + // Double-check in case another thread added it while we waited for the semaphore + if (!s_pcaMap.TryGetValue(publicClientAppKey, out clientApplicationInstance)) + { + clientApplicationInstance = CreateClientAppInstance(publicClientAppKey); + s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance); + } + } + finally + { + s_pcaMapModifierSemaphore.Release(); + } } + return clientApplicationInstance; } + private static async Task GetTokenAsync(TokenCredentialKey tokenCredentialKey, string secret, + TokenRequestContext tokenRequestContext, CancellationToken cancellationToken) + { + if (!s_tokenCredentialMap.TryGetValue(tokenCredentialKey, out TokenCredentialData tokenCredentialInstance)) + { + await s_tokenCredentialMapModifierSemaphore.WaitAsync(cancellationToken); + try + { + // Double-check in case another thread added it while we waited for the semaphore + if (!s_tokenCredentialMap.TryGetValue(tokenCredentialKey, out tokenCredentialInstance)) + { + tokenCredentialInstance = CreateTokenCredentialInstance(tokenCredentialKey, secret); + s_tokenCredentialMap.TryAdd(tokenCredentialKey, tokenCredentialInstance); + } + } + finally + { + s_tokenCredentialMapModifierSemaphore.Release(); + } + } + + if (!AreEqual(tokenCredentialInstance._secretHash, GetHash(secret))) + { + // If the secret hash has changed, we need to remove the old token credential instance and create a new one. + await s_tokenCredentialMapModifierSemaphore.WaitAsync(cancellationToken); + try + { + s_tokenCredentialMap.TryRemove(tokenCredentialKey, out _); + tokenCredentialInstance = CreateTokenCredentialInstance(tokenCredentialKey, secret); + s_tokenCredentialMap.TryAdd(tokenCredentialKey, tokenCredentialInstance); + } + finally + { + s_tokenCredentialMapModifierSemaphore.Release(); + } + } + + return await tokenCredentialInstance._tokenCredential.GetTokenAsync(tokenRequestContext, cancellationToken); + } + private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters) { return parameters.Authority + "+" + parameters.UserId; @@ -470,6 +520,24 @@ private static byte[] GetHash(string input) return hashedBytes; } + private static bool AreEqual(byte[] a1, byte[] a2) + { + if (ReferenceEquals(a1, a2)) + { + return true; + } + else if (a1 is null || a2 is null) + { + return false; + } + else if (a1.Length != a2.Length) + { + return false; + } + + return a1.AsSpan().SequenceEqual(a2.AsSpan()); + } + private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey) { IPublicClientApplication publicClientApplication; @@ -513,6 +581,59 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ return publicClientApplication; } + private static TokenCredentialData CreateTokenCredentialInstance(TokenCredentialKey tokenCredentialKey, string secret) + { + if (tokenCredentialKey._tokenCredentialType == typeof(DefaultAzureCredential)) + { + DefaultAzureCredentialOptions defaultAzureCredentialOptions = new() + { + AuthorityHost = new Uri(tokenCredentialKey._authority), + SharedTokenCacheTenantId = tokenCredentialKey._audience, + VisualStudioCodeTenantId = tokenCredentialKey._audience, + VisualStudioTenantId = tokenCredentialKey._audience, + ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications. + }; + + // Optionally set clientId when available + if (tokenCredentialKey._clientId is not null) + { + defaultAzureCredentialOptions.ManagedIdentityClientId = tokenCredentialKey._clientId; + defaultAzureCredentialOptions.SharedTokenCacheUsername = tokenCredentialKey._clientId; + defaultAzureCredentialOptions.WorkloadIdentityClientId = tokenCredentialKey._clientId; + } + + return new TokenCredentialData(new DefaultAzureCredential(defaultAzureCredentialOptions), GetHash(secret)); + } + + TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(tokenCredentialKey._authority) }; + + if (tokenCredentialKey._tokenCredentialType == typeof(ManagedIdentityCredential)) + { + return new TokenCredentialData(new ManagedIdentityCredential(tokenCredentialKey._clientId, tokenCredentialOptions), GetHash(secret)); + } + else if (tokenCredentialKey._tokenCredentialType == typeof(ClientSecretCredential)) + { + return new TokenCredentialData(new ClientSecretCredential(tokenCredentialKey._audience, tokenCredentialKey._clientId, secret, tokenCredentialOptions), GetHash(secret)); + } + else if (tokenCredentialKey._tokenCredentialType == typeof(WorkloadIdentityCredential)) + { + // The WorkloadIdentityCredentialOptions object initialization populates its instance members + // from the environment variables AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, + // and AZURE_ADDITIONALLY_ALLOWED_TENANTS. AZURE_CLIENT_ID may be overridden by the User Id. + WorkloadIdentityCredentialOptions options = new() { AuthorityHost = new Uri(tokenCredentialKey._authority) }; + + if (tokenCredentialKey._clientId is not null) + { + options.ClientId = tokenCredentialKey._clientId; + } + + return new TokenCredentialData(new WorkloadIdentityCredential(options), GetHash(secret)); + } + + // This should never be reached, but if it is, throw an exception that will be noticed during development + throw new ArgumentException(nameof(ActiveDirectoryAuthenticationProvider)); + } + internal class PublicClientAppKey { public readonly string _authority; @@ -572,5 +693,52 @@ public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _app #endif ).GetHashCode(); } + + internal class TokenCredentialData + { + public TokenCredential _tokenCredential; + public byte[] _secretHash; + + public TokenCredentialData(TokenCredential tokenCredential, byte[] secretHash) + { + _tokenCredential = tokenCredential; + _secretHash = secretHash; + } + } + + internal class TokenCredentialKey + { + public readonly Type _tokenCredentialType; + public readonly string _authority; + public readonly string _scope; + public readonly string _audience; + public readonly string _clientId; + + public TokenCredentialKey(Type tokenCredentialType, string authority, string scope, string audience, string clientId) + { + _tokenCredentialType = tokenCredentialType; + _authority = authority; + _scope = scope; + _audience = audience; + _clientId = clientId; + } + + public override bool Equals(object obj) + { + if (obj != null && obj is TokenCredentialKey tcKey) + { + return string.CompareOrdinal(nameof(_tokenCredentialType), nameof(tcKey._tokenCredentialType)) == 0 + && string.CompareOrdinal(_authority, tcKey._authority) == 0 + && string.CompareOrdinal(_scope, tcKey._scope) == 0 + && string.CompareOrdinal(_audience, tcKey._audience) == 0 + && string.CompareOrdinal(_clientId, tcKey._clientId) == 0 + ; + } + return false; + } + + public override int GetHashCode() => Tuple.Create(_tokenCredentialType, _authority, _scope, _audience, _clientId).GetHashCode(); + } + } }