Skip to content

Commit

Permalink
Refresh AccessKey passively
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Dec 6, 2024
1 parent a298c6b commit 3df417d
Show file tree
Hide file tree
Showing 14 changed files with 191 additions and 334 deletions.
9 changes: 1 addition & 8 deletions src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServerNameProvider), () => serverNameProvider);
}

var synchronizer = configuration.Resolver.Resolve<IAccessKeySynchronizer>();
if (synchronizer == null)
{
synchronizer = new AccessKeySynchronizer(loggerFactory);
configuration.Resolver.Register(typeof(IAccessKeySynchronizer), () => synchronizer);
}

var endpoint = new ServiceEndpointManager(synchronizer, options, loggerFactory);
var endpoint = new ServiceEndpointManager(options, loggerFactory);
configuration.Resolver.Register(typeof(IServiceEndpointManager), () => endpoint);

var requestIdProvider = configuration.Resolver.Resolve<IConnectionRequestIdProvider>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase
{
private readonly ServiceOptions _options;

private readonly IAccessKeySynchronizer _synchronizer;

public ServiceEndpointManager(IAccessKeySynchronizer synchronizer,
ServiceOptions options,
public ServiceEndpointManager(ServiceOptions options,
ILoggerFactory loggerFactory) :
base(options,
loggerFactory?.CreateLogger<ServiceEndpointManager>())
{
_options = options;
_synchronizer = synchronizer;
}

public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint)
Expand All @@ -27,7 +23,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end
{
return null;
}
_synchronizer.AddServiceEndpoint(endpoint);
return new ServiceEndpointProvider(endpoint, _options);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ public ServiceEndpointProvider(ServiceEndpoint endpoint, ServiceOptions options)
public Task<string> GenerateClientAccessTokenAsync(string hubName = null, IEnumerable<Claim> claims = null, TimeSpan? lifetime = null)
{
var audience = $"{_audienceBaseUrl}{ClientPath}";

return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
}

Expand Down
11 changes: 5 additions & 6 deletions src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ internal class LocalTokenProvider : IAccessTokenProvider

private readonly IEnumerable<Claim> _claims;

public LocalTokenProvider(
AccessKey accessKey,
string audience,
IEnumerable<Claim> claims,
AccessTokenAlgorithm algorithm = AccessTokenAlgorithm.HS256,
TimeSpan? tokenLifetime = null)
public LocalTokenProvider(AccessKey accessKey,
string audience,
IEnumerable<Claim> claims,
AccessTokenAlgorithm algorithm = AccessTokenAlgorithm.HS256,
TimeSpan? tokenLifetime = null)
{
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
_algorithm = algorithm;
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55);

private static readonly TimeSpan GetAccessKeyIntervalWhenUnauthorized = TimeSpan.FromMinutes(5);
private static readonly TimeSpan GetAccessKeyIntervalUnavailable = TimeSpan.FromMinutes(5);

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

Expand All @@ -56,6 +56,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool NeedRefresh => DateTime.UtcNow - _updateAt > (Available ? GetAccessKeyInterval : GetAccessKeyIntervalUnavailable);

public bool Available
{
get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime;
Expand Down Expand Up @@ -119,17 +121,31 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
if (!_initializedTcs.Task.IsCompleted)
if (!_initializedTcs.Task.IsCompleted || NeedRefresh)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = UpdateAccessKeyAsync(source.Token);
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");

return Available
? AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm)
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException), LastException);
if (Available)
{
return AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm);
}
else
{
while (true)
{
if (_updateState == UpdateTaskIdle)
{
return Available
? AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm)
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException), LastException);
}
await Task.Delay(100, ctoken);
}
}
}

internal void UpdateAccessKey(string kid, string keyStr)
Expand All @@ -141,16 +157,6 @@ internal void UpdateAccessKey(string kid, string keyStr)

internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
{
var delta = DateTime.UtcNow - _updateAt;
if (Available && delta < GetAccessKeyInterval)
{
return;
}
else if (!Available && delta < GetAccessKeyIntervalWhenUnauthorized)
{
return;
}

if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
{
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ public Task<string> Generate(string audience, TimeSpan? lifetime = null)
{
return key.GetMicrosoftEntraTokenAsync();
}

return _accessKey.GenerateAccessTokenAsync(audience,
_claims,
lifetime ?? Constants.Periods.DefaultAccessTokenLifetime,
DefaultAlgorithm);
var time = lifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
return _accessKey.GenerateAccessTokenAsync(audience, _claims, time, DefaultAlgorithm);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton(typeof(AzureSignalRMarkerService))
.AddSingleton<IClientConnectionFactory, ClientConnectionFactory>()
.AddSingleton<IHostedService, HeartBeat>()
.AddSingleton<IAccessKeySynchronizer, AccessKeySynchronizer>()
.AddSingleton(typeof(NegotiateHandler<>));

// If a custom router is added, do not add the default router
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase

private readonly TimeSpan _scaleTimeout;

private readonly IAccessKeySynchronizer _synchronizer;

public ServiceEndpointManager(IAccessKeySynchronizer synchronizer,
IOptionsMonitor<ServiceOptions> optionsMonitor,
public ServiceEndpointManager(IOptionsMonitor<ServiceOptions> optionsMonitor,
ILoggerFactory loggerFactory) :
base(optionsMonitor.CurrentValue, loggerFactory.CreateLogger<ServiceEndpointManager>())
{
_options = optionsMonitor.CurrentValue;
_logger = loggerFactory?.CreateLogger<ServiceEndpointManager>() ?? throw new ArgumentNullException(nameof(loggerFactory));
_synchronizer = synchronizer;

optionsMonitor.OnChange(OnChange);
_scaleTimeout = _options.ServiceScaleTimeout;
Expand All @@ -40,8 +36,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end
{
return null;
}

_synchronizer.AddServiceEndpoint(endpoint);
return new ServiceEndpointProvider(endpoint, _options);
}

Expand All @@ -53,7 +47,6 @@ private void OnChange(ServiceOptions options)

private Task ReloadServiceEndpointsAsync(IEnumerable<ServiceEndpoint> serviceEndpoints)
{
_synchronizer.UpdateServiceEndpoints(serviceEndpoints);
return ReloadServiceEndpointsAsync(serviceEndpoints, _scaleTimeout);
}

Expand Down

This file was deleted.

Loading

0 comments on commit 3df417d

Please sign in to comment.