diff --git a/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs b/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs index 315797f8d..46cbbea1a 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs @@ -64,7 +64,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder var synchronizer = configuration.Resolver.Resolve(); if (synchronizer == null) { - synchronizer = new AccessKeySynchronizer(loggerFactory); + synchronizer = new AccessKeySynchronizer(); configuration.Resolver.Register(typeof(IAccessKeySynchronizer), () => synchronizer); } diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs index 679d71b1b..c87a707ad 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs @@ -8,8 +8,6 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.Azure.SignalR; @@ -17,26 +15,19 @@ internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposabl { private readonly ConcurrentDictionary _keyMap = new(ReferenceEqualityComparer.Instance); - private readonly ILogger _logger; - - private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1)); + private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.FromMinutes(1), TimeSpan.FromMinutes(1)); internal IEnumerable InitializedKeyList => _keyMap.Where(x => x.Key.Initialized).Select(x => x.Key); - public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true) - { - } - /// - /// Test only. + /// Test only /// - internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start) + /// + internal int Count => _keyMap.Count; + + public AccessKeySynchronizer() { - if (start) - { - _ = UpdateAllAccessKeyAsync(); - } - _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); + _ = UpdateAllAccessKeyTask(); } public void AddServiceEndpoint(ServiceEndpoint endpoint) @@ -65,13 +56,23 @@ public void UpdateServiceEndpoints(IEnumerable endpoints) /// internal bool ContainsKey(ServiceEndpoint e) => _keyMap.ContainsKey(e.AccessKey as MicrosoftEntraAccessKey); - /// - /// Test only - /// - /// - internal int Count() => _keyMap.Count; + internal void UpdateAllAccessKey() + { + foreach (var key in InitializedKeyList) + { + if (key.IsActive) + { + var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout); + _ = key.UpdateAccessKeyAsync(source.Token); + } + else + { + _keyMap.TryRemove(key, out _); + } + } + } - private async Task UpdateAllAccessKeyAsync() + private async Task UpdateAllAccessKeyTask() { using (_timer) { @@ -79,11 +80,7 @@ private async Task UpdateAllAccessKeyAsync() while (await _timer) { - foreach (var key in InitializedKeyList) - { - var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout); - _ = key.UpdateAccessKeyAsync(source.Token); - } + UpdateAllAccessKey(); } } } diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs index 76fe71afb..dcbc51684 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs @@ -50,15 +50,19 @@ internal class MicrosoftEntraAccessKey : IAccessKey private DateTime _updateAt = DateTime.MinValue; + private DateTime _lastUsedAt = DateTime.UtcNow; + private volatile string? _kid; private volatile byte[]? _keyBytes; public bool Initialized => _initializedTcs.Task.IsCompleted; + public bool IsActive => _lastUsedAt > DateTime.UtcNow - AccessKeyExpireTime; + public bool Available { - get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime; + get => _isAuthorized && _updateAt > DateTime.UtcNow - AccessKeyExpireTime; private set { @@ -124,6 +128,8 @@ public async Task GenerateAccessTokenAsync(string audience, AccessTokenAlgorithm algorithm, CancellationToken ctoken = default) { + _lastUsedAt = DateTime.UtcNow; + if (!_initializedTcs.Task.IsCompleted) { var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout); @@ -155,6 +161,10 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default) { return; } + else if (!IsActive) + { + return; + } if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle) { diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeySynchronizerFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeySynchronizerFacts.cs index 51521c619..3ec15b842 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeySynchronizerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeySynchronizerFacts.cs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; using Azure.Identity; using Microsoft.Azure.SignalR.Tests.Common; -using Microsoft.Extensions.Logging.Abstractions; using Xunit; namespace Microsoft.Azure.SignalR.Common.Tests.Auth; @@ -11,33 +14,59 @@ namespace Microsoft.Azure.SignalR.Common.Tests.Auth; public class AccessKeySynchronizerFacts { [Fact] - public void AddAndRemoveServiceEndpointsTest() + public async void AddRemoveServiceEndpointTest() { - var synchronizer = GetInstanceForTest(); + var synchronizer = new AccessKeySynchronizer(); + Assert.Equal(0, synchronizer.Count); + + var credential = new DefaultAzureCredential(); + var endpoint = new TestServiceEndpoint(credential); + synchronizer.AddServiceEndpoint(endpoint); + Assert.Equal(1, synchronizer.Count); + + var field = typeof(MicrosoftEntraAccessKey).GetField("_lastUsedAt", BindingFlags.NonPublic | BindingFlags.Instance); + + var key = Assert.IsType(endpoint.AccessKey); + var before = Assert.IsType(field.GetValue(key)); + + var source = new CancellationTokenSource(1000); + await Assert.ThrowsAsync(async () => await key.GenerateAccessTokenAsync("localhost", [], TimeSpan.FromHours(1), AccessTokenAlgorithm.HS256, source.Token)); + var after = Assert.IsType(field.GetValue(key)); + Assert.NotEqual(before, after); + + synchronizer.UpdateAllAccessKey(); + await Task.Delay(TimeSpan.FromSeconds(1)); + Assert.Equal(1, synchronizer.Count); + + key.UpdateAccessKey("foo", "bar"); + field.SetValue(key, DateTime.UtcNow - TimeSpan.FromHours(3)); + synchronizer.UpdateAllAccessKey(); + Assert.Equal(0, synchronizer.Count); + } + + [Fact] + public void HotReloadServiceEndpointTest() + { + var synchronizer = new AccessKeySynchronizer(); var credential = new DefaultAzureCredential(); var endpoint1 = new TestServiceEndpoint(credential); var endpoint2 = new TestServiceEndpoint(credential); - Assert.Equal(0, synchronizer.Count()); + Assert.Equal(0, synchronizer.Count); synchronizer.UpdateServiceEndpoints([endpoint1]); - Assert.Equal(1, synchronizer.Count()); + Assert.Equal(1, synchronizer.Count); synchronizer.UpdateServiceEndpoints([endpoint1, endpoint2]); Assert.Empty(synchronizer.InitializedKeyList); - Assert.Equal(2, synchronizer.Count()); + Assert.Equal(2, synchronizer.Count); Assert.True(synchronizer.ContainsKey(endpoint1)); Assert.True(synchronizer.ContainsKey(endpoint2)); synchronizer.UpdateServiceEndpoints([endpoint2]); - Assert.Equal(1, synchronizer.Count()); + Assert.Equal(1, synchronizer.Count); synchronizer.UpdateServiceEndpoints([]); - Assert.Equal(0, synchronizer.Count()); + Assert.Equal(0, synchronizer.Count); Assert.Empty(synchronizer.InitializedKeyList); } - - private static AccessKeySynchronizer GetInstanceForTest() - { - return new AccessKeySynchronizer(NullLoggerFactory.Instance, false); - } } diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs index f5582283f..621f9661f 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs @@ -25,7 +25,7 @@ public DefaultClientInvocationManager() NullLogger.Instance); var loggerFactory = new NullLoggerFactory(); var serviceEndpointManager = new ServiceEndpointManager( - new AccessKeySynchronizer(loggerFactory), + new AccessKeySynchronizer(), new TestOptionsMonitor(), loggerFactory );