Skip to content

Commit

Permalink
Fix Azure#40174: Azure.Extensions.AspNetCore.DataProtection.Keys Heav…
Browse files Browse the repository at this point in the history
…y load causes blocked threads (Azure#40914)

* Fix 40174: Azure.Extensions.AspNetCore.DataProtection.Keys Heavy load causes blocked threads

* Address PR comment
  • Loading branch information
AlexanderSher authored Jan 24, 2024
1 parent 4db7862 commit 2945277
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 258 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,68 @@ public AzureKeyVaultConfigurationProvider(SecretClient client, AzureKeyVaultConf
/// <summary>
/// Load secrets into this provider.
/// </summary>
public override void Load() => LoadAsync().GetAwaiter().GetResult();
public override void Load()
{
var secretPages = _client.GetPropertiesOfSecrets();

using var secretLoader = new ParallelSecretLoader(_client);
var newLoadedSecrets = new Dictionary<string, KeyVaultSecret>();
var oldLoadedSecrets = Interlocked.Exchange(ref _loadedSecrets, null);

foreach (var secret in secretPages)
{
AddSecretToLoader(secret, oldLoadedSecrets, newLoadedSecrets, secretLoader);
}

var loadedSecret = secretLoader.WaitForAll();
UpdateSecrets(loadedSecret, newLoadedSecrets, oldLoadedSecrets);

// schedule a polling task only if none exists and a valid delay is specified
if (_pollingTask == null && _reloadInterval != null)
{
_pollingTask = PollForSecretChangesAsync();
}
}

private void AddSecretToLoader(SecretProperties secret, Dictionary<string, KeyVaultSecret> oldLoadedSecrets, Dictionary<string, KeyVaultSecret> newLoadedSecrets, ParallelSecretLoader secretLoader)
{
if (!_manager.Load(secret) || secret.Enabled != true)
{
return;
}

var secretId = secret.Name;
if (oldLoadedSecrets != null && oldLoadedSecrets.TryGetValue(secretId, out var existingSecret) && IsUpToDate(existingSecret, secret))
{
oldLoadedSecrets.Remove(secretId);
newLoadedSecrets.Add(secretId, existingSecret);
}
else
{
secretLoader.AddSecretToLoad(secret.Name);
}
}

private void UpdateSecrets(Response<KeyVaultSecret>[] loadedSecret, Dictionary<string, KeyVaultSecret> newLoadedSecrets, Dictionary<string, KeyVaultSecret> oldLoadedSecrets)
{
foreach (var secretBundle in loadedSecret)
{
newLoadedSecrets.Add(secretBundle.Value.Name, secretBundle);
}

_loadedSecrets = newLoadedSecrets;

// Reload is needed if we are loading secrets that were not loaded before or
// secret that was loaded previously is not available anymore
if (loadedSecret.Any() || oldLoadedSecrets?.Any() == true)
{
Data = _manager.GetData(newLoadedSecrets.Values);
if (oldLoadedSecrets != null)
{
OnReload();
}
}
}

private async Task PollForSecretChangesAsync()
{
Expand All @@ -79,57 +140,17 @@ internal virtual Task WaitForReload()

private async Task LoadAsync()
{
var secretPages = _client.GetPropertiesOfSecretsAsync();

using var secretLoader = new ParallelSecretLoader(_client);
var newLoadedSecrets = new Dictionary<string, KeyVaultSecret>();
var oldLoadedSecrets = Interlocked.Exchange(ref _loadedSecrets, null);

await foreach (var secret in secretPages.ConfigureAwait(false))
{
if (!_manager.Load(secret) || secret.Enabled != true)
{
continue;
}

var secretId = secret.Name;
if (oldLoadedSecrets != null &&
oldLoadedSecrets.TryGetValue(secretId, out var existingSecret) &&
IsUpToDate(existingSecret, secret))
{
oldLoadedSecrets.Remove(secretId);
newLoadedSecrets.Add(secretId, existingSecret);
}
else
{
secretLoader.Add(secret.Name);
}
}

var loadedSecret = await secretLoader.WaitForAll().ConfigureAwait(false);
foreach (var secretBundle in loadedSecret)
{
newLoadedSecrets.Add(secretBundle.Value.Name, secretBundle);
}

_loadedSecrets = newLoadedSecrets;

// Reload is needed if we are loading secrets that were not loaded before or
// secret that was loaded previously is not available anymore
if (loadedSecret.Any() || oldLoadedSecrets?.Any() == true)
await foreach (var secret in _client.GetPropertiesOfSecretsAsync().ConfigureAwait(false))
{
Data = _manager.GetData(newLoadedSecrets.Values);
if (oldLoadedSecrets != null)
{
OnReload();
}
AddSecretToLoader(secret, oldLoadedSecrets, newLoadedSecrets, secretLoader);
}

// schedule a polling task only if none exists and a valid delay is specified
if (_pollingTask == null && _reloadInterval != null)
{
_pollingTask = PollForSecretChangesAsync();
}
var loadedSecret = await secretLoader.WaitForAllAsync().ConfigureAwait(false);
UpdateSecrets(loadedSecret, newLoadedSecrets, oldLoadedSecrets);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Azure.Security.KeyVault.Secrets;
Expand All @@ -23,12 +24,12 @@ public ParallelSecretLoader(SecretClient client)
_tasks = new List<Task<Response<KeyVaultSecret>>>();
}

public void Add(string secretName)
public void AddSecretToLoad(string secretName)
{
_tasks.Add(GetSecret(secretName));
_tasks.Add(Task.Run(() => GetSecretAsync(secretName)));
}

private async Task<Response<KeyVaultSecret>> GetSecret(string secretName)
private async Task<Response<KeyVaultSecret>> GetSecretAsync(string secretName)
{
await _semaphore.WaitAsync().ConfigureAwait(false);
try
Expand All @@ -41,11 +42,14 @@ private async Task<Response<KeyVaultSecret>> GetSecret(string secretName)
}
}

public Task<Response<KeyVaultSecret>[]> WaitForAll()
public Response<KeyVaultSecret>[] WaitForAll()
{
return Task.WhenAll(_tasks);
Task.WaitAll(_tasks.Select(t => (Task)t).ToArray());
return _tasks.Select(t => t.Result).ToArray();
}

public Task<Response<KeyVaultSecret>[]> WaitForAllAsync() => Task.WhenAll(_tasks);

public void Dispose()
{
_semaphore?.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ private void SetPages(Mock<SecretClient> mock, Func<string, Task> getSecretCallb
var pagesOfProperties = pages.Select(
page => page.Select(secret => secret.Properties).ToArray()).ToArray();

mock.Setup(m => m.GetPropertiesOfSecrets(default)).Returns(new MockPageable(pagesOfProperties));
mock.Setup(m => m.GetPropertiesOfSecretsAsync(default)).Returns(new MockAsyncPageable(pagesOfProperties));

foreach (var page in pages)
Expand All @@ -49,6 +50,24 @@ private void SetPages(Mock<SecretClient> mock, Func<string, Task> getSecretCallb
}
}

private class MockPageable : Pageable<SecretProperties>
{
private readonly SecretProperties[][] _pages;

public MockPageable(SecretProperties[][] pages)
{
_pages = pages;
}

public override IEnumerable<Page<SecretProperties>> AsPages(string continuationToken = null, int? pageSizeHint = null)
{
foreach (var page in _pages)
{
yield return Page<SecretProperties>.FromValues(page, null, Mock.Of<Response>());
}
}
}

private class MockAsyncPageable : AsyncPageable<SecretProperties>
{
private readonly SecretProperties[][] _pages;
Expand Down Expand Up @@ -353,7 +372,7 @@ public async Task SupportsReloadOnEnabledChange()
await provider.Wait();

SetPages(client,
new[]
new[]
{
CreateSecret("Secret1", "Value2"),
CreateSecret("Secret2", "Value2", enabled: false)
Expand Down
Loading

0 comments on commit 2945277

Please sign in to comment.