Skip to content

Commit

Permalink
TokenExchangeManagedIdentitySource with async IO (Azure#38939)
Browse files Browse the repository at this point in the history
  • Loading branch information
christothes authored Sep 29, 2023
1 parent 0fa6301 commit 3c35fbb
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 12 deletions.
8 changes: 8 additions & 0 deletions sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,19 @@ protected virtual async ValueTask<IConfidentialClientApplication> CreateClientCo

if (_assertionCallback != null)
{
if (_asyncAssertionCallback != null)
{
throw new InvalidOperationException($"Cannot set both {nameof(_assertionCallback)} and {nameof(_asyncAssertionCallback)}");
}
confClientBuilder.WithClientAssertion(_assertionCallback);
}

if (_asyncAssertionCallback != null)
{
if (_assertionCallback != null)
{
throw new InvalidOperationException($"Cannot set both {nameof(_assertionCallback)} and {nameof(_asyncAssertionCallback)}");
}
confClientBuilder.WithClientAssertion(_asyncAssertionCallback);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Text;
Expand All @@ -15,12 +16,13 @@ internal class TokenExchangeManagedIdentitySource : ManagedIdentitySource
{
private TokenFileCache _tokenFileCache;
private ClientAssertionCredential _clientAssertionCredential;
private static readonly int DefaultBufferSize = 4096;

private TokenExchangeManagedIdentitySource(CredentialPipeline pipeline, string tenantId, string clientId, string tokenFilePath)
: base(pipeline)
{
_tokenFileCache = new TokenFileCache(tokenFilePath);
_clientAssertionCredential = new ClientAssertionCredential(tenantId, clientId, _tokenFileCache.GetTokenFileContents, new ClientAssertionCredentialOptions { Pipeline = pipeline });
_clientAssertionCredential = new ClientAssertionCredential(tenantId, clientId, _tokenFileCache.GetTokenFileContentsAsync, new ClientAssertionCredentialOptions { Pipeline = pipeline });
}

public static ManagedIdentitySource TryCreate(ManagedIdentityClientOptions options)
Expand All @@ -47,13 +49,9 @@ protected override Request CreateRequest(string[] scopes)
throw new NotImplementedException();
}

// Ideally this class should handle I/O asynchronously, and have a design similar to AccessTokenCache in BearerTokenAuthenticationPolicy.
// However, MSAL currently only accepts sync callbacks for client assertions so this has been radically simplified in light of this. If MSAL
// were to add support for an async callback we should update this accordingly.
// See, https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/issues/2863
private class TokenFileCache
{
private readonly object _lock = new object();
private static SemaphoreSlim semaphore = new SemaphoreSlim(1, 1);
private readonly string _tokenFilePath;
private string _tokenFileContents;
private DateTimeOffset _refreshOn = DateTimeOffset.MinValue;
Expand All @@ -63,23 +61,83 @@ public TokenFileCache(string tokenFilePath)
_tokenFilePath = tokenFilePath;
}

public string GetTokenFileContents()
public async Task<string> GetTokenFileContentsAsync(CancellationToken cancellationToken)
{
if (_refreshOn <= DateTimeOffset.UtcNow)
{
lock (_lock)
try
{
if (_refreshOn <= DateTimeOffset.UtcNow)
await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
{
_tokenFileContents = File.ReadAllText(_tokenFilePath);
if (_refreshOn <= DateTimeOffset.UtcNow)
{
_tokenFileContents = await ReadAllTextAsync(_tokenFilePath).ConfigureAwait(false);

_refreshOn = DateTimeOffset.UtcNow.AddMinutes(5);
_refreshOn = DateTimeOffset.UtcNow.AddMinutes(5);
}
}
}
finally
{
semaphore.Release();
}
}

return _tokenFileContents;
}
}

// Since File.ReadAllTextAsync is not available in netstandard2.0, the below implementation is borrowed with some modifications from
// https://github.com/dotnet/runtime/blob/8bcd03c650a85d523d542715e4e2543251f1dfa5/src/libraries/System.Private.CoreLib/src/System/IO/File.cs#L863-L906
internal static Task<string> ReadAllTextAsync(string path, CancellationToken cancellationToken = default)
=> ReadAllTextAsync(path, Encoding.UTF8, cancellationToken);

internal static Task<string> ReadAllTextAsync(string path, Encoding encoding, CancellationToken cancellationToken = default(CancellationToken))
{
Argument.AssertNotNullOrEmpty(path, nameof(path));
Argument.AssertNotNull(encoding, nameof(encoding));

return cancellationToken.IsCancellationRequested
? Task.FromCanceled<string>(cancellationToken)
: InternalReadAllTextAsync(path, encoding, cancellationToken);
}

private static async Task<string> InternalReadAllTextAsync(string path, Encoding encoding, CancellationToken cancellationToken)
{
char[] buffer = null;
StreamReader sr = AsyncStreamReader(path, encoding);
try
{
cancellationToken.ThrowIfCancellationRequested();
buffer = ArrayPool<char>.Shared.Rent(sr.CurrentEncoding.GetMaxCharCount(DefaultBufferSize));
StringBuilder sb = new StringBuilder();
int totalRead = 0;
while (true)
{
int read = await sr.ReadAsync(buffer, totalRead, DefaultBufferSize - totalRead).ConfigureAwait(false);
if (read == 0)
{
return sb.ToString();
}

sb.Append(buffer, 0, read);
totalRead += read;
}
}
finally
{
sr.Dispose();
if (buffer != null)
{
ArrayPool<char>.Shared.Return(buffer);
}
}
}

// If we use the path-taking constructors, we won't have FileOptions.Asynchronous set and
// we will have asynchronous file access faked by the thread pool. We want the real thing.
private static StreamReader AsyncStreamReader(string path, Encoding encoding)
=> new StreamReader(
new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read, DefaultBufferSize, FileOptions.Asynchronous | FileOptions.SequentialScan),
encoding, detectEncodingFromByteOrderMarks: true);
}
}

0 comments on commit 3c35fbb

Please sign in to comment.