diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/AzureKeyVaultConfigurationProvider.cs b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/AzureKeyVaultConfigurationProvider.cs
index 9d8a3ba0b3d72..8b359d2b25e17 100644
--- a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/AzureKeyVaultConfigurationProvider.cs
+++ b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/AzureKeyVaultConfigurationProvider.cs
@@ -53,7 +53,68 @@ public AzureKeyVaultConfigurationProvider(SecretClient client, AzureKeyVaultConf
///
/// Load secrets into this provider.
///
- public override void Load() => LoadAsync().GetAwaiter().GetResult();
+ public override void Load()
+ {
+ var secretPages = _client.GetPropertiesOfSecrets();
+
+ using var secretLoader = new ParallelSecretLoader(_client);
+ var newLoadedSecrets = new Dictionary();
+ var oldLoadedSecrets = Interlocked.Exchange(ref _loadedSecrets, null);
+
+ foreach (var secret in secretPages)
+ {
+ AddSecretToLoader(secret, oldLoadedSecrets, newLoadedSecrets, secretLoader);
+ }
+
+ var loadedSecret = secretLoader.WaitForAll();
+ UpdateSecrets(loadedSecret, newLoadedSecrets, oldLoadedSecrets);
+
+ // schedule a polling task only if none exists and a valid delay is specified
+ if (_pollingTask == null && _reloadInterval != null)
+ {
+ _pollingTask = PollForSecretChangesAsync();
+ }
+ }
+
+ private void AddSecretToLoader(SecretProperties secret, Dictionary oldLoadedSecrets, Dictionary newLoadedSecrets, ParallelSecretLoader secretLoader)
+ {
+ if (!_manager.Load(secret) || secret.Enabled != true)
+ {
+ return;
+ }
+
+ var secretId = secret.Name;
+ if (oldLoadedSecrets != null && oldLoadedSecrets.TryGetValue(secretId, out var existingSecret) && IsUpToDate(existingSecret, secret))
+ {
+ oldLoadedSecrets.Remove(secretId);
+ newLoadedSecrets.Add(secretId, existingSecret);
+ }
+ else
+ {
+ secretLoader.AddSecretToLoad(secret.Name);
+ }
+ }
+
+ private void UpdateSecrets(Response[] loadedSecret, Dictionary newLoadedSecrets, Dictionary oldLoadedSecrets)
+ {
+ foreach (var secretBundle in loadedSecret)
+ {
+ newLoadedSecrets.Add(secretBundle.Value.Name, secretBundle);
+ }
+
+ _loadedSecrets = newLoadedSecrets;
+
+ // Reload is needed if we are loading secrets that were not loaded before or
+ // secret that was loaded previously is not available anymore
+ if (loadedSecret.Any() || oldLoadedSecrets?.Any() == true)
+ {
+ Data = _manager.GetData(newLoadedSecrets.Values);
+ if (oldLoadedSecrets != null)
+ {
+ OnReload();
+ }
+ }
+ }
private async Task PollForSecretChangesAsync()
{
@@ -79,57 +140,17 @@ internal virtual Task WaitForReload()
private async Task LoadAsync()
{
- var secretPages = _client.GetPropertiesOfSecretsAsync();
-
using var secretLoader = new ParallelSecretLoader(_client);
var newLoadedSecrets = new Dictionary();
var oldLoadedSecrets = Interlocked.Exchange(ref _loadedSecrets, null);
- await foreach (var secret in secretPages.ConfigureAwait(false))
- {
- if (!_manager.Load(secret) || secret.Enabled != true)
- {
- continue;
- }
-
- var secretId = secret.Name;
- if (oldLoadedSecrets != null &&
- oldLoadedSecrets.TryGetValue(secretId, out var existingSecret) &&
- IsUpToDate(existingSecret, secret))
- {
- oldLoadedSecrets.Remove(secretId);
- newLoadedSecrets.Add(secretId, existingSecret);
- }
- else
- {
- secretLoader.Add(secret.Name);
- }
- }
-
- var loadedSecret = await secretLoader.WaitForAll().ConfigureAwait(false);
- foreach (var secretBundle in loadedSecret)
- {
- newLoadedSecrets.Add(secretBundle.Value.Name, secretBundle);
- }
-
- _loadedSecrets = newLoadedSecrets;
-
- // Reload is needed if we are loading secrets that were not loaded before or
- // secret that was loaded previously is not available anymore
- if (loadedSecret.Any() || oldLoadedSecrets?.Any() == true)
+ await foreach (var secret in _client.GetPropertiesOfSecretsAsync().ConfigureAwait(false))
{
- Data = _manager.GetData(newLoadedSecrets.Values);
- if (oldLoadedSecrets != null)
- {
- OnReload();
- }
+ AddSecretToLoader(secret, oldLoadedSecrets, newLoadedSecrets, secretLoader);
}
- // schedule a polling task only if none exists and a valid delay is specified
- if (_pollingTask == null && _reloadInterval != null)
- {
- _pollingTask = PollForSecretChangesAsync();
- }
+ var loadedSecret = await secretLoader.WaitForAllAsync().ConfigureAwait(false);
+ UpdateSecrets(loadedSecret, newLoadedSecrets, oldLoadedSecrets);
}
///
diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/ParallelSecretLoader.cs b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/ParallelSecretLoader.cs
index b01970052e5f5..c9267e70e18d8 100644
--- a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/ParallelSecretLoader.cs
+++ b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/src/ParallelSecretLoader.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Azure.Security.KeyVault.Secrets;
@@ -23,12 +24,12 @@ public ParallelSecretLoader(SecretClient client)
_tasks = new List>>();
}
- public void Add(string secretName)
+ public void AddSecretToLoad(string secretName)
{
- _tasks.Add(GetSecret(secretName));
+ _tasks.Add(Task.Run(() => GetSecretAsync(secretName)));
}
- private async Task> GetSecret(string secretName)
+ private async Task> GetSecretAsync(string secretName)
{
await _semaphore.WaitAsync().ConfigureAwait(false);
try
@@ -41,11 +42,14 @@ private async Task> GetSecret(string secretName)
}
}
- public Task[]> WaitForAll()
+ public Response[] WaitForAll()
{
- return Task.WhenAll(_tasks);
+ Task.WaitAll(_tasks.Select(t => (Task)t).ToArray());
+ return _tasks.Select(t => t.Result).ToArray();
}
+ public Task[]> WaitForAllAsync() => Task.WhenAll(_tasks);
+
public void Dispose()
{
_semaphore?.Dispose();
diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/tests/AzureKeyVaultConfigurationTests.cs b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/tests/AzureKeyVaultConfigurationTests.cs
index 5946ee5686825..0496d174abbe3 100644
--- a/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/tests/AzureKeyVaultConfigurationTests.cs
+++ b/sdk/extensions/Azure.Extensions.AspNetCore.Configuration.Secrets/tests/AzureKeyVaultConfigurationTests.cs
@@ -32,6 +32,7 @@ private void SetPages(Mock mock, Func getSecretCallb
var pagesOfProperties = pages.Select(
page => page.Select(secret => secret.Properties).ToArray()).ToArray();
+ mock.Setup(m => m.GetPropertiesOfSecrets(default)).Returns(new MockPageable(pagesOfProperties));
mock.Setup(m => m.GetPropertiesOfSecretsAsync(default)).Returns(new MockAsyncPageable(pagesOfProperties));
foreach (var page in pages)
@@ -49,6 +50,24 @@ private void SetPages(Mock mock, Func getSecretCallb
}
}
+ private class MockPageable : Pageable
+ {
+ private readonly SecretProperties[][] _pages;
+
+ public MockPageable(SecretProperties[][] pages)
+ {
+ _pages = pages;
+ }
+
+ public override IEnumerable> AsPages(string continuationToken = null, int? pageSizeHint = null)
+ {
+ foreach (var page in _pages)
+ {
+ yield return Page.FromValues(page, null, Mock.Of());
+ }
+ }
+ }
+
private class MockAsyncPageable : AsyncPageable
{
private readonly SecretProperties[][] _pages;
@@ -353,7 +372,7 @@ public async Task SupportsReloadOnEnabledChange()
await provider.Wait();
SetPages(client,
- new[]
+ new[]
{
CreateSecret("Secret1", "Value2"),
CreateSecret("Secret2", "Value2", enabled: false)
diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/src/AzureBlobXmlRepository.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/src/AzureBlobXmlRepository.cs
index 6a2ecb2bce959..c892daefefe15 100644
--- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/src/AzureBlobXmlRepository.cs
+++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/src/AzureBlobXmlRepository.cs
@@ -50,51 +50,11 @@ public AzureBlobXmlRepository(BlobClient blobClient)
///
public IReadOnlyCollection GetAllElements()
{
- // Shunt the work onto a ThreadPool thread so that it's independent of any
- // existing sync context or other potentially deadlock-causing items.
-
- var elements = Task.Run(() => GetAllElementsAsync()).GetAwaiter().GetResult();
- return new ReadOnlyCollection(elements);
- }
-
- ///
- public void StoreElement(XElement element, string friendlyName)
- {
- if (element == null)
- {
- throw new ArgumentNullException(nameof(element));
- }
-
- // Shunt the work onto a ThreadPool thread so that it's independent of any
- // existing sync context or other potentially deadlock-causing items.
-
- Task.Run(() => StoreElementAsync(element)).GetAwaiter().GetResult();
- }
-
- private static XDocument CreateDocumentFromBlobData(BlobData blobData)
- {
- if (blobData == null || blobData.BlobContents.Length == 0)
- {
- return new XDocument(new XElement(RepositoryElementName));
- }
-
- using var memoryStream = new MemoryStream(blobData.BlobContents);
-
- var xmlReaderSettings = new XmlReaderSettings()
- {
- DtdProcessing = DtdProcessing.Prohibit,
- IgnoreProcessingInstructions = true,
- };
-
- using (var xmlReader = XmlReader.Create(memoryStream, xmlReaderSettings))
- {
- return XDocument.Load(xmlReader);
- }
- }
-
- private async Task> GetAllElementsAsync()
- {
- var data = await GetLatestDataAsync().ConfigureAwait(false);
+ // Fixes for #40174
+ // Original `Task.Run(() => GetAllElementsAsync()).GetAwaiter().GetResult();` blocks the thread in ThreadPool until task is completed,
+ // then runs first part of the task on another ThreadPool thread and schedules continuation to run in ThreadPool thread again.
+ // If too many calls of GetAllElements() happens before any continuation is executed, all threads in ThreadPool will become blocked
+ var data = GetLatestData();
// The document will look like this:
//
@@ -107,68 +67,17 @@ private async Task> GetAllElementsAsync()
// We want to return the first-level child elements to our caller.
var doc = CreateDocumentFromBlobData(data);
- return doc.Root.Elements().ToList();
+ return new ReadOnlyCollection(doc.Root.Elements().ToList());
}
- private async Task GetLatestDataAsync()
+ ///
+ public void StoreElement(XElement element, string friendlyName)
{
- // Set the appropriate AccessCondition based on what we believe the latest
- // file contents to be, then make the request.
-
- var latestCachedData = Volatile.Read(ref _cachedBlobData); // local ref so field isn't mutated under our feet
- var requestCondition = (latestCachedData != null)
- ? new BlobRequestConditions() { IfNoneMatch = latestCachedData.ETag }
- : null;
-
- try
- {
- using (var memoryStream = new MemoryStream())
- {
- var response = await _blobClient.DownloadToAsync(
- destination: memoryStream,
- conditions: requestCondition).ConfigureAwait(false);
-
- if (response.Status == 304)
- {
- // 304 Not Modified
- // Thrown when we already have the latest cached data.
- // This isn't an error; we'll return our cached copy of the data.
- return latestCachedData;
- }
-
- // At this point, our original cache either didn't exist or was outdated.
- // We'll update it now and return the updated value
- latestCachedData = new BlobData()
- {
- BlobContents = memoryStream.ToArray(),
- ETag = response.Headers.ETag
- };
- }
- Volatile.Write(ref _cachedBlobData, latestCachedData);
- }
- catch (RequestFailedException ex) when (ex.Status == 404)
+ if (element == null)
{
- // 404 Not Found
- // Thrown when no file exists in storage.
- // This isn't an error; we'll delete our cached copy of data.
-
- latestCachedData = null;
- Volatile.Write(ref _cachedBlobData, latestCachedData);
+ throw new ArgumentNullException(nameof(element));
}
- return latestCachedData;
- }
-
- private int GetRandomizedBackoffPeriod()
- {
- // returns a TimeSpan in the range [0.8, 1.0) * ConflictBackoffPeriod
- // not used for crypto purposes
- var multiplier = 0.8 + (_random.NextDouble() * 0.2);
- return (int) (multiplier * ConflictBackoffPeriod.Ticks);
- }
-
- private async Task StoreElementAsync(XElement element)
- {
// holds the last error in case we need to rethrow it
ExceptionDispatchInfo lastError = null;
@@ -178,14 +87,14 @@ private async Task StoreElementAsync(XElement element)
{
// If multiple conflicts occurred, wait a small period of time before retrying
// the operation so that other writers can make forward progress.
- await Task.Delay(GetRandomizedBackoffPeriod()).ConfigureAwait(false);
+ Thread.Sleep(GetRandomizedBackoffPeriod());
}
if (i > 0)
{
// If at least one conflict occurred, make sure we have an up-to-date
// view of the blob contents.
- await GetLatestDataAsync().ConfigureAwait(false);
+ GetLatestData();
}
// Merge the new element into the document. If no document exists,
@@ -217,10 +126,7 @@ private async Task StoreElementAsync(XElement element)
try
{
// Send the request up to the server.
- var response = await _blobClient.UploadAsync(
- serializedDoc,
- httpHeaders: _blobHttpHeaders,
- conditions: requestConditions).ConfigureAwait(false);
+ var response = _blobClient.Upload(serializedDoc, httpHeaders: _blobHttpHeaders, conditions: requestConditions);
// If we got this far, success!
// We can update the cached view of the remote contents.
@@ -255,6 +161,82 @@ private async Task StoreElementAsync(XElement element)
lastError.Throw();
}
+ private static XDocument CreateDocumentFromBlobData(BlobData blobData)
+ {
+ if (blobData == null || blobData.BlobContents.Length == 0)
+ {
+ return new XDocument(new XElement(RepositoryElementName));
+ }
+
+ using var memoryStream = new MemoryStream(blobData.BlobContents);
+
+ var xmlReaderSettings = new XmlReaderSettings()
+ {
+ DtdProcessing = DtdProcessing.Prohibit,
+ IgnoreProcessingInstructions = true,
+ };
+
+ using (var xmlReader = XmlReader.Create(memoryStream, xmlReaderSettings))
+ {
+ return XDocument.Load(xmlReader);
+ }
+ }
+
+ private BlobData GetLatestData()
+ {
+ // Set the appropriate AccessCondition based on what we believe the latest
+ // file contents to be, then make the request.
+
+ var latestCachedData = Volatile.Read(ref _cachedBlobData); // local ref so field isn't mutated under our feet
+ var requestCondition = (latestCachedData != null)
+ ? new BlobRequestConditions() { IfNoneMatch = latestCachedData.ETag }
+ : null;
+
+ try
+ {
+ using (var memoryStream = new MemoryStream())
+ {
+ var response = _blobClient.DownloadTo(destination: memoryStream, conditions: requestCondition);
+
+ if (response.Status == 304)
+ {
+ // 304 Not Modified
+ // Thrown when we already have the latest cached data.
+ // This isn't an error; we'll return our cached copy of the data.
+ return latestCachedData;
+ }
+
+ // At this point, our original cache either didn't exist or was outdated.
+ // We'll update it now and return the updated value
+ latestCachedData = new BlobData
+ {
+ BlobContents = memoryStream.ToArray(),
+ ETag = response.Headers.ETag
+ };
+ }
+ Volatile.Write(ref _cachedBlobData, latestCachedData);
+ }
+ catch (RequestFailedException ex) when (ex.Status == 404)
+ {
+ // 404 Not Found
+ // Thrown when no file exists in storage.
+ // This isn't an error; we'll delete our cached copy of data.
+
+ latestCachedData = null;
+ Volatile.Write(ref _cachedBlobData, latestCachedData);
+ }
+
+ return latestCachedData;
+ }
+
+ private int GetRandomizedBackoffPeriod()
+ {
+ // returns a TimeSpan in the range [0.8, 1.0) * ConflictBackoffPeriod
+ // not used for crypto purposes
+ var multiplier = 0.8 + (_random.NextDouble() * 0.2);
+ return (int) (multiplier * ConflictBackoffPeriod.Ticks);
+ }
+
private sealed class BlobData
{
internal byte[] BlobContents;
diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/tests/AzureBlobXmlRepositoryTests.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/tests/AzureBlobXmlRepositoryTests.cs
index 9a8ae4d4be1d2..d8b4013373e7a 100644
--- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/tests/AzureBlobXmlRepositoryTests.cs
+++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Blobs/tests/AzureBlobXmlRepositoryTests.cs
@@ -30,7 +30,7 @@ public void StoreCreatesBlobWhenNotExist()
var mock = new Mock();
- mock.Setup(c => c.UploadAsync(
+ mock.Setup(c => c.Upload(
It.IsAny(),
It.IsAny(),
It.IsAny>(),
@@ -39,7 +39,7 @@ public void StoreCreatesBlobWhenNotExist()
It.IsAny(),
It.IsAny(),
It.IsAny()))
- .Returns(async (Stream strm, BlobHttpHeaders headers, IDictionary metaData, BlobRequestConditions conditions, IProgress progress, AccessTier? access, StorageTransferOptions transfer, CancellationToken token) =>
+ .Returns((Stream strm, BlobHttpHeaders headers, IDictionary metaData, BlobRequestConditions conditions, IProgress progress, AccessTier? access, StorageTransferOptions transfer, CancellationToken token) =>
{
using var memoryStream = new MemoryStream();
strm.CopyTo(memoryStream);
@@ -47,8 +47,6 @@ public void StoreCreatesBlobWhenNotExist()
uploadConditions = conditions;
contentType = headers?.ContentType;
- await Task.Yield();
-
var mockResponse = new Mock>();
var blobContentInfo = BlobsModelFactory.BlobContentInfo(ETag.All, DateTimeOffset.Now.AddDays(-1), Array.Empty(), "", 1);
@@ -73,15 +71,15 @@ public void StoreUpdatesWhenExistsAndNewerExists()
var mock = new Mock();
- mock.Setup(c => c.DownloadToAsync(
+ mock.Setup(c => c.DownloadTo(
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny()))
- .Returns(async (Stream target, BlobRequestConditions conditions, StorageTransferOptions options, CancellationToken token) =>
+ .Returns((Stream target, BlobRequestConditions conditions, StorageTransferOptions options, CancellationToken token) =>
{
var data = GetEnvelopedContent("");
- await target.WriteAsync(data, 0, data.Length);
+ target.Write(data, 0, data.Length);
var response = new MockResponse(200);
response.AddHeader(new HttpHeader("ETag", "*"));
@@ -89,7 +87,7 @@ public void StoreUpdatesWhenExistsAndNewerExists()
})
.Verifiable();
- mock.Setup(c => c.UploadAsync(
+ mock.Setup(c => c.Upload(
It.IsAny(),
It.IsAny(),
It.IsAny>(),
@@ -101,7 +99,7 @@ public void StoreUpdatesWhenExistsAndNewerExists()
.Throws(new RequestFailedException(status: 412, message: ""))
.Verifiable();
- mock.Setup(c => c.UploadAsync(
+ mock.Setup(c => c.Upload(
It.IsAny(),
It.IsAny(),
It.IsAny>(),
@@ -110,14 +108,12 @@ public void StoreUpdatesWhenExistsAndNewerExists()
It.IsAny(),
It.IsAny(),
It.IsAny()))
- .Returns(async (Stream strm, BlobHttpHeaders headers, IDictionary metaData, BlobRequestConditions conditions, IProgress progress, AccessTier? access, StorageTransferOptions transfer, CancellationToken token) =>
+ .Returns((Stream strm, BlobHttpHeaders headers, IDictionary metaData, BlobRequestConditions conditions, IProgress progress, AccessTier? access, StorageTransferOptions transfer, CancellationToken token) =>
{
using var memoryStream = new MemoryStream();
strm.CopyTo(memoryStream);
bytes = memoryStream.ToArray();
- await Task.Yield();
-
var mockResponse = new Mock>();
var blobContentInfo = BlobsModelFactory.BlobContentInfo(ETag.All, DateTimeOffset.Now.AddDays(-1), Array.Empty(), "", 1);
mockResponse.Setup(c => c.Value).Returns(blobContentInfo);
diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlDecryptor.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlDecryptor.cs
index b3fe150878f5f..9bdd4ba213f8f 100644
--- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlDecryptor.cs
+++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlDecryptor.cs
@@ -23,11 +23,6 @@ public AzureKeyVaultXmlDecryptor(IServiceProvider serviceProvider)
}
public XElement Decrypt(XElement encryptedElement)
- {
- return Task.Run(() => DecryptAsync(encryptedElement)).GetAwaiter().GetResult();
- }
-
- private async Task DecryptAsync(XElement encryptedElement)
{
var kid = (string)encryptedElement.Element("kid");
var symmetricKey = Convert.FromBase64String((string)encryptedElement.Element("key"));
@@ -35,8 +30,8 @@ private async Task DecryptAsync(XElement encryptedElement)
var encryptedValue = Convert.FromBase64String((string)encryptedElement.Element("value"));
- var key = await _client.ResolveAsync(kid).ConfigureAwait(false);
- var result = await key.UnwrapKeyAsync(AzureKeyVaultXmlEncryptor.DefaultKeyEncryption, symmetricKey).ConfigureAwait(false);
+ var key = _client.Resolve(kid);
+ var result = key.UnwrapKey(AzureKeyVaultXmlEncryptor.DefaultKeyEncryption, symmetricKey);
byte[] decryptedValue;
using (var symmetricAlgorithm = AzureKeyVaultXmlEncryptor.DefaultSymmetricAlgorithmFactory())
diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlEncryptor.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlEncryptor.cs
index dd50987797713..400f4b0ce5f39 100644
--- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlEncryptor.cs
+++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/src/AzureKeyVaultXmlEncryptor.cs
@@ -35,11 +35,6 @@ internal AzureKeyVaultXmlEncryptor(IKeyEncryptionKeyResolver client, string keyI
}
public EncryptedXmlInfo Encrypt(XElement plaintextElement)
- {
- return Task.Run(() => EncryptAsync(plaintextElement)).GetAwaiter().GetResult();
- }
-
- private async Task EncryptAsync(XElement plaintextElement)
{
byte[] value;
using (var memoryStream = new MemoryStream())
@@ -62,8 +57,8 @@ private async Task EncryptAsync(XElement plaintextElement)
encryptedValue = encryptor.TransformFinalBlock(value, 0, value.Length);
}
- var key = await _client.ResolveAsync(_keyId).ConfigureAwait(false);
- var wrappedKey = await key.WrapKeyAsync(DefaultKeyEncryption, symmetricKey).ConfigureAwait(false);
+ var key = _client.Resolve(_keyId);
+ var wrappedKey = key.WrapKey(DefaultKeyEncryption, symmetricKey);
var element = new XElement("encryptedKey",
new XComment(" This key is encrypted with Azure Key Vault. "),
diff --git a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/tests/AzureKeyVaultXmlEncryptorTests.cs b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/tests/AzureKeyVaultXmlEncryptorTests.cs
index f66a45e7337c1..4f711bae57e4e 100644
--- a/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/tests/AzureKeyVaultXmlEncryptorTests.cs
+++ b/sdk/extensions/Azure.Extensions.AspNetCore.DataProtection.Keys/tests/AzureKeyVaultXmlEncryptorTests.cs
@@ -19,15 +19,15 @@ public class AzureKeyVaultXmlEncryptorTests
public void UsesKeyVaultToEncryptKey()
{
var keyMock = new Mock(MockBehavior.Strict);
- keyMock.Setup(client => client.WrapKeyAsync("RSA-OAEP", It.IsAny>(), default))
- .ReturnsAsync((string _, ReadOnlyMemory data, CancellationToken __) => data.ToArray().Reverse().ToArray())
+ keyMock.Setup(client => client.WrapKey("RSA-OAEP", It.IsAny>(), default))
+ .Returns((string _, ReadOnlyMemory data, CancellationToken __) => data.ToArray().Reverse().ToArray())
.Verifiable();
keyMock.SetupGet(client => client.KeyId).Returns("KeyId");
var mock = new Mock();
- mock.Setup(client => client.ResolveAsync("key", default))
- .ReturnsAsync((string _, CancellationToken __) => keyMock.Object)
+ mock.Setup(client => client.Resolve("key", default))
+ .Returns((string _, CancellationToken __) => keyMock.Object)
.Verifiable();
var encryptor = new AzureKeyVaultXmlEncryptor(mock.Object, "key", new MockNumberGenerator());
@@ -50,13 +50,13 @@ public void UsesKeyVaultToEncryptKey()
public void UsesKeyVaultToDecryptKey()
{
var keyMock = new Mock(MockBehavior.Strict);
- keyMock.Setup(client => client.UnwrapKeyAsync("RSA-OAEP", It.IsAny>(), default))
- .ReturnsAsync((string _, ReadOnlyMemory data, CancellationToken __) => data.ToArray().Reverse().ToArray())
+ keyMock.Setup(client => client.UnwrapKey("RSA-OAEP", It.IsAny>(), default))
+ .Returns((string _, ReadOnlyMemory data, CancellationToken __) => data.ToArray().Reverse().ToArray())
.Verifiable();
var mock = new Mock();
- mock.Setup(client => client.ResolveAsync("KeyId", default))
- .ReturnsAsync((string _, CancellationToken __) => keyMock.Object)
+ mock.Setup(client => client.Resolve("KeyId", default))
+ .Returns((string _, CancellationToken __) => keyMock.Object)
.Verifiable();
var serviceCollection = new ServiceCollection();
diff --git a/sdk/extensions/Microsoft.Extensions.Azure/src/AzureClientFactoryBuilder.cs b/sdk/extensions/Microsoft.Extensions.Azure/src/AzureClientFactoryBuilder.cs
index 2aa538a91dca0..29edcbf2b7e40 100644
--- a/sdk/extensions/Microsoft.Extensions.Azure/src/AzureClientFactoryBuilder.cs
+++ b/sdk/extensions/Microsoft.Extensions.Azure/src/AzureClientFactoryBuilder.cs
@@ -75,7 +75,7 @@ public AzureClientFactoryBuilder ConfigureDefaults(ActionThis instance.
public AzureClientFactoryBuilder ConfigureDefaults(IConfiguration configuration)
{
- ConfigureDefaults(options => configuration.Bind(options));
+ ConfigureDefaults(configuration.Bind);
var credentialsFromConfig = ClientFactory.CreateCredential(configuration);
@@ -161,18 +161,14 @@ public IAzureClientBuilder AddClient(Func<
private IAzureClientBuilder RegisterClientFactory(Func clientFactory, bool requiresCredential) where TOptions : class
{
- var clientRegistration = new ClientRegistration(DefaultClientName, (provider, options, credential) => clientFactory((TOptions) options, credential, provider));
- clientRegistration.RequiresTokenCredential = requiresCredential;
-
+ var clientRegistration = new ClientRegistration(DefaultClientName, requiresCredential, (provider, options, credential) => clientFactory((TOptions) options, credential, provider));
_serviceCollection.AddSingleton(clientRegistration);
_serviceCollection.TryAddSingleton(typeof(IConfigureOptions>), typeof(DefaultCredentialClientOptionsSetup));
_serviceCollection.TryAddSingleton(typeof(IOptionsMonitor), typeof(ClientOptionsMonitor));
_serviceCollection.TryAddSingleton(typeof(ClientOptionsFactory), typeof(ClientOptionsFactory));
_serviceCollection.TryAddSingleton(typeof(IAzureClientFactory), typeof(AzureClientFactory));
- _serviceCollection.TryAddSingleton(
- typeof(TClient),
- provider => provider.GetService>().CreateClient(DefaultClientName));
+ _serviceCollection.TryAddSingleton(typeof(TClient), provider => provider.GetService>().CreateClient(DefaultClientName));
return new AzureClientBuilder(clientRegistration, _serviceCollection);
}
diff --git a/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientOptionsFactory.cs b/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientOptionsFactory.cs
index 1808866fe2ada..fe5f353ecda21 100644
--- a/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientOptionsFactory.cs
+++ b/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientOptionsFactory.cs
@@ -3,7 +3,6 @@
using System;
using System.Collections.Generic;
-using System.Reflection;
using Azure.Core;
using Microsoft.Extensions.Options;
diff --git a/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientRegistration.cs b/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientRegistration.cs
index 89fa8da1c2bad..1f3705f1f3281 100644
--- a/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientRegistration.cs
+++ b/sdk/extensions/Microsoft.Extensions.Azure/src/Internal/ClientRegistration.cs
@@ -12,46 +12,38 @@ internal class ClientRegistration : IDisposable, IAsyncDisposable
{
public string Name { get; set; }
public object Version { get; set; }
- public bool RequiresTokenCredential { get; set; }
+ private readonly bool _requiresCredential;
private readonly Func _factory;
- private readonly object _cacheLock = new object();
- private readonly bool _asyncDisposable;
- private readonly bool _disposable;
+ private readonly object _cacheLock = new();
private bool _clientInitialized;
private TClient _cachedClient;
private ExceptionDispatchInfo _cachedException;
- public ClientRegistration(string name, Func factory)
+ public ClientRegistration(string name, bool requiresCredential, Func factory)
{
Name = name;
+ _requiresCredential = requiresCredential;
_factory = factory;
-
- _asyncDisposable = typeof(IAsyncDisposable).IsAssignableFrom(typeof(TClient));
- _disposable = typeof(IDisposable).IsAssignableFrom(typeof(TClient));
}
public TClient GetClient(IServiceProvider serviceProvider, object options, TokenCredential tokenCredential)
{
- _cachedException?.Throw();
-
- if (_clientInitialized)
+ if (TryGetCachedClientOrThrow(out var cachedClient))
{
- return _cachedClient;
+ return cachedClient;
}
lock (_cacheLock)
{
- _cachedException?.Throw();
-
- if (_clientInitialized)
+ if (TryGetCachedClientOrThrow(out cachedClient))
{
- return _cachedClient;
+ return cachedClient;
}
- if (RequiresTokenCredential && tokenCredential == null)
+ if (_requiresCredential && tokenCredential == null)
{
throw new InvalidOperationException("Client registration requires a TokenCredential. Configure it using UseCredential method.");
}
@@ -71,63 +63,87 @@ public TClient GetClient(IServiceProvider serviceProvider, object options, Token
}
}
- public async ValueTask DisposeAsync()
+ ///
+ /// Client registration can be in one of 4 states:
+ /// - _clientInitialized is false, _cachedException is null: _factory has never been called, client hasn't been initialized yet
+ /// - _clientInitialized is false, _cachedException is not null: _factory has been called, but exception has happened. Client isn't been initialized.
+ /// - _clientInitialized is true, _cachedClient is not null: client has been initialized and is not disposed yet
+ /// - _clientInitialized is true, _cachedClient is null: client has been initialized and is disposed
+ ///
+ ///
+ /// True if client is initialized and not disposed, otherwise false.
+ private bool TryGetCachedClientOrThrow(out TClient cachedClient)
{
+ _cachedException?.Throw();
if (_clientInitialized)
{
- if (_asyncDisposable)
- {
- IAsyncDisposable disposableClient;
+ cachedClient = _cachedClient ?? throw new ObjectDisposedException(nameof(ClientRegistration));
+ return true;
+ }
- lock (_cacheLock)
- {
- if (!_clientInitialized)
- {
- return;
- }
+ cachedClient = default;
+ return false;
+ }
- disposableClient = (IAsyncDisposable)_cachedClient;
+ public async ValueTask DisposeAsync()
+ {
+ if (!_clientInitialized)
+ {
+ return;
+ }
- _cachedClient = default;
- _clientInitialized = false;
- }
+ if (_cachedClient is null)
+ {
+ return;
+ }
- await disposableClient.DisposeAsync().ConfigureAwait(false);
- }
- else if (_disposable)
- {
- Dispose();
- }
+ IDisposable disposableClient;
+ IAsyncDisposable asyncDisposableClient;
+ lock (_cacheLock)
+ {
+ disposableClient = _cachedClient as IDisposable;
+ asyncDisposableClient = _cachedClient as IAsyncDisposable;
+ _cachedClient = default;
+ }
+
+ if (asyncDisposableClient is not null)
+ {
+ await asyncDisposableClient.DisposeAsync().ConfigureAwait(false);
+ }
+ else if (disposableClient is not null)
+ {
+ disposableClient.Dispose();
}
}
public void Dispose()
{
- if (_clientInitialized)
+ if (!_clientInitialized)
{
- if (_disposable)
- {
- IDisposable disposableClient;
-
- lock (_cacheLock)
- {
- if (!_clientInitialized)
- {
- return;
- }
+ return;
+ }
- disposableClient = (IDisposable)_cachedClient;
+ if (_cachedClient is null)
+ {
+ return;
+ }
- _cachedClient = default;
- _clientInitialized = false;
- }
+ IDisposable disposableClient;
+ IAsyncDisposable asyncDisposableClient;
+ lock (_cacheLock)
+ {
+ disposableClient = _cachedClient as IDisposable;
+ asyncDisposableClient = _cachedClient as IAsyncDisposable;
+ _cachedClient = default;
+ }
- disposableClient.Dispose();
- }
- else if (_asyncDisposable)
- {
- DisposeAsync().GetAwaiter().GetResult();
- }
+ if (disposableClient is not null)
+ {
+ disposableClient.Dispose();
+ }
+ else if (asyncDisposableClient is not null)
+ {
+ asyncDisposableClient.DisposeAsync().GetAwaiter().GetResult();
}
}
}