Skip to content

Commit

Permalink
Refresh AccessKey passively
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Dec 10, 2024
1 parent 979f193 commit da60b72
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 368 deletions.
9 changes: 1 addition & 8 deletions src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServerNameProvider), () => serverNameProvider);
}

var synchronizer = configuration.Resolver.Resolve<IAccessKeySynchronizer>();
if (synchronizer == null)
{
synchronizer = new AccessKeySynchronizer(loggerFactory);
configuration.Resolver.Register(typeof(IAccessKeySynchronizer), () => synchronizer);
}

var endpoint = new ServiceEndpointManager(synchronizer, options, loggerFactory);
var endpoint = new ServiceEndpointManager(options, loggerFactory);
configuration.Resolver.Register(typeof(IServiceEndpointManager), () => endpoint);

var requestIdProvider = configuration.Resolver.Resolve<IConnectionRequestIdProvider>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase
{
private readonly ServiceOptions _options;

private readonly IAccessKeySynchronizer _synchronizer;

public ServiceEndpointManager(IAccessKeySynchronizer synchronizer,
ServiceOptions options,
public ServiceEndpointManager(ServiceOptions options,
ILoggerFactory loggerFactory) :
base(options,
loggerFactory?.CreateLogger<ServiceEndpointManager>())
{
_options = options;
_synchronizer = synchronizer;
}

public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint)
Expand All @@ -27,7 +23,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end
{
return null;
}
_synchronizer.AddServiceEndpoint(endpoint);
return new ServiceEndpointProvider(endpoint, _options);
}
}

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ internal class MicrosoftEntraAccessKey : IAccessKey
{
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);

private const int UpdateTaskIdle = 0;

private const int UpdateTaskRunning = 1;

private const int GetAccessKeyMaxRetryTimes = 3;

private const int GetMicrosoftEntraTokenMaxRetryTimes = 3;

private readonly object _lock = new object();

private volatile TaskCompletionSource<bool> _updateTaskSource;

private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { Constants.AsrsDefaultScope });

private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55);
Expand All @@ -40,12 +40,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

private readonly TaskCompletionSource<object?> _initializedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly IHttpClientFactory _httpClientFactory;

private volatile int _updateState = 0;

private volatile bool _isAuthorized = false;

private DateTime _updateAt = DateTime.MinValue;
Expand All @@ -54,8 +50,6 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private volatile byte[]? _keyBytes;

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool NeedRefresh => DateTime.UtcNow - _updateAt > (Available ? GetAccessKeyInterval : GetAccessKeyIntervalUnavailable);

public bool Available
Expand All @@ -66,11 +60,10 @@ private set
{
if (value)
{
LastException = null;
LastException = new Exception("The access key has expired.");
}
_updateAt = DateTime.UtcNow;
_isAuthorized = value;
_initializedTcs.TrySetResult(null);
}
}

Expand All @@ -80,7 +73,7 @@ private set

public byte[] KeyBytes => _keyBytes ?? throw new ArgumentNullException(nameof(KeyBytes));

internal Exception? LastException { get; private set; }
internal Exception LastException { get; private set; } = new Exception("The access key has not been initialized.");

internal string GetAccessKeyUrl { get; }

Expand All @@ -95,6 +88,9 @@ public MicrosoftEntraAccessKey(Uri serverEndpoint,
TokenCredential = credential;

_httpClientFactory = httpClientFactory ?? HttpClientFactory.Instance;

_updateTaskSource = new(TaskCreationOptions.RunContinuationsAsynchronously);
_updateTaskSource.TrySetResult(false);
}

public virtual async Task<string> GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default)
Expand All @@ -121,13 +117,22 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
if (!_initializedTcs.Task.IsCompleted)
var updateTask = Task.CompletedTask;
if (NeedRefresh)
{
_ = UpdateAccessKeyAsync();
updateTask = UpdateAccessKeyAsync();
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");

if (!Available)
{
try
{
await updateTask.OrCancelAsync(ctoken);
}
catch (OperationCanceledException)
{
}
}
return Available
? AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm)
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException), LastException);
Expand All @@ -138,27 +143,41 @@ internal void UpdateAccessKey(string kid, string keyStr)
_keyBytes = Encoding.UTF8.GetBytes(keyStr);
_kid = kid;
Available = true;
}

internal async Task UpdateAccessKeyAsync()
{
if (!NeedRefresh)
lock (_lock)
{
return;
_updateTaskSource.TrySetResult(true);
}
}

if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
internal async Task UpdateAccessKeyAsync()
{
TaskCompletionSource<bool> tcs;
lock (_lock)
{
return;
if (!_updateTaskSource.Task.IsCompleted)
{
tcs = _updateTaskSource;
}
else
{
_updateTaskSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
_ = UpdateAccessKeyInternalAsync(_updateTaskSource);
tcs = _updateTaskSource;
}
}
await tcs.Task;
}

private async Task UpdateAccessKeyInternalAsync(TaskCompletionSource<bool> tcs)
{
for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
var source = new CancellationTokenSource(GetAccessKeyTimeout);
using var source = new CancellationTokenSource(GetAccessKeyTimeout);
try
{
await UpdateAccessKeyInternalAsync(source.Token);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
tcs.TrySetResult(true);
return;
}
catch (OperationCanceledException e)
Expand All @@ -168,14 +187,7 @@ internal async Task UpdateAccessKeyAsync()
catch (Exception e)
{
LastException = e;
try
{
await Task.Delay(GetAccessKeyRetryInterval); // retry after interval.
}
catch (OperationCanceledException)
{
break;
}
await Task.Delay(GetAccessKeyRetryInterval); // retry after interval.
}
}

Expand All @@ -184,15 +196,15 @@ internal async Task UpdateAccessKeyAsync()
// Update the status only when it becomes "not available" due to expiration to refresh updateAt.
Available = false;
}
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
tcs.TrySetResult(false);
}

private static string GetExceptionMessage(Exception? exception)
private static string GetExceptionMessage(Exception exception)
{
return exception switch
{
AzureSignalRUnauthorizedException => AzureSignalRUnauthorizedException.ErrorMessageMicrosoftEntra,
_ => exception?.Message ?? "The access key has expired.",
_ => exception.Message,
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ public Task<string> Generate(string audience, TimeSpan? lifetime = null)
{
return key.GetMicrosoftEntraTokenAsync();
}

return _accessKey.GenerateAccessTokenAsync(audience,
_claims,
lifetime ?? Constants.Periods.DefaultAccessTokenLifetime,
DefaultAlgorithm);
var time = lifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
return _accessKey.GenerateAccessTokenAsync(audience, _claims, time, DefaultAlgorithm);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton(typeof(AzureSignalRMarkerService))
.AddSingleton<IClientConnectionFactory, ClientConnectionFactory>()
.AddSingleton<IHostedService, HeartBeat>()
.AddSingleton<IAccessKeySynchronizer, AccessKeySynchronizer>()
.AddSingleton(typeof(NegotiateHandler<>));

// If a custom router is added, do not add the default router
Expand Down
Loading

0 comments on commit da60b72

Please sign in to comment.