Skip to content

Commit

Permalink
remove ctoken from UpdateAccessKeyAsync (#2116)
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan authored Dec 10, 2024
1 parent 6bfabbf commit 979f193
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ private async Task UpdateAllAccessKeyAsync()
{
foreach (var key in InitializedKeyList)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = key.UpdateAccessKeyAsync(source.Token);
_ = key.UpdateAccessKeyAsync();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55);

private static readonly TimeSpan GetAccessKeyIntervalWhenUnauthorized = TimeSpan.FromMinutes(5);
private static readonly TimeSpan GetAccessKeyIntervalUnavailable = TimeSpan.FromMinutes(5);

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

Expand All @@ -56,6 +56,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool NeedRefresh => DateTime.UtcNow - _updateAt > (Available ? GetAccessKeyInterval : GetAccessKeyIntervalUnavailable);

public bool Available
{
get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime;
Expand Down Expand Up @@ -121,8 +123,7 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
{
if (!_initializedTcs.Task.IsCompleted)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = UpdateAccessKeyAsync(source.Token);
_ = UpdateAccessKeyAsync();
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");
Expand All @@ -139,14 +140,9 @@ internal void UpdateAccessKey(string kid, string keyStr)
Available = true;
}

internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
internal async Task UpdateAccessKeyAsync()
{
var delta = DateTime.UtcNow - _updateAt;
if (Available && delta < GetAccessKeyInterval)
{
return;
}
else if (!Available && delta < GetAccessKeyIntervalWhenUnauthorized)
if (!NeedRefresh)
{
return;
}
Expand All @@ -158,15 +154,10 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)

for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
if (ctoken.IsCancellationRequested)
{
break;
}

var source = new CancellationTokenSource(GetAccessKeyTimeout);
try
{
await UpdateAccessKeyInternalAsync(source.Token).OrCancelAsync(ctoken);
await UpdateAccessKeyInternalAsync(source.Token);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
return;
}
Expand All @@ -179,7 +170,7 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
LastException = e;
try
{
await Task.Delay(GetAccessKeyRetryInterval, ctoken); // retry after interval.
await Task.Delay(GetAccessKeyRetryInterval); // retry after interval.
}
catch (OperationCanceledException)
{
Expand Down
2 changes: 0 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ public static class Periods

public const int MaxCustomHandshakeTimeout = 30;

public static readonly TimeSpan DefaultUpdateAccessKeyTimeout = TimeSpan.FromMinutes(2);

public static readonly TimeSpan DefaultAccessTokenLifetime = TimeSpan.FromHours(1);

public static readonly TimeSpan DefaultScaleTimeout = TimeSpan.FromMinutes(5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,26 @@
using System.Threading.Tasks;
using Azure.Core;
using Azure.Identity;
using Moq;
using Xunit;

namespace Microsoft.Azure.SignalR.Common.Tests.Auth;

#nullable enable

[Collection("Auth")]
public class MicrosoftEntraAccessKeyTests
{
private const string DefaultSigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";

private static readonly Uri DefaultEndpoint = new("http://localhost");

public enum TokenType
{
Local,

MicrosoftEntra,
}

[Theory]
[InlineData("https://a.bc", "https://a.bc/api/v1/auth/accessKey")]
[InlineData("https://a.bc:80", "https://a.bc:80/api/v1/auth/accessKey")]
Expand All @@ -36,12 +44,7 @@ public void TestExpectedGetAccessKeyUrl(string endpoint, string expectedGetAcces
[Fact]
public async Task TestUpdateAccessKey()
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception"));
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential());

var audience = "http://localhost/chat";
var claims = Array.Empty<Claim>();
Expand All @@ -66,28 +69,23 @@ public async Task TestUpdateAccessKey()
[InlineData(true, 121, false, true)] // > 120, should set key unauthorized and log the exception
public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int timeElapsed, bool skip, bool hasException)
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception"));
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential())
{
GetAccessKeyRetryInterval = TimeSpan.Zero
};
var isAuthorizedField = typeof(MicrosoftEntraAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance);
isAuthorizedField.SetValue(key, isAuthorized);
Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key));
isAuthorizedField?.SetValue(key, isAuthorized);
Assert.Equal(isAuthorized, Assert.IsType<bool>(isAuthorizedField?.GetValue(key)));

var updateAt = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed);
var updateAtField = typeof(MicrosoftEntraAccessKey).GetField("_updateAt", BindingFlags.NonPublic | BindingFlags.Instance);
updateAtField.SetValue(key, updateAt);
updateAtField?.SetValue(key, updateAt);

