diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs index 679d71b1b..c1cfcce41 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs @@ -81,8 +81,7 @@ private async Task UpdateAllAccessKeyAsync() { foreach (var key in InitializedKeyList) { - var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout); - _ = key.UpdateAccessKeyAsync(source.Token); + _ = key.UpdateAccessKeyAsync(); } } } diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs index 21a56a03b..37592c15b 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs @@ -36,7 +36,7 @@ internal class MicrosoftEntraAccessKey : IAccessKey private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55); - private static readonly TimeSpan GetAccessKeyIntervalWhenUnauthorized = TimeSpan.FromMinutes(5); + private static readonly TimeSpan GetAccessKeyIntervalUnavailable = TimeSpan.FromMinutes(5); private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120); @@ -56,6 +56,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey public bool Initialized => _initializedTcs.Task.IsCompleted; + public bool NeedRefresh => DateTime.UtcNow - _updateAt > (Available ? GetAccessKeyInterval : GetAccessKeyIntervalUnavailable); + public bool Available { get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime; @@ -121,8 +123,7 @@ public async Task GenerateAccessTokenAsync(string audience, { if (!_initializedTcs.Task.IsCompleted) { - var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout); - _ = UpdateAccessKeyAsync(source.Token); + _ = UpdateAccessKeyAsync(); } await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out."); @@ -139,14 +140,9 @@ internal void UpdateAccessKey(string kid, string keyStr) Available = true; } - internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default) + internal async Task UpdateAccessKeyAsync() { - var delta = DateTime.UtcNow - _updateAt; - if (Available && delta < GetAccessKeyInterval) - { - return; - } - else if (!Available && delta < GetAccessKeyIntervalWhenUnauthorized) + if (!NeedRefresh) { return; } @@ -158,15 +154,10 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default) for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++) { - if (ctoken.IsCancellationRequested) - { - break; - } - var source = new CancellationTokenSource(GetAccessKeyTimeout); try { - await UpdateAccessKeyInternalAsync(source.Token).OrCancelAsync(ctoken); + await UpdateAccessKeyInternalAsync(source.Token); Interlocked.Exchange(ref _updateState, UpdateTaskIdle); return; } @@ -179,7 +170,7 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default) LastException = e; try { - await Task.Delay(GetAccessKeyRetryInterval, ctoken); // retry after interval. + await Task.Delay(GetAccessKeyRetryInterval); // retry after interval. } catch (OperationCanceledException) { diff --git a/src/Microsoft.Azure.SignalR.Common/Constants.cs b/src/Microsoft.Azure.SignalR.Common/Constants.cs index da4798daa..663c3fbb8 100644 --- a/src/Microsoft.Azure.SignalR.Common/Constants.cs +++ b/src/Microsoft.Azure.SignalR.Common/Constants.cs @@ -44,8 +44,6 @@ public static class Periods public const int MaxCustomHandshakeTimeout = 30; - public static readonly TimeSpan DefaultUpdateAccessKeyTimeout = TimeSpan.FromMinutes(2); - public static readonly TimeSpan DefaultAccessTokenLifetime = TimeSpan.FromHours(1); public static readonly TimeSpan DefaultScaleTimeout = TimeSpan.FromMinutes(5); diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs index 06a165077..aede9744f 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs @@ -11,11 +11,12 @@ using System.Threading.Tasks; using Azure.Core; using Azure.Identity; -using Moq; using Xunit; namespace Microsoft.Azure.SignalR.Common.Tests.Auth; +#nullable enable + [Collection("Auth")] public class MicrosoftEntraAccessKeyTests { @@ -23,6 +24,13 @@ public class MicrosoftEntraAccessKeyTests private static readonly Uri DefaultEndpoint = new("http://localhost"); + public enum TokenType + { + Local, + + MicrosoftEntra, + } + [Theory] [InlineData("https://a.bc", "https://a.bc/api/v1/auth/accessKey")] [InlineData("https://a.bc:80", "https://a.bc:80/api/v1/auth/accessKey")] @@ -36,12 +44,7 @@ public void TestExpectedGetAccessKeyUrl(string endpoint, string expectedGetAcces [Fact] public async Task TestUpdateAccessKey() { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()); var audience = "http://localhost/chat"; var claims = Array.Empty(); @@ -66,28 +69,23 @@ public async Task TestUpdateAccessKey() [InlineData(true, 121, false, true)] // > 120, should set key unauthorized and log the exception public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int timeElapsed, bool skip, bool hasException) { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object) + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()) { GetAccessKeyRetryInterval = TimeSpan.Zero }; var isAuthorizedField = typeof(MicrosoftEntraAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); - isAuthorizedField.SetValue(key, isAuthorized); - Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key)); + isAuthorizedField?.SetValue(key, isAuthorized); + Assert.Equal(isAuthorized, Assert.IsType(isAuthorizedField?.GetValue(key))); var updateAt = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed); var updateAtField = typeof(MicrosoftEntraAccessKey).GetField("_updateAt", BindingFlags.NonPublic | BindingFlags.Instance); - updateAtField.SetValue(key, updateAt); + updateAtField?.SetValue(key, updateAt); var initializedTcsField = typeof(MicrosoftEntraAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance); - var initializedTcs = (TaskCompletionSource)initializedTcsField.GetValue(key); + var initializedTcs = Assert.IsType>(initializedTcsField?.GetValue(key)); await key.UpdateAccessKeyAsync().OrTimeout(TimeSpan.FromSeconds(30)); - var actualUpdateAt = Assert.IsType(updateAtField.GetValue(key)); + var actualUpdateAt = Assert.IsType(updateAtField?.GetValue(key)); Assert.Equal(skip && isAuthorized, Assert.IsType(isAuthorizedField.GetValue(key))); @@ -115,12 +113,7 @@ public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int time [Fact] public async Task TestInitializeFailed() { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object) + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()) { GetAccessKeyRetryInterval = TimeSpan.Zero }; @@ -141,8 +134,7 @@ public async Task TestInitializeFailed() [Fact] public async Task TestNotInitialized() { - var mockCredential = new Mock(); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()); var source = new CancellationTokenSource(TimeSpan.FromSeconds(1)); var exception = await Assert.ThrowsAsync( @@ -155,12 +147,7 @@ public async Task TestNotInitialized() [ClassData(typeof(NotAuthorizedTestData))] public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSignalRException e, string expectedErrorMessage) { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(e); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object) + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential() { Exception = e }) { GetAccessKeyRetryInterval = TimeSpan.Zero, }; @@ -177,7 +164,7 @@ public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSig ); Assert.Same(exception.InnerException, e); Assert.Same(exception.InnerException, key.LastException); - Assert.StartsWith($"TokenCredentialProxy is not available for signing client tokens", exception.Message); + Assert.StartsWith($"{nameof(TestTokenCredential)} is not available for signing client tokens", exception.Message); Assert.Contains(expectedErrorMessage, exception.Message); var (kid, accessKey) = ("foo", DefaultSigningKey); @@ -232,12 +219,7 @@ public async Task TestLazyLoadAccessKey() [Fact] public async Task TestLazyLoadAccessKeyFailed() { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(new Exception()); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object) + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential()) { GetAccessKeyRetryInterval = TimeSpan.FromSeconds(1), }; @@ -355,7 +337,7 @@ private sealed class TestHttpClient(HttpResponseMessage message) : HttpClient { public override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - Assert.Equal("Bearer", request.Headers.Authorization.Scheme); + Assert.Equal("Bearer", request.Headers?.Authorization?.Scheme); return Task.FromResult(message); } } @@ -364,11 +346,14 @@ private sealed class TextHttpContent : HttpContent { private readonly string _content; - private TextHttpContent(string content) => _content = content; + private TextHttpContent(string content) + { + _content = content; + } internal static HttpContent From(string content) => new TextHttpContent(content); - protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) { return stream.WriteAsync(Encoding.UTF8.GetBytes(_content)).AsTask(); } @@ -380,21 +365,22 @@ protected override bool TryComputeLength(out long length) } } - public enum TokenType + private sealed class TestTokenCredential(TokenType? tokenType = null) : TokenCredential { - Local, - MicrosoftEntra, - } + public Exception? Exception { get; init; } - private sealed class TestTokenCredential(TokenType tokenType) : TokenCredential - { public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) { + if (Exception != null) + { + throw Exception; + } + var issuer = tokenType switch { TokenType.Local => Constants.AsrsTokenIssuer, TokenType.MicrosoftEntra => "microsoft.com", - _ => throw new NotImplementedException(), + _ => throw new InvalidOperationException(), }; var key = new AccessKey(DefaultSigningKey); var token = AuthUtility.GenerateJwtToken(key.KeyBytes, issuer: issuer);