Skip to content

Commit

Permalink
remove inactive keys from sychronizer
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Dec 4, 2024
1 parent 64b9805 commit 32116fa
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
var synchronizer = configuration.Resolver.Resolve<IAccessKeySynchronizer>();
if (synchronizer == null)
{
synchronizer = new AccessKeySynchronizer(loggerFactory);
synchronizer = new AccessKeySynchronizer();
configuration.Resolver.Register(typeof(IAccessKeySynchronizer), () => synchronizer);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,26 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

namespace Microsoft.Azure.SignalR;

internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable
{
private readonly ConcurrentDictionary<MicrosoftEntraAccessKey, bool> _keyMap = new(ReferenceEqualityComparer.Instance);

private readonly ILogger<AccessKeySynchronizer> _logger;

private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));
private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.FromMinutes(1), TimeSpan.FromMinutes(1));

internal IEnumerable<MicrosoftEntraAccessKey> InitializedKeyList => _keyMap.Where(x => x.Key.Initialized).Select(x => x.Key);

public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true)
{
}

/// <summary>
/// Test only.
/// Test only
/// </summary>
internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start)
/// <returns></returns>
internal int Count => _keyMap.Count;

public AccessKeySynchronizer()
{
if (start)
{
_ = UpdateAllAccessKeyAsync();
}
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<AccessKeySynchronizer>();
_ = UpdateAllAccessKeyTask();
}

public void AddServiceEndpoint(ServiceEndpoint endpoint)
Expand Down Expand Up @@ -65,25 +56,31 @@ public void UpdateServiceEndpoints(IEnumerable<ServiceEndpoint> endpoints)
/// <returns></returns>
internal bool ContainsKey(ServiceEndpoint e) => _keyMap.ContainsKey(e.AccessKey as MicrosoftEntraAccessKey);

/// <summary>
/// Test only
/// </summary>
/// <returns></returns>
internal int Count() => _keyMap.Count;
internal void UpdateAllAccessKey()
{
foreach (var key in InitializedKeyList)
{
if (key.IsActive)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = key.UpdateAccessKeyAsync(source.Token);
}
else
{
_keyMap.TryRemove(key, out _);
}
}
}

private async Task UpdateAllAccessKeyAsync()
private async Task UpdateAllAccessKeyTask()
{
using (_timer)
{
_timer.Start();

while (await _timer)
{
foreach (var key in InitializedKeyList)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = key.UpdateAccessKeyAsync(source.Token);
}
UpdateAllAccessKey();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,19 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private DateTime _updateAt = DateTime.MinValue;

private DateTime _lastUsedAt = DateTime.UtcNow;

private volatile string? _kid;

private volatile byte[]? _keyBytes;

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool IsActive => _lastUsedAt > DateTime.UtcNow - AccessKeyExpireTime;

public bool Available
{
get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime;
get => _isAuthorized && _updateAt > DateTime.UtcNow - AccessKeyExpireTime;

private set
{
Expand Down Expand Up @@ -124,6 +128,8 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
_lastUsedAt = DateTime.UtcNow;

if (!_initializedTcs.Task.IsCompleted)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
Expand Down Expand Up @@ -155,6 +161,10 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
{
return;
}
else if (!IsActive)
{
return;
}

if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,72 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Azure.Identity;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Logging.Abstractions;
using Xunit;

namespace Microsoft.Azure.SignalR.Common.Tests.Auth;

public class AccessKeySynchronizerFacts
{
[Fact]
public void AddAndRemoveServiceEndpointsTest()
public async void AddRemoveServiceEndpointTest()
{
var synchronizer = GetInstanceForTest();
var synchronizer = new AccessKeySynchronizer();
Assert.Equal(0, synchronizer.Count);

var credential = new DefaultAzureCredential();
var endpoint = new TestServiceEndpoint(credential);
synchronizer.AddServiceEndpoint(endpoint);
Assert.Equal(1, synchronizer.Count);

var field = typeof(MicrosoftEntraAccessKey).GetField("_lastUsedAt", BindingFlags.NonPublic | BindingFlags.Instance);

var key = Assert.IsType<MicrosoftEntraAccessKey>(endpoint.AccessKey);
var before = Assert.IsType<DateTime>(field.GetValue(key));

var source = new CancellationTokenSource(1000);
await Assert.ThrowsAsync<TaskCanceledException>(async () => await key.GenerateAccessTokenAsync("localhost", [], TimeSpan.FromHours(1), AccessTokenAlgorithm.HS256, source.Token));
var after = Assert.IsType<DateTime>(field.GetValue(key));
Assert.NotEqual(before, after);

synchronizer.UpdateAllAccessKey();
await Task.Delay(TimeSpan.FromSeconds(1));
Assert.Equal(1, synchronizer.Count);

key.UpdateAccessKey("foo", "bar");
field.SetValue(key, DateTime.UtcNow - TimeSpan.FromHours(3));
synchronizer.UpdateAllAccessKey();
Assert.Equal(0, synchronizer.Count);
}

[Fact]
public void HotReloadServiceEndpointTest()
{
var synchronizer = new AccessKeySynchronizer();

var credential = new DefaultAzureCredential();
var endpoint1 = new TestServiceEndpoint(credential);
var endpoint2 = new TestServiceEndpoint(credential);

Assert.Equal(0, synchronizer.Count());
Assert.Equal(0, synchronizer.Count);
synchronizer.UpdateServiceEndpoints([endpoint1]);
Assert.Equal(1, synchronizer.Count());
Assert.Equal(1, synchronizer.Count);
synchronizer.UpdateServiceEndpoints([endpoint1, endpoint2]);
Assert.Empty(synchronizer.InitializedKeyList);

Assert.Equal(2, synchronizer.Count());
Assert.Equal(2, synchronizer.Count);
Assert.True(synchronizer.ContainsKey(endpoint1));
Assert.True(synchronizer.ContainsKey(endpoint2));

synchronizer.UpdateServiceEndpoints([endpoint2]);
Assert.Equal(1, synchronizer.Count());
Assert.Equal(1, synchronizer.Count);
synchronizer.UpdateServiceEndpoints([]);
Assert.Equal(0, synchronizer.Count());
Assert.Equal(0, synchronizer.Count);
Assert.Empty(synchronizer.InitializedKeyList);
}

private static AccessKeySynchronizer GetInstanceForTest()
{
return new AccessKeySynchronizer(NullLoggerFactory.Instance, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public DefaultClientInvocationManager()
NullLogger<DefaultHubProtocolResolver>.Instance);
var loggerFactory = new NullLoggerFactory();
var serviceEndpointManager = new ServiceEndpointManager(
new AccessKeySynchronizer(loggerFactory),
new AccessKeySynchronizer(),
new TestOptionsMonitor(),
loggerFactory
);
Expand Down

0 comments on commit 32116fa

Please sign in to comment.