diff --git a/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs b/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs index 315797f8d..24326b66a 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs @@ -61,14 +61,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder configuration.Resolver.Register(typeof(IServerNameProvider), () => serverNameProvider); } - var synchronizer = configuration.Resolver.Resolve(); - if (synchronizer == null) - { - synchronizer = new AccessKeySynchronizer(loggerFactory); - configuration.Resolver.Register(typeof(IAccessKeySynchronizer), () => synchronizer); - } - - var endpoint = new ServiceEndpointManager(synchronizer, options, loggerFactory); + var endpoint = new ServiceEndpointManager(options, loggerFactory); configuration.Resolver.Register(typeof(IServiceEndpointManager), () => endpoint); var requestIdProvider = configuration.Resolver.Resolve(); diff --git a/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointManager.cs b/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointManager.cs index 95ca88b14..f54155d18 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointManager.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointManager.cs @@ -9,16 +9,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase { private readonly ServiceOptions _options; - private readonly IAccessKeySynchronizer _synchronizer; - - public ServiceEndpointManager(IAccessKeySynchronizer synchronizer, - ServiceOptions options, + public ServiceEndpointManager(ServiceOptions options, ILoggerFactory loggerFactory) : base(options, loggerFactory?.CreateLogger()) { _options = options; - _synchronizer = synchronizer; } public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint) @@ -27,7 +23,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end { return null; } - _synchronizer.AddServiceEndpoint(endpoint); return new ServiceEndpointProvider(endpoint, _options); } } diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs deleted file mode 100644 index c1cfcce41..000000000 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; - -namespace Microsoft.Azure.SignalR; - -internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable -{ - private readonly ConcurrentDictionary _keyMap = new(ReferenceEqualityComparer.Instance); - - private readonly ILogger _logger; - - private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1)); - - internal IEnumerable InitializedKeyList => _keyMap.Where(x => x.Key.Initialized).Select(x => x.Key); - - public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true) - { - } - - /// - /// Test only. - /// - internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start) - { - if (start) - { - _ = UpdateAllAccessKeyAsync(); - } - _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); - } - - public void AddServiceEndpoint(ServiceEndpoint endpoint) - { - if (endpoint.AccessKey is MicrosoftEntraAccessKey key) - { - _keyMap.TryAdd(key, true); - } - } - - public void Dispose() => _timer.Stop(); - - public void UpdateServiceEndpoints(IEnumerable endpoints) - { - _keyMap.Clear(); - foreach (var endpoint in endpoints) - { - AddServiceEndpoint(endpoint); - } - } - - /// - /// Test only - /// - /// - /// - internal bool ContainsKey(ServiceEndpoint e) => _keyMap.ContainsKey(e.AccessKey as MicrosoftEntraAccessKey); - - /// - /// Test only - /// - /// - internal int Count() => _keyMap.Count; - - private async Task UpdateAllAccessKeyAsync() - { - using (_timer) - { - _timer.Start(); - - while (await _timer) - { - foreach (var key in InitializedKeyList) - { - _ = key.UpdateAccessKeyAsync(); - } - } - } - } - - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - internal static readonly ReferenceEqualityComparer Instance = new ReferenceEqualityComparer(); - - private ReferenceEqualityComparer() - { - } - - public bool Equals(MicrosoftEntraAccessKey x, MicrosoftEntraAccessKey y) - { - return ReferenceEquals(x, y); - } - - public int GetHashCode(MicrosoftEntraAccessKey obj) - { - return RuntimeHelpers.GetHashCode(obj); - } - } -} diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/IAccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/IAccessKeySynchronizer.cs deleted file mode 100644 index 0bf4fc255..000000000 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/IAccessKeySynchronizer.cs +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System.Collections.Generic; - -namespace Microsoft.Azure.SignalR; - -internal interface IAccessKeySynchronizer -{ - public void AddServiceEndpoint(ServiceEndpoint endpoint); - - public void UpdateServiceEndpoints(IEnumerable endpoints); -} diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs index 37592c15b..385ae06ae 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs @@ -24,14 +24,14 @@ internal class MicrosoftEntraAccessKey : IAccessKey { internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100); - private const int UpdateTaskIdle = 0; - - private const int UpdateTaskRunning = 1; - private const int GetAccessKeyMaxRetryTimes = 3; private const int GetMicrosoftEntraTokenMaxRetryTimes = 3; + private readonly object _lock = new object(); + + private volatile TaskCompletionSource _updateTaskSource; + private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { Constants.AsrsDefaultScope }); private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55); @@ -40,12 +40,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120); - private readonly TaskCompletionSource _initializedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly IHttpClientFactory _httpClientFactory; - private volatile int _updateState = 0; - private volatile bool _isAuthorized = false; private DateTime _updateAt = DateTime.MinValue; @@ -54,8 +50,6 @@ internal class MicrosoftEntraAccessKey : IAccessKey private volatile byte[]? _keyBytes; - public bool Initialized => _initializedTcs.Task.IsCompleted; - public bool NeedRefresh => DateTime.UtcNow - _updateAt > (Available ? GetAccessKeyInterval : GetAccessKeyIntervalUnavailable); public bool Available @@ -66,11 +60,10 @@ private set { if (value) { - LastException = null; + LastException = new Exception("The access key has expired."); } _updateAt = DateTime.UtcNow; _isAuthorized = value; - _initializedTcs.TrySetResult(null); } } @@ -80,7 +73,7 @@ private set public byte[] KeyBytes => _keyBytes ?? throw new ArgumentNullException(nameof(KeyBytes)); - internal Exception? LastException { get; private set; } + internal Exception LastException { get; private set; } = new Exception("The access key has not been initialized."); internal string GetAccessKeyUrl { get; } @@ -95,6 +88,9 @@ public MicrosoftEntraAccessKey(Uri serverEndpoint, TokenCredential = credential; _httpClientFactory = httpClientFactory ?? HttpClientFactory.Instance; + + _updateTaskSource = new(TaskCreationOptions.RunContinuationsAsynchronously); + _updateTaskSource.TrySetResult(false); } public virtual async Task GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default) @@ -121,13 +117,22 @@ public async Task GenerateAccessTokenAsync(string audience, AccessTokenAlgorithm algorithm, CancellationToken ctoken = default) { - if (!_initializedTcs.Task.IsCompleted) + var updateTask = Task.CompletedTask; + if (NeedRefresh) { - _ = UpdateAccessKeyAsync(); + updateTask = UpdateAccessKeyAsync(); } - await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out."); - + if (!Available) + { + try + { + await updateTask.OrCancelAsync(ctoken); + } + catch (OperationCanceledException) + { + } + } return Available ? AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm) : throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException), LastException); @@ -138,27 +143,41 @@ internal void UpdateAccessKey(string kid, string keyStr) _keyBytes = Encoding.UTF8.GetBytes(keyStr); _kid = kid; Available = true; - } - internal async Task UpdateAccessKeyAsync() - { - if (!NeedRefresh) + lock (_lock) { - return; + _updateTaskSource.TrySetResult(true); } + } - if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle) + internal async Task UpdateAccessKeyAsync() + { + TaskCompletionSource tcs; + lock (_lock) { - return; + if (!_updateTaskSource.Task.IsCompleted) + { + tcs = _updateTaskSource; + } + else + { + _updateTaskSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _ = UpdateAccessKeyInternalAsync(_updateTaskSource); + tcs = _updateTaskSource; + } } + await tcs.Task; + } + private async Task UpdateAccessKeyInternalAsync(TaskCompletionSource tcs) + { for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++) { - var source = new CancellationTokenSource(GetAccessKeyTimeout); + using var source = new CancellationTokenSource(GetAccessKeyTimeout); try { await UpdateAccessKeyInternalAsync(source.Token); - Interlocked.Exchange(ref _updateState, UpdateTaskIdle); + tcs.TrySetResult(true); return; } catch (OperationCanceledException e) @@ -168,14 +187,7 @@ internal async Task UpdateAccessKeyAsync() catch (Exception e) { LastException = e; - try - { - await Task.Delay(GetAccessKeyRetryInterval); // retry after interval. - } - catch (OperationCanceledException) - { - break; - } + await Task.Delay(GetAccessKeyRetryInterval); // retry after interval. } } @@ -184,15 +196,15 @@ internal async Task UpdateAccessKeyAsync() // Update the status only when it becomes "not available" due to expiration to refresh updateAt. Available = false; } - Interlocked.Exchange(ref _updateState, UpdateTaskIdle); + tcs.TrySetResult(false); } - private static string GetExceptionMessage(Exception? exception) + private static string GetExceptionMessage(Exception exception) { return exception switch { AzureSignalRUnauthorizedException => AzureSignalRUnauthorizedException.ErrorMessageMicrosoftEntra, - _ => exception?.Message ?? "The access key has expired.", + _ => exception.Message, }; } diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs index 5726c6ce9..d8d60f96f 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs @@ -36,10 +36,7 @@ public Task Generate(string audience, TimeSpan? lifetime = null) { return key.GetMicrosoftEntraTokenAsync(); } - - return _accessKey.GenerateAccessTokenAsync(audience, - _claims, - lifetime ?? Constants.Periods.DefaultAccessTokenLifetime, - DefaultAlgorithm); + var time = lifetime ?? Constants.Periods.DefaultAccessTokenLifetime; + return _accessKey.GenerateAccessTokenAsync(audience, _claims, time, DefaultAlgorithm); } } diff --git a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs index 2e4b699b7..48a4fdacc 100644 --- a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs +++ b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs @@ -105,7 +105,6 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil .AddSingleton(typeof(AzureSignalRMarkerService)) .AddSingleton() .AddSingleton() - .AddSingleton() .AddSingleton(typeof(NegotiateHandler<>)); // If a custom router is added, do not add the default router diff --git a/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointManager.cs b/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointManager.cs index 543442791..ccd421c4d 100644 --- a/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointManager.cs +++ b/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointManager.cs @@ -19,16 +19,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase private readonly TimeSpan _scaleTimeout; - private readonly IAccessKeySynchronizer _synchronizer; - - public ServiceEndpointManager(IAccessKeySynchronizer synchronizer, - IOptionsMonitor optionsMonitor, + public ServiceEndpointManager(IOptionsMonitor optionsMonitor, ILoggerFactory loggerFactory) : base(optionsMonitor.CurrentValue, loggerFactory.CreateLogger()) { _options = optionsMonitor.CurrentValue; _logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); - _synchronizer = synchronizer; optionsMonitor.OnChange(OnChange); _scaleTimeout = _options.ServiceScaleTimeout; @@ -40,8 +36,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end { return null; } - - _synchronizer.AddServiceEndpoint(endpoint); return new ServiceEndpointProvider(endpoint, _options); } @@ -53,7 +47,6 @@ private void OnChange(ServiceOptions options) private Task ReloadServiceEndpointsAsync(IEnumerable serviceEndpoints) { - _synchronizer.UpdateServiceEndpoints(serviceEndpoints); return ReloadServiceEndpointsAsync(serviceEndpoints, _scaleTimeout); } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeySynchronizerFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeySynchronizerFacts.cs deleted file mode 100644 index 51521c619..000000000 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeySynchronizerFacts.cs +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using Azure.Identity; -using Microsoft.Azure.SignalR.Tests.Common; -using Microsoft.Extensions.Logging.Abstractions; -using Xunit; - -namespace Microsoft.Azure.SignalR.Common.Tests.Auth; - -public class AccessKeySynchronizerFacts -{ - [Fact] - public void AddAndRemoveServiceEndpointsTest() - { - var synchronizer = GetInstanceForTest(); - - var credential = new DefaultAzureCredential(); - var endpoint1 = new TestServiceEndpoint(credential); - var endpoint2 = new TestServiceEndpoint(credential); - - Assert.Equal(0, synchronizer.Count()); - synchronizer.UpdateServiceEndpoints([endpoint1]); - Assert.Equal(1, synchronizer.Count()); - synchronizer.UpdateServiceEndpoints([endpoint1, endpoint2]); - Assert.Empty(synchronizer.InitializedKeyList); - - Assert.Equal(2, synchronizer.Count()); - Assert.True(synchronizer.ContainsKey(endpoint1)); - Assert.True(synchronizer.ContainsKey(endpoint2)); - - synchronizer.UpdateServiceEndpoints([endpoint2]); - Assert.Equal(1, synchronizer.Count()); - synchronizer.UpdateServiceEndpoints([]); - Assert.Equal(0, synchronizer.Count()); - Assert.Empty(synchronizer.InitializedKeyList); - } - - private static AccessKeySynchronizer GetInstanceForTest() - { - return new AccessKeySynchronizer(NullLoggerFactory.Instance, false); - } -} diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs index aede9744f..5f82e19aa 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs @@ -7,6 +7,7 @@ using System.Reflection; using System.Security.Claims; using System.Text; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Azure.Core; @@ -22,8 +23,14 @@ public class MicrosoftEntraAccessKeyTests { private const string DefaultSigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + private const string DefaultToken = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + + private const string DefaultAudience = "https://localhost"; + private static readonly Uri DefaultEndpoint = new("http://localhost"); + private static readonly FieldInfo? UpdateAtField = typeof(MicrosoftEntraAccessKey).GetField("_updateAt", BindingFlags.Instance | BindingFlags.NonPublic); + public enum TokenType { Local, @@ -46,103 +53,66 @@ public async Task TestUpdateAccessKey() { var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()); - var audience = "http://localhost/chat"; - var claims = Array.Empty(); - var lifetime = TimeSpan.FromHours(1); - var algorithm = AccessTokenAlgorithm.HS256; - var (kid, accessKey) = ("foo", DefaultSigningKey); key.UpdateAccessKey(kid, accessKey); - var token = await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm); + var token = await key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromHours(1), AccessTokenAlgorithm.HS256); Assert.NotNull(token); + Assert.True(TokenUtilities.TryParseIssuer(token, out var issuer) && string.Equals(Constants.AsrsTokenIssuer, issuer)); } - [Theory] - [InlineData(false, 1, true, false)] - [InlineData(false, 4, true, false)] - [InlineData(false, 6, false, true)] // > 5, should try update when unauthorized - [InlineData(true, 1, true, false)] - [InlineData(true, 54, true, false)] - [InlineData(true, 56, true, true)] // > 55, should try update and log the exception - [InlineData(true, 119, true, true)] // > 55, should try update and log the exception - [InlineData(true, 121, false, true)] // > 120, should set key unauthorized and log the exception - public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int timeElapsed, bool skip, bool hasException) + [Fact] + public async Task TestInitializeFailed() { var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()) { GetAccessKeyRetryInterval = TimeSpan.Zero }; - var isAuthorizedField = typeof(MicrosoftEntraAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); - isAuthorizedField?.SetValue(key, isAuthorized); - Assert.Equal(isAuthorized, Assert.IsType(isAuthorizedField?.GetValue(key))); - var updateAt = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed); - var updateAtField = typeof(MicrosoftEntraAccessKey).GetField("_updateAt", BindingFlags.NonPublic | BindingFlags.Instance); - updateAtField?.SetValue(key, updateAt); - - var initializedTcsField = typeof(MicrosoftEntraAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance); - var initializedTcs = Assert.IsType>(initializedTcsField?.GetValue(key)); - - await key.UpdateAccessKeyAsync().OrTimeout(TimeSpan.FromSeconds(30)); - var actualUpdateAt = Assert.IsType(updateAtField?.GetValue(key)); - - Assert.Equal(skip && isAuthorized, Assert.IsType(isAuthorizedField.GetValue(key))); - - if (skip) - { - Assert.Equal(updateAt, actualUpdateAt); - Assert.False(initializedTcs.Task.IsCompleted); - } - else - { - Assert.True(updateAt < actualUpdateAt); - Assert.True(initializedTcs.Task.IsCompleted); - } + await key.UpdateAccessKeyAsync(); - if (hasException) - { - Assert.NotNull(key.LastException); - } - else - { - Assert.Null(key.LastException); - } + var task = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromHours(1), AccessTokenAlgorithm.HS256); + var exception = await Assert.ThrowsAsync(async () => await task); + Assert.IsType(exception.InnerException); } [Fact] - public async Task TestInitializeFailed() + public async Task TestNotInitailized() { - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()) - { - GetAccessKeyRetryInterval = TimeSpan.Zero - }; + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential(delay: 10000)); + Assert.False(key.Available); - var audience = "http://localhost/chat"; - var claims = Array.Empty(); - var lifetime = TimeSpan.FromHours(1); - var algorithm = AccessTokenAlgorithm.HS256; + var source = new CancellationTokenSource(TimeSpan.FromSeconds(1)); + var task1 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromHours(1), AccessTokenAlgorithm.HS256, source.Token); - await key.UpdateAccessKeyAsync(); + Assert.Equal(task1, await Task.WhenAny(task1, Task.Delay(5000))); - var exception = await Assert.ThrowsAsync( - async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) - ); - Assert.IsType(exception.InnerException); + var exception = await Assert.ThrowsAsync(async () => await task1); + Assert.Contains("is not available for signing client tokens", exception.Message); + Assert.Contains("has not been initialized.", exception.Message); } [Fact] - public async Task TestNotInitialized() + public async Task TestUnavailable() { - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential(delay: 10000)); + key.UpdateAccessKey("foo", "bar"); + Assert.True(key.Available); + + UpdateAtField?.SetValue(key, DateTime.UtcNow - TimeSpan.FromHours(3)); + Assert.False(key.Available); var source = new CancellationTokenSource(TimeSpan.FromSeconds(1)); - var exception = await Assert.ThrowsAsync( - async () => await key.GenerateAccessTokenAsync("", [], TimeSpan.FromSeconds(1), AccessTokenAlgorithm.HS256, source.Token) - ); - Assert.Contains("initialization timed out", exception.Message); + var task1 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromHours(1), AccessTokenAlgorithm.HS256, source.Token); + + Assert.Equal(task1, await Task.WhenAny(task1, Task.Delay(5000))); + + var exception = await Assert.ThrowsAsync(async () => await task1); + Assert.Contains("is not available for signing client tokens", exception.Message); + Assert.Contains("has expired.", exception.Message); } + [Theory] [ClassData(typeof(NotAuthorizedTestData))] public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSignalRException e, string expectedErrorMessage) @@ -152,16 +122,10 @@ public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSig GetAccessKeyRetryInterval = TimeSpan.Zero, }; - var audience = "http://localhost/chat"; - var claims = Array.Empty(); - var lifetime = TimeSpan.FromHours(1); - var algorithm = AccessTokenAlgorithm.HS256; - await key.UpdateAccessKeyAsync(); - var exception = await Assert.ThrowsAsync( - async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) - ); + var task = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromHours(1), AccessTokenAlgorithm.HS256); + var exception = await Assert.ThrowsAsync(async () => await task); Assert.Same(exception.InnerException, e); Assert.Same(exception.InnerException, key.LastException); Assert.StartsWith($"{nameof(TestTokenCredential)} is not available for signing client tokens", exception.Message); @@ -169,7 +133,7 @@ public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSig var (kid, accessKey) = ("foo", DefaultSigningKey); key.UpdateAccessKey(kid, accessKey); - Assert.Null(key.LastException); + Assert.Contains("has expired", key.LastException.Message); } [Theory] @@ -197,9 +161,11 @@ public async Task TestUpdateAccessKeySendRequest(string expectedKeyStr) [Fact] public async Task TestLazyLoadAccessKey() { - var expectedKeyStr = DefaultSigningKey; - var expectedKid = "foo"; - var text = "{" + string.Format("\"AccessKey\": \"{0}\", \"KeyId\": \"{1}\"", expectedKeyStr, expectedKid) + "}"; + var text = JsonSerializer.Serialize(new AccessKeyResponse() + { + AccessKey = DefaultSigningKey, + KeyId = "foo" + }); var httpClientFactory = new TestHttpClientFactory(new HttpResponseMessage(HttpStatusCode.OK) { Content = TextHttpContent.From(text), @@ -208,12 +174,8 @@ public async Task TestLazyLoadAccessKey() var credential = new TestTokenCredential(TokenType.MicrosoftEntra); var key = new MicrosoftEntraAccessKey(DefaultEndpoint, credential, httpClientFactory: httpClientFactory); - Assert.False(key.Initialized); - var token = await key.GenerateAccessTokenAsync("https://localhost", [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); Assert.NotNull(token); - - Assert.True(key.Initialized); } [Fact] @@ -224,15 +186,119 @@ public async Task TestLazyLoadAccessKeyFailed() GetAccessKeyRetryInterval = TimeSpan.FromSeconds(1), }; - Assert.False(key.Initialized); - - var task1 = key.GenerateAccessTokenAsync("https://localhost", [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); + var task1 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); var task2 = key.UpdateAccessKeyAsync(); - Assert.True(task2.IsCompleted); // another task is in progress. + Assert.False(task2.IsCompleted); await Assert.ThrowsAsync(async () => await task1); + await task2; + Assert.False(key.Available); + Assert.False(key.NeedRefresh); + } - Assert.True(key.Initialized); + [Fact] + public async Task TestRefreshAccessKey() + { + var text = JsonSerializer.Serialize(new AccessKeyResponse() + { + AccessKey = DefaultSigningKey, + KeyId = "foo" + }); + var httpClientFactory = new TestHttpClientFactory(new HttpResponseMessage(HttpStatusCode.OK) + { + Content = TextHttpContent.From(text), + }); + + var credential = new TestTokenCredential(TokenType.MicrosoftEntra); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, credential, httpClientFactory: httpClientFactory); + Assert.False(key.Available); + Assert.True(key.NeedRefresh); + + var token = await key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); + Assert.True(TokenUtilities.TryParseIssuer(token, out var issuer)); + Assert.Equal(Constants.AsrsTokenIssuer, issuer); + + Assert.True(key.Available); + Assert.False(key.NeedRefresh); + + UpdateAtField?.SetValue(key, DateTime.UtcNow - TimeSpan.FromMinutes(56)); + Assert.True(key.Available); + Assert.True(key.NeedRefresh); + + Assert.Equal(1, credential.Count); + var task1 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); + var task2 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); + await Task.WhenAll(task1, task2); + Assert.True(TokenUtilities.TryParseIssuer(await task1, out issuer)); + Assert.Equal(Constants.AsrsTokenIssuer, issuer); + + Assert.True(key.Available); + Assert.False(key.NeedRefresh); + Assert.Equal(2, credential.Count); + } + + [Fact] + public async Task TestRefreshAccessKeyUnauthorized() + { + var text = JsonSerializer.Serialize(new AccessKeyResponse() + { + AccessKey = DefaultSigningKey, + KeyId = "foo" + }); + var httpClientFactory = new TestHttpClientFactory(new HttpResponseMessage(HttpStatusCode.OK) + { + Content = TextHttpContent.From(text), + }); + + var credential = new TestTokenCredential(TokenType.MicrosoftEntra) { Exception = new InvalidOperationException() }; + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, credential, httpClientFactory: httpClientFactory) + { + GetAccessKeyRetryInterval = TimeSpan.FromSeconds(1) + }; + Assert.False(key.Available); + Assert.True(key.NeedRefresh); + + await Assert.ThrowsAsync(async () => await key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256)); + + Assert.False(key.Available); + Assert.False(key.NeedRefresh); + + Assert.Equal(9, credential.Count); // GetMicrosoftEntraTokenRetry * GetAccessKeyRetry = 3 * 3 + await Assert.ThrowsAsync(async () => await key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256)); + Assert.Equal(9, credential.Count); // Does not trigger refresh + + // refresh, but still failed. + UpdateAtField?.SetValue(key, DateTime.UtcNow - TimeSpan.FromMinutes(6)); + Assert.False(key.Available); + Assert.True(key.NeedRefresh); + + var task1 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); + var task2 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); + await Assert.ThrowsAsync(async () => await Task.WhenAll(task1, task2)); + + await Assert.ThrowsAsync(async () => await task1); + await Assert.ThrowsAsync(async () => await task2); + + Assert.False(key.Available); + Assert.False(key.NeedRefresh); + Assert.Equal(18, credential.Count); + + // refresh, succeed. + credential.Exception = null; + UpdateAtField?.SetValue(key, DateTime.UtcNow - TimeSpan.FromMinutes(6)); + Assert.False(key.Available); + Assert.True(key.NeedRefresh); + + task1 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); + task2 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256); + await Task.WhenAll(task1, task2); + + Assert.True(TokenUtilities.TryParseIssuer(await task1, out var issuer)); + Assert.Equal(Constants.AsrsTokenIssuer, issuer); + + Assert.True(key.Available); + Assert.False(key.NeedRefresh); + Assert.Equal(19, credential.Count); } [Theory] @@ -365,12 +431,20 @@ protected override bool TryComputeLength(out long length) } } - private sealed class TestTokenCredential(TokenType? tokenType = null) : TokenCredential + private sealed class TestTokenCredential(TokenType? tokenType = null, int delay = 0) : TokenCredential { - public Exception? Exception { get; init; } + public Exception? Exception { get; set; } + + private volatile int _count; + + public Exception Error { get; set; } = new InvalidOperationException(); + + public int Count => _count; public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) { + Interlocked.Increment(ref _count); + if (Exception != null) { throw Exception; @@ -382,14 +456,14 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell TokenType.MicrosoftEntra => "microsoft.com", _ => throw new InvalidOperationException(), }; - var key = new AccessKey(DefaultSigningKey); - var token = AuthUtility.GenerateJwtToken(key.KeyBytes, issuer: issuer); + var token = AuthUtility.GenerateJwtToken(Encoding.UTF8.GetBytes(DefaultSigningKey), issuer: issuer); return new AccessToken(token, DateTimeOffset.UtcNow.Add(TimeSpan.FromHours(1))); } - public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + public override async ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) { - return ValueTask.FromResult(GetToken(requestContext, cancellationToken)); + await Task.Delay(delay, cancellationToken); + return GetToken(requestContext, cancellationToken); } } } diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestAccessKeySynchronizer.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestAccessKeySynchronizer.cs deleted file mode 100644 index 32d451185..000000000 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestAccessKeySynchronizer.cs +++ /dev/null @@ -1,16 +0,0 @@ -using System.Collections.Generic; - -namespace Microsoft.Azure.SignalR.Tests.Common; - -internal class TestAccessKeySynchronizer : IAccessKeySynchronizer -{ - public static readonly IAccessKeySynchronizer Instance = new TestAccessKeySynchronizer(); - - public void UpdateServiceEndpoints(IEnumerable endpoints) - { - } - - public void AddServiceEndpoint(ServiceEndpoint endpoint) - { - } -} diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs index f5582283f..780576817 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs @@ -25,7 +25,6 @@ public DefaultClientInvocationManager() NullLogger.Instance); var loggerFactory = new NullLoggerFactory(); var serviceEndpointManager = new ServiceEndpointManager( - new AccessKeySynchronizer(loggerFactory), new TestOptionsMonitor(), loggerFactory ); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs index b9f16c9f8..bb5075ba2 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs @@ -9,6 +9,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Azure.Core; using Azure.Identity; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR; @@ -31,6 +32,10 @@ public class ServiceMessageTests : VerifiableLoggedTest private const string MicrosoftEntraConnectionString = "endpoint=https://localhost;authType=aad;"; + private static readonly Uri DefaultEndpoint = new("http://localhost"); + + private const string DefaultAudience = "https://localhost"; + private const string LocalConnectionString = "endpoint=https://localhost;accessKey=" + SigningKey; public ServiceMessageTests(ITestOutputHelper output) : base(output) @@ -205,10 +210,10 @@ public async Task TestAccessKeyRequestMessage(Type keyType) _ = connection.StartAsync(); await connection.ConnectionInitializedTask.OrTimeout(1000); - if (endpoint.AccessKey is TestAadAccessKey aadKey) + if (endpoint.AccessKey is MicrosoftEntraAccessKey) { var message = await connection.ExpectServiceMessage().OrTimeout(3000); - Assert.Equal(aadKey.Token, message.Token); + Assert.True(Guid.TryParse(message.Token, out _)); } else { @@ -221,6 +226,10 @@ public async Task TestAccessKeyRequestMessage(Type keyType) [InlineData(typeof(MicrosoftEntraAccessKey))] public async Task TestAccessKeyResponseMessage(Type keyType) { + var emptyClaims = Array.Empty(); + var lifetime = TimeSpan.FromHours(1); + var algorithm = AccessTokenAlgorithm.HS256; + var endpoint = MockServiceEndpoint(keyType.Name); Assert.IsAssignableFrom(keyType, endpoint.AccessKey); var hubServiceEndpoint = new HubServiceEndpoint("foo", null, endpoint); @@ -230,6 +239,14 @@ public async Task TestAccessKeyResponseMessage(Type keyType) _ = connection.StartAsync(); await connection.ConnectionInitializedTask.OrTimeout(1000); + switch (endpoint.AccessKey) + { + case MicrosoftEntraAccessKey key: + var source = new CancellationTokenSource(3000); + await Assert.ThrowsAsync(async () => await key.GenerateAccessTokenAsync(DefaultAudience, emptyClaims, lifetime, algorithm, source.Token)); + break; + } + var message = new AccessKeyResponseMessage() { Kid = "foo", @@ -237,13 +254,9 @@ public async Task TestAccessKeyResponseMessage(Type keyType) }; await connection.WriteFromServiceAsync(message); - var audience = "http://localhost/chat"; - var claims = Array.Empty(); - var lifetime = TimeSpan.FromHours(1); - var algorithm = AccessTokenAlgorithm.HS256; - - var clientToken = await endpoint.AccessKey.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm).OrTimeout(TimeSpan.FromSeconds(3)); - Assert.NotNull(clientToken); + var clientToken = await endpoint.AccessKey.GenerateAccessTokenAsync(DefaultAudience, emptyClaims, lifetime, algorithm).OrTimeout(3000); + Assert.True(TokenUtilities.TryParseIssuer(clientToken, out var issuer)); + Assert.Equal(Constants.AsrsTokenIssuer, issuer); await connection.StopAsync(); } @@ -344,33 +357,27 @@ private static TestServiceConnection CreateServiceConnection(ConnectionHandler h private ServiceEndpoint MockServiceEndpoint(string keyTypeName) { - switch (keyTypeName) + return keyTypeName switch { - case nameof(AccessKey): - return new ServiceEndpoint(LocalConnectionString); - - case nameof(MicrosoftEntraAccessKey): - var endpoint = new ServiceEndpoint(MicrosoftEntraConnectionString); - var field = typeof(ServiceEndpoint).GetField("_accessKey", BindingFlags.NonPublic | BindingFlags.Instance); - field.SetValue(endpoint, new TestAadAccessKey()); - return endpoint; - - default: - throw new NotImplementedException(); - } + nameof(AccessKey) => new ServiceEndpoint(LocalConnectionString), + nameof(MicrosoftEntraAccessKey) => new ServiceEndpoint(new Uri("http://localhost"), new TestTokenCredential()), + _ => throw new NotImplementedException(), + }; } - private class TestAadAccessKey : MicrosoftEntraAccessKey + private class TestTokenCredential : TokenCredential { public string Token { get; } = Guid.NewGuid().ToString(); - public TestAadAccessKey() : base(new Uri("http://localhost:80"), new DefaultAzureCredential()) + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) { + return new AccessToken(Token, DateTimeOffset.UtcNow.Add(TimeSpan.FromHours(1))); } - public override Task GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default) + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) { - return Task.FromResult(Token); + var token = GetToken(requestContext, cancellationToken); + return new ValueTask(token); } }