var initializedTcsField = typeof(MicrosoftEntraAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance);
var initializedTcs = (TaskCompletionSource<object>)initializedTcsField.GetValue(key);
var initializedTcs = Assert.IsType<TaskCompletionSource<object>>(initializedTcsField?.GetValue(key));

await key.UpdateAccessKeyAsync().OrTimeout(TimeSpan.FromSeconds(30));
var actualUpdateAt = Assert.IsType<DateTime>(updateAtField.GetValue(key));
var actualUpdateAt = Assert.IsType<DateTime>(updateAtField?.GetValue(key));

Assert.Equal(skip && isAuthorized, Assert.IsType<bool>(isAuthorizedField.GetValue(key)));

Expand Down Expand Up @@ -115,12 +113,7 @@ public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int time
[Fact]
public async Task TestInitializeFailed()
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception"));
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential())
{
GetAccessKeyRetryInterval = TimeSpan.Zero
};
Expand All @@ -141,8 +134,7 @@ public async Task TestInitializeFailed()
[Fact]
public async Task TestNotInitialized()
{
var mockCredential = new Mock<TokenCredential>();
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential());

var source = new CancellationTokenSource(TimeSpan.FromSeconds(1));
var exception = await Assert.ThrowsAsync<TaskCanceledException>(
Expand All @@ -155,12 +147,7 @@ public async Task TestNotInitialized()
[ClassData(typeof(NotAuthorizedTestData))]
public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSignalRException e, string expectedErrorMessage)
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(e);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential() { Exception = e })
{
GetAccessKeyRetryInterval = TimeSpan.Zero,
};
Expand All @@ -177,7 +164,7 @@ public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSig
);
Assert.Same(exception.InnerException, e);
Assert.Same(exception.InnerException, key.LastException);
Assert.StartsWith($"TokenCredentialProxy is not available for signing client tokens", exception.Message);
Assert.StartsWith($"{nameof(TestTokenCredential)} is not available for signing client tokens", exception.Message);
Assert.Contains(expectedErrorMessage, exception.Message);

var (kid, accessKey) = ("foo", DefaultSigningKey);
Expand Down Expand Up @@ -232,12 +219,7 @@ public async Task TestLazyLoadAccessKey()
[Fact]
public async Task TestLazyLoadAccessKeyFailed()
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new Exception());
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential())
{
GetAccessKeyRetryInterval = TimeSpan.FromSeconds(1),
};
Expand Down Expand Up @@ -355,7 +337,7 @@ private sealed class TestHttpClient(HttpResponseMessage message) : HttpClient
{
public override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
Assert.Equal("Bearer", request.Headers.Authorization.Scheme);
Assert.Equal("Bearer", request.Headers?.Authorization?.Scheme);
return Task.FromResult(message);
}
}
Expand All @@ -364,11 +346,14 @@ private sealed class TextHttpContent : HttpContent
{
private readonly string _content;

private TextHttpContent(string content) => _content = content;
private TextHttpContent(string content)
{
_content = content;
}

internal static HttpContent From(string content) => new TextHttpContent(content);

protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context)
{
return stream.WriteAsync(Encoding.UTF8.GetBytes(_content)).AsTask();
}
Expand All @@ -380,21 +365,22 @@ protected override bool TryComputeLength(out long length)
}
}

public enum TokenType
private sealed class TestTokenCredential(TokenType? tokenType = null) : TokenCredential
{
Local,
MicrosoftEntra,
}
public Exception? Exception { get; init; }

private sealed class TestTokenCredential(TokenType tokenType) : TokenCredential
{
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
if (Exception != null)
{
throw Exception;
}

var issuer = tokenType switch
{
TokenType.Local => Constants.AsrsTokenIssuer,
TokenType.MicrosoftEntra => "microsoft.com",
_ => throw new NotImplementedException(),
_ => throw new InvalidOperationException(),
};
var key = new AccessKey(DefaultSigningKey);
var token = AuthUtility.GenerateJwtToken(key.KeyBytes, issuer: issuer);
Expand Down

0 comments on commit 979f193

Please sign in to comment.