From 2945277d2ae03be166c9f3b5d4ea3daec77efc12 Mon Sep 17 00:00:00 2001 From: Alexander Sher Date: Tue, 23 Jan 2024 21:46:15 -0500 Subject: [PATCH] Fix #40174: Azure.Extensions.AspNetCore.DataProtection.Keys Heavy load causes blocked threads (#40914) * Fix 40174: Azure.Extensions.AspNetCore.DataProtection.Keys Heavy load causes blocked threads * Address PR comment --- .../src/AzureKeyVaultConfigurationProvider.cs | 111 ++++++---- .../src/ParallelSecretLoader.cs | 14 +- .../tests/AzureKeyVaultConfigurationTests.cs | 21 +- .../src/AzureBlobXmlRepository.cs | 196 ++++++++---------- .../tests/AzureBlobXmlRepositoryTests.cs | 20 +- .../src/AzureKeyVaultXmlDecryptor.cs | 9 +- .../src/AzureKeyVaultXmlEncryptor.cs | 9 +- .../tests/AzureKeyVaultXmlEncryptorTests.cs | 16 +- .../src/AzureClientFactoryBuilder.cs | 10 +- .../src/Internal/ClientOptionsFactory.cs | 1 - .../src/Internal/ClientRegistration.cs | 132 ++++++------ 11 files changed, 281 insertions(+), 258 deletions(-) diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/AzureKeyVaultConfigurationProvider.cs b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/AzureKeyVaultConfigurationProvider.cs index 9d8a3ba0b3d72..8b359d2b25e17 100644 --- a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/AzureKeyVaultConfigurationProvider.cs +++ b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/AzureKeyVaultConfigurationProvider.cs @@ -53,7 +53,68 @@ public AzureKeyVaultConfigurationProvider(SecretClient client, AzureKeyVaultConf /// /// Load secrets into this provider. /// - public override void Load() => LoadAsync().GetAwaiter().GetResult(); + public override void Load() + { + var secretPages = _client.GetPropertiesOfSecrets(); + + using var secretLoader = new ParallelSecretLoader(_client); + var newLoadedSecrets = new Dictionary(); + var oldLoadedSecrets = Interlocked.Exchange(ref _loadedSecrets, null); + + foreach (var secret in secretPages) + { + AddSecretToLoader(secret, oldLoadedSecrets, newLoadedSecrets, secretLoader); + } + + var loadedSecret = secretLoader.WaitForAll(); + UpdateSecrets(loadedSecret, newLoadedSecrets, oldLoadedSecrets); + + // schedule a polling task only if none exists and a valid delay is specified + if (_pollingTask == null && _reloadInterval != null) + { + _pollingTask = PollForSecretChangesAsync(); + } + } + + private void AddSecretToLoader(SecretProperties secret, Dictionary oldLoadedSecrets, Dictionary newLoadedSecrets, ParallelSecretLoader secretLoader) + { + if (!_manager.Load(secret) || secret.Enabled != true) + { + return; + } + + var secretId = secret.Name; + if (oldLoadedSecrets != null && oldLoadedSecrets.TryGetValue(secretId, out var existingSecret) && IsUpToDate(existingSecret, secret)) + { + oldLoadedSecrets.Remove(secretId); + newLoadedSecrets.Add(secretId, existingSecret); + } + else + { + secretLoader.AddSecretToLoad(secret.Name); + } + } + + private void UpdateSecrets(Response[] loadedSecret, Dictionary newLoadedSecrets, Dictionary oldLoadedSecrets) + { + foreach (var secretBundle in loadedSecret) + { + newLoadedSecrets.Add(secretBundle.Value.Name, secretBundle); + } + + _loadedSecrets = newLoadedSecrets; + + // Reload is needed if we are loading secrets that were not loaded before or + // secret that was loaded previously is not available anymore + if (loadedSecret.Any() || oldLoadedSecrets?.Any() == true) + { + Data = _manager.GetData(newLoadedSecrets.Values); + if (oldLoadedSecrets != null) + { + OnReload(); + } + } + } private async Task PollForSecretChangesAsync() { @@ -79,57 +140,17 @@ internal virtual Task WaitForReload() private async Task LoadAsync() { - var secretPages = _client.GetPropertiesOfSecretsAsync(); - using var secretLoader = new ParallelSecretLoader(_client); var newLoadedSecrets = new Dictionary(); var oldLoadedSecrets = Interlocked.Exchange(ref _loadedSecrets, null); - await foreach (var secret in secretPages.ConfigureAwait(false)) - { - if (!_manager.Load(secret) || secret.Enabled != true) - { - continue; - } - - var secretId = secret.Name; - if (oldLoadedSecrets != null && - oldLoadedSecrets.TryGetValue(secretId, out var existingSecret) && - IsUpToDate(existingSecret, secret)) - { - oldLoadedSecrets.Remove(secretId); - newLoadedSecrets.Add(secretId, existingSecret); - } - else - { - secretLoader.Add(secret.Name); - } - } - - var loadedSecret = await secretLoader.WaitForAll().ConfigureAwait(false); - foreach (var secretBundle in loadedSecret) - { - newLoadedSecrets.Add(secretBundle.Value.Name, secretBundle); - } - - _loadedSecrets = newLoadedSecrets; - - // Reload is needed if we are loading secrets that were not loaded before or - // secret that was loaded previously is not available anymore - if (loadedSecret.Any() || oldLoadedSecrets?.Any() == true) + await foreach (var secret in _client.GetPropertiesOfSecretsAsync().ConfigureAwait(false)) { - Data = _manager.GetData(newLoadedSecrets.Values); - if (oldLoadedSecrets != null) - { - OnReload(); - } + AddSecretToLoader(secret, oldLoadedSecrets, newLoadedSecrets, secretLoader); } - // schedule a polling task only if none exists and a valid delay is specified - if (_pollingTask == null && _reloadInterval != null) - { - _pollingTask = PollForSecretChangesAsync(); - } + var loadedSecret = await secretLoader.WaitForAllAsync().ConfigureAwait(false); + UpdateSecrets(loadedSecret, newLoadedSecrets, oldLoadedSecrets); } /// diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/ParallelSecretLoader.cs b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/ParallelSecretLoader.cs index b01970052e5f5..c9267e70e18d8 100644 --- a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/ParallelSecretLoader.cs +++ b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/ParallelSecretLoader.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Azure.Security.KeyVault.Secrets; @@ -23,12 +24,12 @@ public ParallelSecretLoader(SecretClient client) _tasks = new List>>(); } - public void Add(string secretName) + public void AddSecretToLoad(string secretName) { - _tasks.Add(GetSecret(secretName)); + _tasks.Add(Task.Run(() => GetSecretAsync(secretName))); } - private async Task> GetSecret(string secretName) + private async Task> GetSecretAsync(string secretName) { await _semaphore.WaitAsync().ConfigureAwait(false); try @@ -41,11 +42,14 @@ private async Task> GetSecret(string secretName) } } - public Task[]> WaitForAll() + public Response[] WaitForAll() { - return Task.WhenAll(_tasks); + Task.WaitAll(_tasks.Select(t => (Task)t).ToArray()); + return _tasks.Select(t => t.Result).ToArray(); } + public Task[]> WaitForAllAsync() => Task.WhenAll(_tasks); + public void Dispose() { _semaphore?.Dispose(); diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/tests/AzureKeyVaultConfigurationTests.cs b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/tests/AzureKeyVaultConfigurationTests.cs index 5946ee5686825..0496d174abbe3 100644 --- a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/tests/AzureKeyVaultConfigurationTests.cs +++ b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/tests/AzureKeyVaultConfigurationTests.cs @@ -32,6 +32,7 @@ private void SetPages(Mock mock, Func getSecretCallb var pagesOfProperties = pages.Select( page => page.Select(secret => secret.Properties).ToArray()).ToArray(); + mock.Setup(m => m.GetPropertiesOfSecrets(default)).Returns(new MockPageable(pagesOfProperties)); mock.Setup(m => m.GetPropertiesOfSecretsAsync(default)).Returns(new MockAsyncPageable(pagesOfProperties)); foreach (var page in pages) @@ -49,6 +50,24 @@ private void SetPages(Mock mock, Func getSecretCallb } } + private class MockPageable : Pageable + { + private readonly SecretProperties[][] _pages; + + public MockPageable(SecretProperties[][] pages) + { + _pages = pages; + } + + public override IEnumerable> AsPages(string continuationToken = null, int? pageSizeHint = null) + { + foreach (var page in _pages) + { + yield return Page.FromValues(page, null, Mock.Of()); + } + } + } + private class MockAsyncPageable : AsyncPageable { private readonly SecretProperties[][] _pages; @@ -353,7 +372,7 @@ public async Task SupportsReloadOnEnabledChange() await provider.Wait(); SetPages(client, - new[] + new[] { CreateSecret("Secret1", "Value2"), CreateSecret("Secret2", "Value2", enabled: false) diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/src/AzureBlobXmlRepository.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/src/AzureBlobXmlRepository.cs index 6a2ecb2bce959..c892daefefe15 100644 --- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/src/AzureBlobXmlRepository.cs +++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/src/AzureBlobXmlRepository.cs @@ -50,51 +50,11 @@ public AzureBlobXmlRepository(BlobClient blobClient) /// public IReadOnlyCollection GetAllElements() { - // Shunt the work onto a ThreadPool thread so that it's independent of any - // existing sync context or other potentially deadlock-causing items. - - var elements = Task.Run(() => GetAllElementsAsync()).GetAwaiter().GetResult(); - return new ReadOnlyCollection(elements); - } - - /// - public void StoreElement(XElement element, string friendlyName) - { - if (element == null) - { - throw new ArgumentNullException(nameof(element)); - } - - // Shunt the work onto a ThreadPool thread so that it's independent of any - // existing sync context or other potentially deadlock-causing items. - - Task.Run(() => StoreElementAsync(element)).GetAwaiter().GetResult(); - } - - private static XDocument CreateDocumentFromBlobData(BlobData blobData) - { - if (blobData == null || blobData.BlobContents.Length == 0) - { - return new XDocument(new XElement(RepositoryElementName)); - } - - using var memoryStream = new MemoryStream(blobData.BlobContents); - - var xmlReaderSettings = new XmlReaderSettings() - { - DtdProcessing = DtdProcessing.Prohibit, - IgnoreProcessingInstructions = true, - }; - - using (var xmlReader = XmlReader.Create(memoryStream, xmlReaderSettings)) - { - return XDocument.Load(xmlReader); - } - } - - private async Task> GetAllElementsAsync() - { - var data = await GetLatestDataAsync().ConfigureAwait(false); + // Fixes for #40174 + // Original `Task.Run(() => GetAllElementsAsync()).GetAwaiter().GetResult();` blocks the thread in ThreadPool until task is completed, + // then runs first part of the task on another ThreadPool thread and schedules continuation to run in ThreadPool thread again. + // If too many calls of GetAllElements() happens before any continuation is executed, all threads in ThreadPool will become blocked + var data = GetLatestData(); // The document will look like this: // @@ -107,68 +67,17 @@ private async Task> GetAllElementsAsync() // We want to return the first-level child elements to our caller. var doc = CreateDocumentFromBlobData(data); - return doc.Root.Elements().ToList(); + return new ReadOnlyCollection(doc.Root.Elements().ToList()); } - private async Task GetLatestDataAsync() + /// + public void StoreElement(XElement element, string friendlyName) { - // Set the appropriate AccessCondition based on what we believe the latest - // file contents to be, then make the request. - - var latestCachedData = Volatile.Read(ref _cachedBlobData); // local ref so field isn't mutated under our feet - var requestCondition = (latestCachedData != null) - ? new BlobRequestConditions() { IfNoneMatch = latestCachedData.ETag } - : null; - - try - { - using (var memoryStream = new MemoryStream()) - { - var response = await _blobClient.DownloadToAsync( - destination: memoryStream, - conditions: requestCondition).ConfigureAwait(false); - - if (response.Status == 304) - { - // 304 Not Modified - // Thrown when we already have the latest cached data. - // This isn't an error; we'll return our cached copy of the data. - return latestCachedData; - } - - // At this point, our original cache either didn't exist or was outdated. - // We'll update it now and return the updated value - latestCachedData = new BlobData() - { - BlobContents = memoryStream.ToArray(), - ETag = response.Headers.ETag - }; - } - Volatile.Write(ref _cachedBlobData, latestCachedData); - } - catch (RequestFailedException ex) when (ex.Status == 404) + if (element == null) { - // 404 Not Found - // Thrown when no file exists in storage. - // This isn't an error; we'll delete our cached copy of data. - - latestCachedData = null; - Volatile.Write(ref _cachedBlobData, latestCachedData); + throw new ArgumentNullException(nameof(element)); } - return latestCachedData; - } - - private int GetRandomizedBackoffPeriod() - { - // returns a TimeSpan in the range [0.8, 1.0) * ConflictBackoffPeriod - // not used for crypto purposes - var multiplier = 0.8 + (_random.NextDouble() * 0.2); - return (int) (multiplier * ConflictBackoffPeriod.Ticks); - } - - private async Task StoreElementAsync(XElement element) - { // holds the last error in case we need to rethrow it ExceptionDispatchInfo lastError = null; @@ -178,14 +87,14 @@ private async Task StoreElementAsync(XElement element) { // If multiple conflicts occurred, wait a small period of time before retrying // the operation so that other writers can make forward progress. - await Task.Delay(GetRandomizedBackoffPeriod()).ConfigureAwait(false); + Thread.Sleep(GetRandomizedBackoffPeriod()); } if (i > 0) { // If at least one conflict occurred, make sure we have an up-to-date // view of the blob contents. - await GetLatestDataAsync().ConfigureAwait(false); + GetLatestData(); } // Merge the new element into the document. If no document exists, @@ -217,10 +126,7 @@ private async Task StoreElementAsync(XElement element) try { // Send the request up to the server. - var response = await _blobClient.UploadAsync( - serializedDoc, - httpHeaders: _blobHttpHeaders, - conditions: requestConditions).ConfigureAwait(false); + var response = _blobClient.Upload(serializedDoc, httpHeaders: _blobHttpHeaders, conditions: requestConditions); // If we got this far, success! // We can update the cached view of the remote contents. @@ -255,6 +161,82 @@ private async Task StoreElementAsync(XElement element) lastError.Throw(); } + private static XDocument CreateDocumentFromBlobData(BlobData blobData) + { + if (blobData == null || blobData.BlobContents.Length == 0) + { + return new XDocument(new XElement(RepositoryElementName)); + } + + using var memoryStream = new MemoryStream(blobData.BlobContents); + + var xmlReaderSettings = new XmlReaderSettings() + { + DtdProcessing = DtdProcessing.Prohibit, + IgnoreProcessingInstructions = true, + }; + + using (var xmlReader = XmlReader.Create(memoryStream, xmlReaderSettings)) + { + return XDocument.Load(xmlReader); + } + } + + private BlobData GetLatestData() + { + // Set the appropriate AccessCondition based on what we believe the latest + // file contents to be, then make the request. + + var latestCachedData = Volatile.Read(ref _cachedBlobData); // local ref so field isn't mutated under our feet + var requestCondition = (latestCachedData != null) + ? new BlobRequestConditions() { IfNoneMatch = latestCachedData.ETag } + : null; + + try + { + using (var memoryStream = new MemoryStream()) + { + var response = _blobClient.DownloadTo(destination: memoryStream, conditions: requestCondition); + + if (response.Status == 304) + { + // 304 Not Modified + // Thrown when we already have the latest cached data. + // This isn't an error; we'll return our cached copy of the data. + return latestCachedData; + } + + // At this point, our original cache either didn't exist or was outdated. + // We'll update it now and return the updated value + latestCachedData = new BlobData + { + BlobContents = memoryStream.ToArray(), + ETag = response.Headers.ETag + }; + } + Volatile.Write(ref _cachedBlobData, latestCachedData); + } + catch (RequestFailedException ex) when (ex.Status == 404) + { + // 404 Not Found + // Thrown when no file exists in storage. + // This isn't an error; we'll delete our cached copy of data. + + latestCachedData = null; + Volatile.Write(ref _cachedBlobData, latestCachedData); + } + + return latestCachedData; + } + + private int GetRandomizedBackoffPeriod() + { + // returns a TimeSpan in the range [0.8, 1.0) * ConflictBackoffPeriod + // not used for crypto purposes + var multiplier = 0.8 + (_random.NextDouble() * 0.2); + return (int) (multiplier * ConflictBackoffPeriod.Ticks); + } + private sealed class BlobData { internal byte[] BlobContents; diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/tests/AzureBlobXmlRepositoryTests.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/tests/AzureBlobXmlRepositoryTests.cs index 9a8ae4d4be1d2..d8b4013373e7a 100644 --- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/tests/AzureBlobXmlRepositoryTests.cs +++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/tests/AzureBlobXmlRepositoryTests.cs @@ -30,7 +30,7 @@ public void StoreCreatesBlobWhenNotExist() var mock = new Mock(); - mock.Setup(c => c.UploadAsync( + mock.Setup(c => c.Upload( It.IsAny(), It.IsAny(), It.IsAny>(), @@ -39,7 +39,7 @@ public void StoreCreatesBlobWhenNotExist() It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(async (Stream strm, BlobHttpHeaders headers, IDictionary metaData, BlobRequestConditions conditions, IProgress progress, AccessTier? access, StorageTransferOptions transfer, CancellationToken token) => + .Returns((Stream strm, BlobHttpHeaders headers, IDictionary metaData, BlobRequestConditions conditions, IProgress progress, AccessTier? access, StorageTransferOptions transfer, CancellationToken token) => { using var memoryStream = new MemoryStream(); strm.CopyTo(memoryStream); @@ -47,8 +47,6 @@ public void StoreCreatesBlobWhenNotExist() uploadConditions = conditions; contentType = headers?.ContentType; - await Task.Yield(); - var mockResponse = new Mock>(); var blobContentInfo = BlobsModelFactory.BlobContentInfo(ETag.All, DateTimeOffset.Now.AddDays(-1), Array.Empty(), "", 1); @@ -73,15 +71,15 @@ public void StoreUpdatesWhenExistsAndNewerExists() var mock = new Mock(); - mock.Setup(c => c.DownloadToAsync( + mock.Setup(c => c.DownloadTo( It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(async (Stream target, BlobRequestConditions conditions, StorageTransferOptions options, CancellationToken token) => + .Returns((Stream target, BlobRequestConditions conditions, StorageTransferOptions options, CancellationToken token) => { var data = GetEnvelopedContent(""); - await target.WriteAsync(data, 0, data.Length); + target.Write(data, 0, data.Length); var response = new MockResponse(200); response.AddHeader(new HttpHeader("ETag", "*")); @@ -89,7 +87,7 @@ public void StoreUpdatesWhenExistsAndNewerExists() }) .Verifiable(); - mock.Setup(c => c.UploadAsync( + mock.Setup(c => c.Upload( It.IsAny(), It.IsAny(), It.IsAny>(), @@ -101,7 +99,7 @@ public void StoreUpdatesWhenExistsAndNewerExists() .Throws(new RequestFailedException(status: 412, message: "")) .Verifiable(); - mock.Setup(c => c.UploadAsync( + mock.Setup(c => c.Upload( It.IsAny(), It.IsAny(), It.IsAny>(), @@ -110,14 +108,12 @@ public void StoreUpdatesWhenExistsAndNewerExists() It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(async (Stream strm, BlobHttpHeaders headers, IDictionary metaData, BlobRequestConditions conditions, IProgress progress, AccessTier? access, StorageTransferOptions transfer, CancellationToken token) => + .Returns((Stream strm, BlobHttpHeaders headers, IDictionary metaData, BlobRequestConditions conditions, IProgress progress, AccessTier? access, StorageTransferOptions transfer, CancellationToken token) => { using var memoryStream = new MemoryStream(); strm.CopyTo(memoryStream); bytes = memoryStream.ToArray(); - await Task.Yield(); - var mockResponse = new Mock>(); var blobContentInfo = BlobsModelFactory.BlobContentInfo(ETag.All, DateTimeOffset.Now.AddDays(-1), Array.Empty(), "", 1); mockResponse.Setup(c => c.Value).Returns(blobContentInfo); diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlDecryptor.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlDecryptor.cs index b3fe150878f5f..9bdd4ba213f8f 100644 --- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlDecryptor.cs +++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlDecryptor.cs @@ -23,11 +23,6 @@ public AzureKeyVaultXmlDecryptor(IServiceProvider serviceProvider) } public XElement Decrypt(XElement encryptedElement) - { - return Task.Run(() => DecryptAsync(encryptedElement)).GetAwaiter().GetResult(); - } - - private async Task DecryptAsync(XElement encryptedElement) { var kid = (string)encryptedElement.Element("kid"); var symmetricKey = Convert.FromBase64String((string)encryptedElement.Element("key")); @@ -35,8 +30,8 @@ private async Task DecryptAsync(XElement encryptedElement) var encryptedValue = Convert.FromBase64String((string)encryptedElement.Element("value")); - var key = await _client.ResolveAsync(kid).ConfigureAwait(false); - var result = await key.UnwrapKeyAsync(AzureKeyVaultXmlEncryptor.DefaultKeyEncryption, symmetricKey).ConfigureAwait(false); + var key = _client.Resolve(kid); + var result = key.UnwrapKey(AzureKeyVaultXmlEncryptor.DefaultKeyEncryption, symmetricKey); byte[] decryptedValue; using (var symmetricAlgorithm = AzureKeyVaultXmlEncryptor.DefaultSymmetricAlgorithmFactory()) diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlEncryptor.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlEncryptor.cs index dd50987797713..400f4b0ce5f39 100644 --- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlEncryptor.cs +++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlEncryptor.cs @@ -35,11 +35,6 @@ internal AzureKeyVaultXmlEncryptor(IKeyEncryptionKeyResolver client, string keyI } public EncryptedXmlInfo Encrypt(XElement plaintextElement) - { - return Task.Run(() => EncryptAsync(plaintextElement)).GetAwaiter().GetResult(); - } - - private async Task EncryptAsync(XElement plaintextElement) { byte[] value; using (var memoryStream = new MemoryStream()) @@ -62,8 +57,8 @@ private async Task EncryptAsync(XElement plaintextElement) encryptedValue = encryptor.TransformFinalBlock(value, 0, value.Length); } - var key = await _client.ResolveAsync(_keyId).ConfigureAwait(false); - var wrappedKey = await key.WrapKeyAsync(DefaultKeyEncryption, symmetricKey).ConfigureAwait(false); + var key = _client.Resolve(_keyId); + var wrappedKey = key.WrapKey(DefaultKeyEncryption, symmetricKey); var element = new XElement("encryptedKey", new XComment(" This key is encrypted with Azure Key Vault. "), diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/tests/AzureKeyVaultXmlEncryptorTests.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/tests/AzureKeyVaultXmlEncryptorTests.cs index f66a45e7337c1..4f711bae57e4e 100644 --- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/tests/AzureKeyVaultXmlEncryptorTests.cs +++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/tests/AzureKeyVaultXmlEncryptorTests.cs @@ -19,15 +19,15 @@ public class AzureKeyVaultXmlEncryptorTests public void UsesKeyVaultToEncryptKey() { var keyMock = new Mock(MockBehavior.Strict); - keyMock.Setup(client => client.WrapKeyAsync("RSA-OAEP", It.IsAny>(), default)) - .ReturnsAsync((string _, ReadOnlyMemory data, CancellationToken __) => data.ToArray().Reverse().ToArray()) + keyMock.Setup(client => client.WrapKey("RSA-OAEP", It.IsAny>(), default)) + .Returns((string _, ReadOnlyMemory data, CancellationToken __) => data.ToArray().Reverse().ToArray()) .Verifiable(); keyMock.SetupGet(client => client.KeyId).Returns("KeyId"); var mock = new Mock(); - mock.Setup(client => client.ResolveAsync("key", default)) - .ReturnsAsync((string _, CancellationToken __) => keyMock.Object) + mock.Setup(client => client.Resolve("key", default)) + .Returns((string _, CancellationToken __) => keyMock.Object) .Verifiable(); var encryptor = new AzureKeyVaultXmlEncryptor(mock.Object, "key", new MockNumberGenerator()); @@ -50,13 +50,13 @@ public void UsesKeyVaultToEncryptKey() public void UsesKeyVaultToDecryptKey() { var keyMock = new Mock(MockBehavior.Strict); - keyMock.Setup(client => client.UnwrapKeyAsync("RSA-OAEP", It.IsAny>(), default)) - .ReturnsAsync((string _, ReadOnlyMemory data, CancellationToken __) => data.ToArray().Reverse().ToArray()) + keyMock.Setup(client => client.UnwrapKey("RSA-OAEP", It.IsAny>(), default)) + .Returns((string _, ReadOnlyMemory data, CancellationToken __) => data.ToArray().Reverse().ToArray()) .Verifiable(); var mock = new Mock(); - mock.Setup(client => client.ResolveAsync("KeyId", default)) - .ReturnsAsync((string _, CancellationToken __) => keyMock.Object) + mock.Setup(client => client.Resolve("KeyId", default)) + .Returns((string _, CancellationToken __) => keyMock.Object) .Verifiable(); var serviceCollection = new ServiceCollection(); diff --git a/sdk/extensions/Microsoft.Extensions.Azure/src/AzureClientFactoryBuilder.cs b/sdk/extensions/Microsoft.Extensions.Azure/src/AzureClientFactoryBuilder.cs index 2aa538a91dca0..29edcbf2b7e40 100644 --- a/sdk/extensions/Microsoft.Extensions.Azure/src/AzureClientFactoryBuilder.cs +++ b/sdk/extensions/Microsoft.Extensions.Azure/src/AzureClientFactoryBuilder.cs @@ -75,7 +75,7 @@ public AzureClientFactoryBuilder ConfigureDefaults(ActionThis instance. public AzureClientFactoryBuilder ConfigureDefaults(IConfiguration configuration) { - ConfigureDefaults(options => configuration.Bind(options)); + ConfigureDefaults(configuration.Bind); var credentialsFromConfig = ClientFactory.CreateCredential(configuration); @@ -161,18 +161,14 @@ public IAzureClientBuilder AddClient(Func< private IAzureClientBuilder RegisterClientFactory(Func clientFactory, bool requiresCredential) where TOptions : class { - var clientRegistration = new ClientRegistration(DefaultClientName, (provider, options, credential) => clientFactory((TOptions) options, credential, provider)); - clientRegistration.RequiresTokenCredential = requiresCredential; - + var clientRegistration = new ClientRegistration(DefaultClientName, requiresCredential, (provider, options, credential) => clientFactory((TOptions) options, credential, provider)); _serviceCollection.AddSingleton(clientRegistration); _serviceCollection.TryAddSingleton(typeof(IConfigureOptions>), typeof(DefaultCredentialClientOptionsSetup)); _serviceCollection.TryAddSingleton(typeof(IOptionsMonitor), typeof(ClientOptionsMonitor)); _serviceCollection.TryAddSingleton(typeof(ClientOptionsFactory), typeof(ClientOptionsFactory)); _serviceCollection.TryAddSingleton(typeof(IAzureClientFactory), typeof(AzureClientFactory)); - _serviceCollection.TryAddSingleton( - typeof(TClient), - provider => provider.GetService>().CreateClient(DefaultClientName)); + _serviceCollection.TryAddSingleton(typeof(TClient), provider => provider.GetService>().CreateClient(DefaultClientName)); return new AzureClientBuilder(clientRegistration, _serviceCollection); } diff --git a/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientOptionsFactory.cs b/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientOptionsFactory.cs index 1808866fe2ada..fe5f353ecda21 100644 --- a/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientOptionsFactory.cs +++ b/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientOptionsFactory.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Reflection; using Azure.Core; using Microsoft.Extensions.Options; diff --git a/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientRegistration.cs b/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientRegistration.cs index 89fa8da1c2bad..1f3705f1f3281 100644 --- a/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientRegistration.cs +++ b/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientRegistration.cs @@ -12,46 +12,38 @@ internal class ClientRegistration : IDisposable, IAsyncDisposable { public string Name { get; set; } public object Version { get; set; } - public bool RequiresTokenCredential { get; set; } + private readonly bool _requiresCredential; private readonly Func _factory; - private readonly object _cacheLock = new object(); - private readonly bool _asyncDisposable; - private readonly bool _disposable; + private readonly object _cacheLock = new(); private bool _clientInitialized; private TClient _cachedClient; private ExceptionDispatchInfo _cachedException; - public ClientRegistration(string name, Func factory) + public ClientRegistration(string name, bool requiresCredential, Func factory) { Name = name; + _requiresCredential = requiresCredential; _factory = factory; - - _asyncDisposable = typeof(IAsyncDisposable).IsAssignableFrom(typeof(TClient)); - _disposable = typeof(IDisposable).IsAssignableFrom(typeof(TClient)); } public TClient GetClient(IServiceProvider serviceProvider, object options, TokenCredential tokenCredential) { - _cachedException?.Throw(); - - if (_clientInitialized) + if (TryGetCachedClientOrThrow(out var cachedClient)) { - return _cachedClient; + return cachedClient; } lock (_cacheLock) { - _cachedException?.Throw(); - - if (_clientInitialized) + if (TryGetCachedClientOrThrow(out cachedClient)) { - return _cachedClient; + return cachedClient; } - if (RequiresTokenCredential && tokenCredential == null) + if (_requiresCredential && tokenCredential == null) { throw new InvalidOperationException("Client registration requires a TokenCredential. Configure it using UseCredential method."); } @@ -71,63 +63,87 @@ public TClient GetClient(IServiceProvider serviceProvider, object options, Token } } - public async ValueTask DisposeAsync() + /// + /// Client registration can be in one of 4 states: + /// - _clientInitialized is false, _cachedException is null: _factory has never been called, client hasn't been initialized yet + /// - _clientInitialized is false, _cachedException is not null: _factory has been called, but exception has happened. Client isn't been initialized. + /// - _clientInitialized is true, _cachedClient is not null: client has been initialized and is not disposed yet + /// - _clientInitialized is true, _cachedClient is null: client has been initialized and is disposed + /// + /// + /// True if client is initialized and not disposed, otherwise false. + private bool TryGetCachedClientOrThrow(out TClient cachedClient) { + _cachedException?.Throw(); if (_clientInitialized) { - if (_asyncDisposable) - { - IAsyncDisposable disposableClient; + cachedClient = _cachedClient ?? throw new ObjectDisposedException(nameof(ClientRegistration)); + return true; + } - lock (_cacheLock) - { - if (!_clientInitialized) - { - return; - } + cachedClient = default; + return false; + } - disposableClient = (IAsyncDisposable)_cachedClient; + public async ValueTask DisposeAsync() + { + if (!_clientInitialized) + { + return; + } - _cachedClient = default; - _clientInitialized = false; - } + if (_cachedClient is null) + { + return; + } - await disposableClient.DisposeAsync().ConfigureAwait(false); - } - else if (_disposable) - { - Dispose(); - } + IDisposable disposableClient; + IAsyncDisposable asyncDisposableClient; + lock (_cacheLock) + { + disposableClient = _cachedClient as IDisposable; + asyncDisposableClient = _cachedClient as IAsyncDisposable; + _cachedClient = default; + } + + if (asyncDisposableClient is not null) + { + await asyncDisposableClient.DisposeAsync().ConfigureAwait(false); + } + else if (disposableClient is not null) + { + disposableClient.Dispose(); } } public void Dispose() { - if (_clientInitialized) + if (!_clientInitialized) { - if (_disposable) - { - IDisposable disposableClient; - - lock (_cacheLock) - { - if (!_clientInitialized) - { - return; - } + return; + } - disposableClient = (IDisposable)_cachedClient; + if (_cachedClient is null) + { + return; + } - _cachedClient = default; - _clientInitialized = false; - } + IDisposable disposableClient; + IAsyncDisposable asyncDisposableClient; + lock (_cacheLock) + { + disposableClient = _cachedClient as IDisposable; + asyncDisposableClient = _cachedClient as IAsyncDisposable; + _cachedClient = default; + } - disposableClient.Dispose(); - } - else if (_asyncDisposable) - { - DisposeAsync().GetAwaiter().GetResult(); - } + if (disposableClient is not null) + { + disposableClient.Dispose(); + } + else if (asyncDisposableClient is not null) + { + asyncDisposableClient.DisposeAsync().GetAwaiter().GetResult(); } } }