From 3c35fbb843d2528d124a209bba9be63a1482947e Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Fri, 29 Sep 2023 17:00:12 -0500 Subject: [PATCH] TokenExchangeManagedIdentitySource with async IO (#38939) --- .../src/MsalConfidentialClient.cs | 8 ++ .../src/TokenExchangeManagedIdentitySource.cs | 82 ++++++++++++++++--- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs b/sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs index ce2bb99db6c75..10dde39557cbc 100644 --- a/sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs +++ b/sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs @@ -104,11 +104,19 @@ protected virtual async ValueTask CreateClientCo if (_assertionCallback != null) { + if (_asyncAssertionCallback != null) + { + throw new InvalidOperationException($"Cannot set both {nameof(_assertionCallback)} and {nameof(_asyncAssertionCallback)}"); + } confClientBuilder.WithClientAssertion(_assertionCallback); } if (_asyncAssertionCallback != null) { + if (_assertionCallback != null) + { + throw new InvalidOperationException($"Cannot set both {nameof(_assertionCallback)} and {nameof(_asyncAssertionCallback)}"); + } confClientBuilder.WithClientAssertion(_asyncAssertionCallback); } diff --git a/sdk/identity/Azure.Identity/src/TokenExchangeManagedIdentitySource.cs b/sdk/identity/Azure.Identity/src/TokenExchangeManagedIdentitySource.cs index 5a3eb0ac6da1a..f3716c0a06691 100644 --- a/sdk/identity/Azure.Identity/src/TokenExchangeManagedIdentitySource.cs +++ b/sdk/identity/Azure.Identity/src/TokenExchangeManagedIdentitySource.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Buffers; using System.Collections.Generic; using System.IO; using System.Text; @@ -15,12 +16,13 @@ internal class TokenExchangeManagedIdentitySource : ManagedIdentitySource { private TokenFileCache _tokenFileCache; private ClientAssertionCredential _clientAssertionCredential; + private static readonly int DefaultBufferSize = 4096; private TokenExchangeManagedIdentitySource(CredentialPipeline pipeline, string tenantId, string clientId, string tokenFilePath) : base(pipeline) { _tokenFileCache = new TokenFileCache(tokenFilePath); - _clientAssertionCredential = new ClientAssertionCredential(tenantId, clientId, _tokenFileCache.GetTokenFileContents, new ClientAssertionCredentialOptions { Pipeline = pipeline }); + _clientAssertionCredential = new ClientAssertionCredential(tenantId, clientId, _tokenFileCache.GetTokenFileContentsAsync, new ClientAssertionCredentialOptions { Pipeline = pipeline }); } public static ManagedIdentitySource TryCreate(ManagedIdentityClientOptions options) @@ -47,13 +49,9 @@ protected override Request CreateRequest(string[] scopes) throw new NotImplementedException(); } - // Ideally this class should handle I/O asynchronously, and have a design similar to AccessTokenCache in BearerTokenAuthenticationPolicy. - // However, MSAL currently only accepts sync callbacks for client assertions so this has been radically simplified in light of this. If MSAL - // were to add support for an async callback we should update this accordingly. - // See, https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/issues/2863 private class TokenFileCache { - private readonly object _lock = new object(); + private static SemaphoreSlim semaphore = new SemaphoreSlim(1, 1); private readonly string _tokenFilePath; private string _tokenFileContents; private DateTimeOffset _refreshOn = DateTimeOffset.MinValue; @@ -63,23 +61,83 @@ public TokenFileCache(string tokenFilePath) _tokenFilePath = tokenFilePath; } - public string GetTokenFileContents() + public async Task GetTokenFileContentsAsync(CancellationToken cancellationToken) { if (_refreshOn <= DateTimeOffset.UtcNow) { - lock (_lock) + try { - if (_refreshOn <= DateTimeOffset.UtcNow) + await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); { - _tokenFileContents = File.ReadAllText(_tokenFilePath); + if (_refreshOn <= DateTimeOffset.UtcNow) + { + _tokenFileContents = await ReadAllTextAsync(_tokenFilePath).ConfigureAwait(false); - _refreshOn = DateTimeOffset.UtcNow.AddMinutes(5); + _refreshOn = DateTimeOffset.UtcNow.AddMinutes(5); + } } } + finally + { + semaphore.Release(); + } } - return _tokenFileContents; } } + + // Since File.ReadAllTextAsync is not available in netstandard2.0, the below implementation is borrowed with some modifications from + // https://github.com/dotnet/runtime/blob/8bcd03c650a85d523d542715e4e2543251f1dfa5/src/libraries/System.Private.CoreLib/src/System/IO/File.cs#L863-L906 + internal static Task ReadAllTextAsync(string path, CancellationToken cancellationToken = default) + => ReadAllTextAsync(path, Encoding.UTF8, cancellationToken); + + internal static Task ReadAllTextAsync(string path, Encoding encoding, CancellationToken cancellationToken = default(CancellationToken)) + { + Argument.AssertNotNullOrEmpty(path, nameof(path)); + Argument.AssertNotNull(encoding, nameof(encoding)); + + return cancellationToken.IsCancellationRequested + ? Task.FromCanceled(cancellationToken) + : InternalReadAllTextAsync(path, encoding, cancellationToken); + } + + private static async Task InternalReadAllTextAsync(string path, Encoding encoding, CancellationToken cancellationToken) + { + char[] buffer = null; + StreamReader sr = AsyncStreamReader(path, encoding); + try + { + cancellationToken.ThrowIfCancellationRequested(); + buffer = ArrayPool.Shared.Rent(sr.CurrentEncoding.GetMaxCharCount(DefaultBufferSize)); + StringBuilder sb = new StringBuilder(); + int totalRead = 0; + while (true) + { + int read = await sr.ReadAsync(buffer, totalRead, DefaultBufferSize - totalRead).ConfigureAwait(false); + if (read == 0) + { + return sb.ToString(); + } + + sb.Append(buffer, 0, read); + totalRead += read; + } + } + finally + { + sr.Dispose(); + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + } + + // If we use the path-taking constructors, we won't have FileOptions.Asynchronous set and + // we will have asynchronous file access faked by the thread pool. We want the real thing. + private static StreamReader AsyncStreamReader(string path, Encoding encoding) + => new StreamReader( + new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read, DefaultBufferSize, FileOptions.Asynchronous | FileOptions.SequentialScan), + encoding, detectEncodingFromByteOrderMarks: true); } }