diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointType.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointType.cs index d1d7da424..5a00e1d3c 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointType.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointType.cs @@ -1,12 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +namespace Microsoft.Azure.SignalR; -namespace Microsoft.Azure.SignalR +public enum EndpointType { - public enum EndpointType - { - Primary, - Secondary - } + Primary, + + Secondary } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/HubServiceEndpoint.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/HubServiceEndpoint.cs index 0b2b265f0..12a8375cc 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/HubServiceEndpoint.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/HubServiceEndpoint.cs @@ -5,60 +5,60 @@ using System.Threading; using System.Threading.Tasks; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class HubServiceEndpoint : ServiceEndpoint { - internal class HubServiceEndpoint : ServiceEndpoint + private static long s_currentIndex; + + private readonly ServiceEndpoint _endpoint; + + private readonly long _uniqueIndex; + + private TaskCompletionSource _scaleTcs; + + public string Hub { get; } + + public override string Name => _endpoint.Name; + + public IServiceEndpointProvider Provider { get; } + + public IServiceConnectionContainer ConnectionContainer { get; set; } + + /// + /// Task waiting for HubServiceEndpoint turn ready when live add/remove endpoint + /// + public Task ScaleTask => _scaleTcs?.Task ?? Task.CompletedTask; + + public long UniqueIndex => _uniqueIndex; + + // Value here is not accurate. + internal override bool PendingReload => throw new NotSupportedException(); + + public HubServiceEndpoint(string hub, + IServiceEndpointProvider provider, + ServiceEndpoint endpoint) : base(endpoint) + { + Hub = hub; + Provider = provider; + _endpoint = endpoint; + _scaleTcs = endpoint.PendingReload ? new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously) : null; + _uniqueIndex = Interlocked.Increment(ref s_currentIndex); + } + + public void CompleteScale() + { + _scaleTcs?.TrySetResult(true); + } + + // When remove an existing HubServiceEndpoint. + public void ResetScale() + { + _scaleTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + + public override string ToString() { - private readonly ServiceEndpoint _endpoint; - private readonly long _uniqueIndex; - private static long s_currentIndex; - private TaskCompletionSource _scaleTcs; - - public HubServiceEndpoint( - string hub, - IServiceEndpointProvider provider, - ServiceEndpoint endpoint - ) : base(endpoint) - { - Hub = hub; - Provider = provider; - _endpoint = endpoint; - _scaleTcs = endpoint.PendingReload ? new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously) : null; - _uniqueIndex = Interlocked.Increment(ref s_currentIndex); - } - - public string Hub { get; } - - public override string Name => _endpoint.Name; - - public IServiceEndpointProvider Provider { get; } - - public IServiceConnectionContainer ConnectionContainer { get; set; } - - /// - /// Task waiting for HubServiceEndpoint turn ready when live add/remove endpoint - /// - public Task ScaleTask => _scaleTcs?.Task ?? Task.CompletedTask; - - public void CompleteScale() - { - _scaleTcs?.TrySetResult(true); - } - - // When remove an existing HubServiceEndpoint. - public void ResetScale() - { - _scaleTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - } - - public long UniqueIndex => _uniqueIndex; - - public override string ToString() - { - return base.ToString() + $"(hub={Hub})"; - } - - // Value here is not accurate. - internal override bool PendingReload => throw new NotSupportedException(); + return base.ToString() + $"(hub={Hub})"; } } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServerStickyMode.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServerStickyMode.cs index 16557ba17..7b8b3c703 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServerStickyMode.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServerStickyMode.cs @@ -1,30 +1,29 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +/// +/// Specifies the mode for server sticky, when client is always routed to the server which it first /negotiate with, we call it "server sticky mode". +/// +public enum ServerStickyMode { /// - /// Specifies the mode for server sticky, when client is always routed to the server which it first /negotiate with, we call it "server sticky mode". + /// We the server sticky mode is disabled, it picks the server connection by some algorithm + /// In general, local server connection first + /// least client connections routed server connection first /// - public enum ServerStickyMode - { - /// - /// We the server sticky mode is disabled, it picks the server connection by some algorithm - /// In general, local server connection first - /// least client connections routed server connection first - /// - Disabled = 0, + Disabled = 0, - ///// - ///// We will try to find the server it /neogitate with from local, if that server is connected to this runtime instance, we choose that server - ///// Otherwise, we fallback to local existed server - ///// - Preferred = 1, + ///// + ///// We will try to find the server it /neogitate with from local, if that server is connected to this runtime instance, we choose that server + ///// Otherwise, we fallback to local existed server + ///// + Preferred = 1, - /// - /// We will try to find the server it /negotiate with from both local and global route table, it the server is not connected, throw, - /// If it is globally routed, this request will be always globally routed - /// - Required = 2, - } + /// + /// We will try to find the server it /negotiate with from both local and global route table, it the server is not connected, throw, + /// If it is globally routed, this request will be always globally routed + /// + Required = 2, } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpointManagerBase.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpointManagerBase.cs index 2ece219b5..525264d37 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpointManagerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpointManagerBase.cs @@ -11,401 +11,401 @@ using Microsoft.Azure.SignalR.Common; using Microsoft.Extensions.Logging; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal abstract class ServiceEndpointManagerBase : IServiceEndpointManager { - internal abstract class ServiceEndpointManagerBase : IServiceEndpointManager + // Endpoints for negotiation + private readonly ConcurrentDictionary> _endpointsPerHub = new ConcurrentDictionary>(); + + private readonly ILogger _logger; + + // Filtered valuable endpoints from ServiceOptions, use dict for fast search + public IReadOnlyDictionary Endpoints { get; private set; } + + // for test purpose + internal ServiceEndpointManagerBase(IEnumerable endpoints, ILogger logger) { - // Endpoints for negotiation - private readonly ConcurrentDictionary> _endpointsPerHub = new ConcurrentDictionary>(); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - private readonly ILogger _logger; + Endpoints = GetValuableEndpoints(endpoints); + } - // Filtered valuable endpoints from ServiceOptions, use dict for fast search - public IReadOnlyDictionary Endpoints { get; private set; } + protected ServiceEndpointManagerBase(IServiceEndpointOptions options, ILogger logger) + : this(ServiceEndpointUtility.Merge(options.ConnectionString, options.Endpoints), logger) + { + } - public event EndpointEventHandler OnAdd; - public event EndpointEventHandler OnRemove; - - protected ServiceEndpointManagerBase(IServiceEndpointOptions options, ILogger logger) - : this(ServiceEndpointUtility.Merge(options.ConnectionString, options.Endpoints), logger) - { - } + public event EndpointEventHandler OnAdd; + + public event EndpointEventHandler OnRemove; + + public abstract IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint); - // for test purpose - internal ServiceEndpointManagerBase(IEnumerable endpoints, ILogger logger) + public IReadOnlyList GetEndpoints(string hub) + { + return _endpointsPerHub.GetOrAdd(hub, s => Endpoints.Select(e => CreateHubServiceEndpoint(hub, e.Key)).ToArray()); + } + + protected Dictionary GetValuableEndpoints(IEnumerable endpoints) + { + // select the most valuable endpoint with the same endpoint address + var groupedEndpoints = endpoints.Distinct().GroupBy(s => s.Endpoint).Select(s => { - _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + var items = s.ToList(); + if (items.Count == 1) + { + return items[0]; + } - Endpoints = GetValuableEndpoints(endpoints); - } + // By default pick up the primary endpoint, otherwise the first one + var item = items.FirstOrDefault(i => i.EndpointType == EndpointType.Primary) ?? items.FirstOrDefault(); + Log.DuplicateEndpointFound(_logger, items.Count, item?.Endpoint, item?.ToString()); + return item; + }).ToDictionary(k => k, v => v, new ServiceEndpointWeakComparer()); - public abstract IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint); + if (groupedEndpoints.Count == 0) + { + throw new AzureSignalRConfigurationNoEndpointException(); + } - public IReadOnlyList GetEndpoints(string hub) + if (groupedEndpoints.Count > 0 && groupedEndpoints.All(s => s.Key.EndpointType != EndpointType.Primary)) { - return _endpointsPerHub.GetOrAdd(hub, s => Endpoints.Select(e => CreateHubServiceEndpoint(hub, e.Key)).ToArray()); + // Only throws when endpoint count > 0 + throw new AzureSignalRNoPrimaryEndpointException(); } - protected Dictionary GetValuableEndpoints(IEnumerable endpoints) + return groupedEndpoints; + } + + protected virtual async Task ReloadServiceEndpointsAsync(IEnumerable serviceEndpoints, TimeSpan scaleTimeout) + { + try { - // select the most valuable endpoint with the same endpoint address - var groupedEndpoints = endpoints.Distinct().GroupBy(s => s.Endpoint).Select(s => - { - var items = s.ToList(); - if (items.Count == 1) - { - return items[0]; - } + var endpoints = GetValuableEndpoints(serviceEndpoints); + + UpdateEndpoints(endpoints, out var addedEndpoints, out var removedEndpoints); - // By default pick up the primary endpoint, otherwise the first one - var item = items.FirstOrDefault(i => i.EndpointType == EndpointType.Primary) ?? items.FirstOrDefault(); - Log.DuplicateEndpointFound(_logger, items.Count, item?.Endpoint, item?.ToString()); - return item; - }).ToDictionary(k => k, v => v, new ServiceEndpointWeakComparer()); + using var addCts = new CancellationTokenSource(scaleTimeout); - if (groupedEndpoints.Count == 0) + if (!await WaitTaskOrTimeout(AddServiceEndpointsAsync(addedEndpoints, addCts.Token), addCts)) { - throw new AzureSignalRConfigurationNoEndpointException(); + Log.AddEndpointsTimeout(_logger); } - if (groupedEndpoints.Count > 0 && groupedEndpoints.All(s => s.Key.EndpointType != EndpointType.Primary)) + using var removeCts = new CancellationTokenSource(scaleTimeout); + + if (!await WaitTaskOrTimeout(RemoveServiceEndpointsAsync(removedEndpoints, removeCts.Token), removeCts)) { - // Only throws when endpoint count > 0 - throw new AzureSignalRNoPrimaryEndpointException(); + Log.RemoveEndpointsTimeout(_logger); } + } + catch (Exception ex) + { + Log.ReloadEndpointsError(_logger, ex); + return; + } + } - return groupedEndpoints; + private static async Task WaitTaskOrTimeout(Task task, CancellationTokenSource cts) + { + var completed = await Task.WhenAny(task, Task.Delay(Timeout.InfiniteTimeSpan, cts.Token)); + + if (completed == task) + { + return true; } - protected virtual async Task ReloadServiceEndpointsAsync(IEnumerable serviceEndpoints, TimeSpan scaleTimeout) + cts.Cancel(); + return false; + } + + private async Task AddServiceEndpointsAsync(IReadOnlyList endpoints, CancellationToken cancellationToken) + { + if (endpoints.Count > 0) { try { - var endpoints = GetValuableEndpoints(serviceEndpoints); - - UpdateEndpoints(endpoints, out var addedEndpoints, out var removedEndpoints); - - using var addCts = new CancellationTokenSource(scaleTimeout); - - if (!await WaitTaskOrTimeout(AddServiceEndpointsAsync(addedEndpoints, addCts.Token), addCts)) - { - Log.AddEndpointsTimeout(_logger); - } + var hubEndpoints = CreateHubServiceEndpoints(endpoints); - using var removeCts = new CancellationTokenSource(scaleTimeout); + await Task.WhenAll(hubEndpoints.SelectMany(h => h.Value.Select(e => AddHubServiceEndpointAsync(e, cancellationToken)))); - if (!await WaitTaskOrTimeout(RemoveServiceEndpointsAsync(removedEndpoints, removeCts.Token), removeCts)) - { - Log.RemoveEndpointsTimeout(_logger); - } + AddEndpointsToNegotiationStore(hubEndpoints); } catch (Exception ex) { - Log.ReloadEndpointsError(_logger, ex); - return; + Log.FailedAddingEndpoints(_logger, ex); } } + } - private async Task AddServiceEndpointsAsync(IReadOnlyList endpoints, CancellationToken cancellationToken) + private async Task RemoveServiceEndpointsAsync(IReadOnlyList endpoints, CancellationToken cancellationToken) + { + if (endpoints.Count > 0) { - if (endpoints.Count > 0) + try { - try - { - var hubEndpoints = CreateHubServiceEndpoints(endpoints); + var hubEndpoints = UpdateAndGetRemovedHubServiceEndpoints(endpoints); - await Task.WhenAll(hubEndpoints.SelectMany(h => h.Value.Select(e => AddHubServiceEndpointAsync(e, cancellationToken)))); - - AddEndpointsToNegotiationStore(hubEndpoints); - } - catch (Exception ex) - { - Log.FailedAddingEndpoints(_logger, ex); - } + await Task.WhenAll(hubEndpoints.Select(e => RemoveHubServiceEndpointAsync(e, cancellationToken))); } - } - - private async Task RemoveServiceEndpointsAsync(IReadOnlyList endpoints, CancellationToken cancellationToken) - { - if (endpoints.Count > 0) + catch (Exception ex) { - try - { - var hubEndpoints = UpdateAndGetRemovedHubServiceEndpoints(endpoints); - - await Task.WhenAll(hubEndpoints.Select(e => RemoveHubServiceEndpointAsync(e, cancellationToken))); - } - catch (Exception ex) - { - Log.FailedRemovingEndpoints(_logger, ex); - } + Log.FailedRemovingEndpoints(_logger, ex); } } + } - private void AddEndpointsToNegotiationStore(Dictionary> endpoints) + private void AddEndpointsToNegotiationStore(Dictionary> endpoints) + { + foreach (var hubEndpoints in _endpointsPerHub) { - foreach (var hubEndpoints in _endpointsPerHub) + if (!endpoints.TryGetValue(hubEndpoints.Key, out var updatedEndpoints) + || updatedEndpoints.Count == 0) { - if (!endpoints.TryGetValue(hubEndpoints.Key, out var updatedEndpoints) - || updatedEndpoints.Count == 0) - { - continue; - } - var oldEndpoints = _endpointsPerHub[hubEndpoints.Key]; - var newEndpoints = oldEndpoints.ToList(); - newEndpoints.AddRange(updatedEndpoints); - _endpointsPerHub.TryUpdate(hubEndpoints.Key, newEndpoints, oldEndpoints); + continue; } + var oldEndpoints = _endpointsPerHub[hubEndpoints.Key]; + var newEndpoints = oldEndpoints.ToList(); + newEndpoints.AddRange(updatedEndpoints); + _endpointsPerHub.TryUpdate(hubEndpoints.Key, newEndpoints, oldEndpoints); } + } - private IReadOnlyList UpdateAndGetRemovedHubServiceEndpoints(IEnumerable endpoints) + private IReadOnlyList UpdateAndGetRemovedHubServiceEndpoints(IEnumerable endpoints) + { + var removedEndpoints = new List(); + foreach (var hubEndpoints in _endpointsPerHub) { - var removedEndpoints = new List(); - foreach (var hubEndpoints in _endpointsPerHub) + var remainedEndpoints = new List(); + var oldEndpoints = _endpointsPerHub[hubEndpoints.Key]; + foreach (var endpoint in oldEndpoints) { - var remainedEndpoints = new List(); - var oldEndpoints = _endpointsPerHub[hubEndpoints.Key]; - foreach (var endpoint in oldEndpoints) + var remove = endpoints.FirstOrDefault(e => e.Equals(endpoint)); + if (remove != null) { - var remove = endpoints.FirstOrDefault(e => e.Equals(endpoint)); - if (remove != null) - { - // Refer to reload detector to reset scale task. - if (remove.PendingReload) - { - endpoint.ResetScale(); - } - removedEndpoints.Add(endpoint); - } - else + // Refer to reload detector to reset scale task. + if (remove.PendingReload) { - remainedEndpoints.Add(endpoint); + endpoint.ResetScale(); } + removedEndpoints.Add(endpoint); } - _endpointsPerHub.TryUpdate(hubEndpoints.Key, remainedEndpoints, oldEndpoints); - } - return removedEndpoints; - } - - private HubServiceEndpoint CreateHubServiceEndpoint(string hub, ServiceEndpoint endpoint) - { - var provider = GetEndpointProvider(endpoint); - var hubEndpoint = new HubServiceEndpoint(hub, provider, endpoint); - // check if endpoint is an instant update and copy container directly. - if (_endpointsPerHub.TryGetValue(hub, out var hubEndpoints)) - { - var exist = hubEndpoints.FirstOrDefault(e => (e.Endpoint, e.EndpointType, e.ServerEndpoint) == (endpoint.Endpoint, endpoint.EndpointType, endpoint.ServerEndpoint)); - if (exist != null) + else { - hubEndpoint.ConnectionContainer = exist.ConnectionContainer; + remainedEndpoints.Add(endpoint); } } - return hubEndpoint; - } - - private IReadOnlyList CreateHubServiceEndpoints(string hub, IEnumerable endpoints) - { - return endpoints.Select(e => CreateHubServiceEndpoint(hub, e)).ToList(); + _endpointsPerHub.TryUpdate(hubEndpoints.Key, remainedEndpoints, oldEndpoints); } + return removedEndpoints; + } - private Dictionary> CreateHubServiceEndpoints(IEnumerable endpoints) + private HubServiceEndpoint CreateHubServiceEndpoint(string hub, ServiceEndpoint endpoint) + { + var provider = GetEndpointProvider(endpoint); + var hubEndpoint = new HubServiceEndpoint(hub, provider, endpoint); + // check if endpoint is an instant update and copy container directly. + if (_endpointsPerHub.TryGetValue(hub, out var hubEndpoints)) { - var hubEndpoints = new Dictionary>(); - foreach (var item in _endpointsPerHub) + var exist = hubEndpoints.FirstOrDefault(e => (e.Endpoint, e.EndpointType, e.ServerEndpoint) == (endpoint.Endpoint, endpoint.EndpointType, endpoint.ServerEndpoint)); + if (exist != null) { - hubEndpoints.Add(item.Key, CreateHubServiceEndpoints(item.Key, endpoints)); + hubEndpoint.ConnectionContainer = exist.ConnectionContainer; } - return hubEndpoints; } + return hubEndpoint; + } - private async Task AddHubServiceEndpointAsync(HubServiceEndpoint endpoint, CancellationToken cancellationToken) - { - Log.StartAddingEndpoint(_logger, endpoint.Endpoint, endpoint.Name); + private IReadOnlyList CreateHubServiceEndpoints(string hub, IEnumerable endpoints) + { + return endpoints.Select(e => CreateHubServiceEndpoint(hub, e)).ToList(); + } - OnAdd?.Invoke(endpoint); + private Dictionary> CreateHubServiceEndpoints(IEnumerable endpoints) + { + var hubEndpoints = new Dictionary>(); + foreach (var item in _endpointsPerHub) + { + hubEndpoints.Add(item.Key, CreateHubServiceEndpoints(item.Key, endpoints)); + } + return hubEndpoints; + } - // Wait for new endpoint turn Ready or timeout getting cancelled - var task = await Task.WhenAny(endpoint.ScaleTask, cancellationToken.AsTask()); + private async Task AddHubServiceEndpointAsync(HubServiceEndpoint endpoint, CancellationToken cancellationToken) + { + Log.StartAddingEndpoint(_logger, endpoint.Endpoint, endpoint.Name); - if (task == endpoint.ScaleTask) - { - Log.SucceedAddingEndpoint(_logger, endpoint.ToString()); - } + OnAdd?.Invoke(endpoint); - // Set complete - endpoint.CompleteScale(); - } + // Wait for new endpoint turn Ready or timeout getting cancelled + var task = await Task.WhenAny(endpoint.ScaleTask, cancellationToken.AsTask()); - private async Task RemoveHubServiceEndpointAsync(HubServiceEndpoint endpoint, CancellationToken cancellationToken) + if (task == endpoint.ScaleTask) { - Log.StartRemovingEndpoint(_logger, endpoint.Endpoint, endpoint.Name); + Log.SucceedAddingEndpoint(_logger, endpoint.ToString()); + } - OnRemove?.Invoke(endpoint); + // Set complete + endpoint.CompleteScale(); + } - // Wait for endpoint turn offline or timeout getting cancelled - var task = await Task.WhenAny(endpoint.ScaleTask, cancellationToken.AsTask()); + private async Task RemoveHubServiceEndpointAsync(HubServiceEndpoint endpoint, CancellationToken cancellationToken) + { + Log.StartRemovingEndpoint(_logger, endpoint.Endpoint, endpoint.Name); - if (task == endpoint.ScaleTask) - { - Log.SucceedRemovingEndpoint(_logger, endpoint.ToString()); - } + OnRemove?.Invoke(endpoint); - // Set complete - endpoint.CompleteScale(); - } + // Wait for endpoint turn offline or timeout getting cancelled + var task = await Task.WhenAny(endpoint.ScaleTask, cancellationToken.AsTask()); - private void UpdateEndpoints(Dictionary updatedEndpoints, - out IReadOnlyList addedEndpoints, - out IReadOnlyList removedEndpoints) + if (task == endpoint.ScaleTask) { - // Get exactly same endpoints - var endpoints = Endpoints.Intersect(updatedEndpoints).ToDictionary(k => k.Key, v => v.Value); - - // Get staging required endpoints - var removed = Endpoints.Keys.Except(updatedEndpoints.Keys, new ServiceEndpointWeakComparer()).ToList(); - var added = updatedEndpoints.Keys.Except(Endpoints.Keys, new ServiceEndpointWeakComparer()).ToList(); - removed.ForEach(e => e.PendingReload = true); - foreach (var item in added) - { - item.PendingReload = true; - endpoints.Add(item, item); - } - - // Get instant changable endpoints - var commonEndpoints = updatedEndpoints.Keys - .Intersect(Endpoints.Keys, new ServiceEndpointWeakComparer()) - .Except(Endpoints.Keys); - foreach (var endpoint in commonEndpoints) - { - // search exist from old to remove and reset PendingReload in case changed before. - var exist = Endpoints.First(x => x.Key.Endpoint == endpoint.Endpoint).Key; - exist.PendingReload = false; - removed.Add(exist); + Log.SucceedRemovingEndpoint(_logger, endpoint.ToString()); + } - added.Add(endpoint); - endpoints.Add(endpoint, endpoint); - } - removedEndpoints = removed; - addedEndpoints = added; + // Set complete + endpoint.CompleteScale(); + } - Endpoints = endpoints; + private void UpdateEndpoints(Dictionary updatedEndpoints, + out IReadOnlyList addedEndpoints, + out IReadOnlyList removedEndpoints) + { + // Get exactly same endpoints + var endpoints = Endpoints.Intersect(updatedEndpoints).ToDictionary(k => k.Key, v => v.Value); + + // Get staging required endpoints + var removed = Endpoints.Keys.Except(updatedEndpoints.Keys, new ServiceEndpointWeakComparer()).ToList(); + var added = updatedEndpoints.Keys.Except(Endpoints.Keys, new ServiceEndpointWeakComparer()).ToList(); + removed.ForEach(e => e.PendingReload = true); + foreach (var item in added) + { + item.PendingReload = true; + endpoints.Add(item, item); } - private static async Task WaitTaskOrTimeout(Task task, CancellationTokenSource cts) + // Get instant changable endpoints + var commonEndpoints = updatedEndpoints.Keys + .Intersect(Endpoints.Keys, new ServiceEndpointWeakComparer()) + .Except(Endpoints.Keys); + foreach (var endpoint in commonEndpoints) { - var completed = await Task.WhenAny(task, Task.Delay(Timeout.InfiniteTimeSpan, cts.Token)); + // search exist from old to remove and reset PendingReload in case changed before. + var exist = Endpoints.First(x => x.Key.Endpoint == endpoint.Endpoint).Key; + exist.PendingReload = false; + removed.Add(exist); - if (completed == task) - { - return true; - } - - cts.Cancel(); - return false; + added.Add(endpoint); + endpoints.Add(endpoint, endpoint); } + removedEndpoints = removed; + addedEndpoints = added; - private sealed class ServiceEndpointWeakComparer : IEqualityComparer - { - public bool Equals(ServiceEndpoint x, ServiceEndpoint y) - { - return (x.Endpoint, x.EndpointType, x.ServerEndpoint) == (y.Endpoint, y.EndpointType, y.ServerEndpoint); - } + Endpoints = endpoints; + } - public int GetHashCode(ServiceEndpoint obj) - { - return obj.Endpoint.GetHashCode() ^ obj.EndpointType.GetHashCode() ^ obj.ServerEndpoint.GetHashCode(); - } - } + private static class Log + { + private static readonly Action _duplicateEndpointFound = + LoggerMessage.Define(LogLevel.Warning, new EventId(1, "DuplicateEndpointFound"), "{count} endpoint configurations to '{endpoint}' found, use '{name}'."); - private static class Log - { - private static readonly Action _duplicateEndpointFound = - LoggerMessage.Define(LogLevel.Warning, new EventId(1, "DuplicateEndpointFound"), "{count} endpoint configurations to '{endpoint}' found, use '{name}'."); + private static readonly Action _startAddingEndpoint = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "StartAddingEndpoint"), "Start adding endpoint: '{endpoint}', name: '{name}'."); - private static readonly Action _startAddingEndpoint = - LoggerMessage.Define(LogLevel.Debug, new EventId(2, "StartAddingEndpoint"), "Start adding endpoint: '{endpoint}', name: '{name}'."); - - private static readonly Action _startRemovingEndpoint = - LoggerMessage.Define(LogLevel.Debug, new EventId(3, "StartRemovingEndpoint"), "Start removing endpoint: '{endpoint}', name: '{name}'"); + private static readonly Action _startRemovingEndpoint = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, "StartRemovingEndpoint"), "Start removing endpoint: '{endpoint}', name: '{name}'"); - private static readonly Action _startRenamingEndpoint = - LoggerMessage.Define(LogLevel.Debug, new EventId(4, "StartRenamingEndpoint"), "Start renaming endpoint: '{endpoint}', name: '{name}'"); + private static readonly Action _startRenamingEndpoint = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, "StartRenamingEndpoint"), "Start renaming endpoint: '{endpoint}', name: '{name}'"); - private static readonly Action _reloadEndpointError = - LoggerMessage.Define(LogLevel.Error, new EventId(5, "ReloadEndpointsError"), "No connection string is specified. Skip scale operation."); + private static readonly Action _reloadEndpointError = + LoggerMessage.Define(LogLevel.Error, new EventId(5, "ReloadEndpointsError"), "No connection string is specified. Skip scale operation."); - private static readonly Action _AddEndpointsTimeout = - LoggerMessage.Define(LogLevel.Error, new EventId(6, "AddEndpointsTimeout"), "Timeout waiting for adding endpoints."); + private static readonly Action _AddEndpointsTimeout = + LoggerMessage.Define(LogLevel.Error, new EventId(6, "AddEndpointsTimeout"), "Timeout waiting for adding endpoints."); - private static readonly Action _removeEndpointsTimeout = - LoggerMessage.Define(LogLevel.Error, new EventId(7, "RemoveEndpointsTimeout"), "Timeout waiting for removing endpoints."); + private static readonly Action _removeEndpointsTimeout = + LoggerMessage.Define(LogLevel.Error, new EventId(7, "RemoveEndpointsTimeout"), "Timeout waiting for removing endpoints."); - private static readonly Action _failedAddingEndpoints = - LoggerMessage.Define(LogLevel.Error, new EventId(8, "FailedAddingEndpoints"), "Failed adding endpoints."); + private static readonly Action _failedAddingEndpoints = + LoggerMessage.Define(LogLevel.Error, new EventId(8, "FailedAddingEndpoints"), "Failed adding endpoints."); - private static readonly Action _failedRemovingEndpoints = - LoggerMessage.Define(LogLevel.Error, new EventId(9, "FailedRemovingEndpoints"), "Failed removing endpoints."); + private static readonly Action _failedRemovingEndpoints = + LoggerMessage.Define(LogLevel.Error, new EventId(9, "FailedRemovingEndpoints"), "Failed removing endpoints."); - private static readonly Action _succeedAddingEndpoints = - LoggerMessage.Define(LogLevel.Information, new EventId(10, "SucceedAddingEndpoint"), "Succeed in adding endpoint: '{endpoint}'"); + private static readonly Action _succeedAddingEndpoints = + LoggerMessage.Define(LogLevel.Information, new EventId(10, "SucceedAddingEndpoint"), "Succeed in adding endpoint: '{endpoint}'"); - private static readonly Action _succeedRemovingEndpoints = - LoggerMessage.Define(LogLevel.Information, new EventId(11, "SucceedRemovingEndpoint"), "Succeed in removing endpoint: '{endpoint}'"); + private static readonly Action _succeedRemovingEndpoints = + LoggerMessage.Define(LogLevel.Information, new EventId(11, "SucceedRemovingEndpoint"), "Succeed in removing endpoint: '{endpoint}'"); - public static void DuplicateEndpointFound(ILogger logger, int count, string endpoint, string name) - { - _duplicateEndpointFound(logger, count, endpoint, name, null); - } + public static void DuplicateEndpointFound(ILogger logger, int count, string endpoint, string name) + { + _duplicateEndpointFound(logger, count, endpoint, name, null); + } - public static void StartAddingEndpoint(ILogger logger, string endpoint, string name) - { - _startAddingEndpoint(logger, endpoint, name, null); - } + public static void StartAddingEndpoint(ILogger logger, string endpoint, string name) + { + _startAddingEndpoint(logger, endpoint, name, null); + } - public static void StartRemovingEndpoint(ILogger logger, string endpoint, string name) - { - _startRemovingEndpoint(logger, endpoint, name, null); - } + public static void StartRemovingEndpoint(ILogger logger, string endpoint, string name) + { + _startRemovingEndpoint(logger, endpoint, name, null); + } - public static void StartRenamingEndpoint(ILogger logger, string endpoint, string name) - { - _startRenamingEndpoint(logger, endpoint, name, null); - } + public static void StartRenamingEndpoint(ILogger logger, string endpoint, string name) + { + _startRenamingEndpoint(logger, endpoint, name, null); + } - public static void ReloadEndpointsError(ILogger logger, Exception ex) - { - _reloadEndpointError(logger, ex); - } + public static void ReloadEndpointsError(ILogger logger, Exception ex) + { + _reloadEndpointError(logger, ex); + } - public static void AddEndpointsTimeout(ILogger logger) - { - _AddEndpointsTimeout(logger, null); - } + public static void AddEndpointsTimeout(ILogger logger) + { + _AddEndpointsTimeout(logger, null); + } - public static void RemoveEndpointsTimeout(ILogger logger) - { - _removeEndpointsTimeout(logger, null); - } + public static void RemoveEndpointsTimeout(ILogger logger) + { + _removeEndpointsTimeout(logger, null); + } - public static void FailedAddingEndpoints(ILogger logger, Exception ex) - { - _failedAddingEndpoints(logger, ex); - } + public static void FailedAddingEndpoints(ILogger logger, Exception ex) + { + _failedAddingEndpoints(logger, ex); + } - public static void FailedRemovingEndpoints(ILogger logger, Exception ex) - { - _failedRemovingEndpoints(logger, ex); - } + public static void FailedRemovingEndpoints(ILogger logger, Exception ex) + { + _failedRemovingEndpoints(logger, ex); + } - public static void SucceedAddingEndpoint(ILogger logger, string endpoint) - { - _succeedAddingEndpoints(logger, endpoint, null); - } + public static void SucceedAddingEndpoint(ILogger logger, string endpoint) + { + _succeedAddingEndpoints(logger, endpoint, null); + } - public static void SucceedRemovingEndpoint(ILogger logger, string endpoint) - { - _succeedRemovingEndpoints(logger, endpoint, null); - } + public static void SucceedRemovingEndpoint(ILogger logger, string endpoint) + { + _succeedRemovingEndpoints(logger, endpoint, null); + } + } + + private sealed class ServiceEndpointWeakComparer : IEqualityComparer + { + public bool Equals(ServiceEndpoint x, ServiceEndpoint y) + { + return (x.Endpoint, x.EndpointType, x.ServerEndpoint) == (y.Endpoint, y.EndpointType, y.ServerEndpoint); + } + + public int GetHashCode(ServiceEndpoint obj) + { + return obj.Endpoint.GetHashCode() ^ obj.EndpointType.GetHashCode() ^ obj.ServerEndpoint.GetHashCode(); } } } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpointUtility.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpointUtility.cs index 0133582d3..e00348ad5 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpointUtility.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpointUtility.cs @@ -3,23 +3,22 @@ using System.Collections.Generic; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal static class ServiceEndpointUtility { - internal static class ServiceEndpointUtility + public static IEnumerable Merge(string connectionString, IEnumerable endpoints) { - public static IEnumerable Merge(string connectionString, IEnumerable endpoints) + if (!string.IsNullOrEmpty(connectionString)) { - if (!string.IsNullOrEmpty(connectionString)) - { - yield return new ServiceEndpoint(connectionString); - } + yield return new ServiceEndpoint(connectionString); + } - if (endpoints != null) + if (endpoints != null) + { + foreach (var endpoint in endpoints) { - foreach (var endpoint in endpoints) - { - yield return endpoint; - } + yield return endpoint; } } } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs index dd05e87cc..0f4bf6870 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs @@ -4,13 +4,12 @@ using System; using System.Linq; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal static class UriExtensions { - internal static class UriExtensions + public static Uri Append(this Uri uri, params string[] paths) { - public static Uri Append(this Uri uri, params string[] paths) - { - return new Uri(paths.Aggregate(uri.AbsoluteUri, (current, path) => string.Format("{0}/{1}", current.TrimEnd('/'), path.TrimStart('/')))); - } + return new Uri(paths.Aggregate(uri.AbsoluteUri, (current, path) => string.Format("{0}/{1}", current.TrimEnd('/'), path.TrimStart('/')))); } } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IAccessTokenProvider.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IAccessTokenProvider.cs index 695402fc8..367397e64 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IAccessTokenProvider.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IAccessTokenProvider.cs @@ -3,10 +3,9 @@ using System.Threading.Tasks; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal interface IAccessTokenProvider { - internal interface IAccessTokenProvider - { - Task ProvideAsync(); - } + Task ProvideAsync(); } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs index a79e85e24..4c39f78f7 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; -using System.Threading; using System.Threading.Tasks; namespace Microsoft.Azure.SignalR; diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainerFactory.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainerFactory.cs index 3e2de8f41..00b2ad7f3 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainerFactory.cs @@ -3,10 +3,9 @@ using System; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal interface IServiceConnectionContainerFactory { - internal interface IServiceConnectionContainerFactory - { - IServiceConnectionContainer Create(string hub, TimeSpan? serviceScaleTimeout = null); - } + IServiceConnectionContainer Create(string hub, TimeSpan? serviceScaleTimeout = null); } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs index f1d624a45..4ddef03be 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs @@ -5,5 +5,8 @@ namespace Microsoft.Azure.SignalR; internal interface IServiceConnectionFactory { - IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type); + IServiceConnection Create(HubServiceEndpoint endpoint, + IServiceMessageHandler serviceMessageHandler, + AckHandler ackHandler, + ServiceConnectionType type); } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceEventHandler.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceEventHandler.cs index c936aa380..cf825d81e 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceEventHandler.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceEventHandler.cs @@ -5,10 +5,9 @@ using Microsoft.Azure.SignalR.Protocol; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +public interface IServiceEventHandler { - public interface IServiceEventHandler - { - Task HandleAsync(string connectionId, ServiceEventMessage message); - } + Task HandleAsync(string connectionId, ServiceEventMessage message); } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceMessageHandler.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceMessageHandler.cs index 6143b72eb..31ff86258 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceMessageHandler.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceMessageHandler.cs @@ -4,12 +4,11 @@ using System.Threading.Tasks; using Microsoft.Azure.SignalR.Protocol; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal interface IServiceMessageHandler { - internal interface IServiceMessageHandler - { - Task HandlePingAsync(PingMessage pingMessage); + Task HandlePingAsync(PingMessage pingMessage); - void HandleAck(AckMessage ackMessage); - } + void HandleAck(AckMessage ackMessage); } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/AckStatus.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/AckStatus.cs index 4577cc816..437ba2227 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/AckStatus.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/AckStatus.cs @@ -1,14 +1,15 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +namespace Microsoft.Azure.SignalR; -namespace Microsoft.Azure.SignalR +internal enum AckStatus { - internal enum AckStatus - { - Ok = 1, - NotFound = 2, - Timeout = 3, - InternalServerError = 4, - } + Ok = 1, + + NotFound = 2, + + Timeout = 3, + + InternalServerError = 4, } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactory.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactory.cs index 300be3e06..b35e310bc 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactory.cs @@ -9,150 +9,149 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class ConnectionFactory : IConnectionFactory { - internal class ConnectionFactory : IConnectionFactory + private readonly ILoggerFactory _loggerFactory; + + private readonly string _serverId; + + public ConnectionFactory(IServerNameProvider nameProvider, ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory != null ? new GracefulLoggerFactory(loggerFactory) : throw new ArgumentNullException(nameof(loggerFactory)); + _serverId = nameProvider?.GetName(); + } + + public async Task ConnectAsync(HubServiceEndpoint hubServiceEndpoint, + TransferFormat transferFormat, + string connectionId, + string target, + CancellationToken cancellationToken = default, + IDictionary headers = null) { - private readonly ILoggerFactory _loggerFactory; + var provider = hubServiceEndpoint.Provider; + var hubName = hubServiceEndpoint.Hub; - private readonly string _serverId; + var accessTokenProvider = provider.GetServerAccessTokenProvider(hubName, _serverId); - public ConnectionFactory(IServerNameProvider nameProvider, ILoggerFactory loggerFactory) + var url = GetServiceUrl(provider, hubName, connectionId, target); + + headers ??= new Dictionary(); + if (!string.IsNullOrEmpty(_serverId) && !headers.ContainsKey(Constants.Headers.AsrsServerId)) { - _loggerFactory = loggerFactory != null ? new GracefulLoggerFactory(loggerFactory) : throw new ArgumentNullException(nameof(loggerFactory)); - _serverId = nameProvider?.GetName(); + headers.Add(Constants.Headers.AsrsServerId, _serverId); } - public async Task ConnectAsync(HubServiceEndpoint hubServiceEndpoint, - TransferFormat transferFormat, - string connectionId, - string target, - CancellationToken cancellationToken = default, - IDictionary headers = null) + var connectionOptions = new WebSocketConnectionOptions + { + Headers = headers, + Proxy = provider.Proxy, + }; + var connection = new WebSocketConnectionContext(connectionOptions, _loggerFactory, accessTokenProvider); + try { - var provider = hubServiceEndpoint.Provider; - var hubName = hubServiceEndpoint.Hub; + await connection.StartAsync(url, cancellationToken); - var accessTokenProvider = provider.GetServerAccessTokenProvider(hubName, _serverId); + return connection; + } + catch + { + await connection.StopAsync(); + throw; + } + } - var url = GetServiceUrl(provider, hubName, connectionId, target); + public Task DisposeAsync(ConnectionContext connection) + { + if (connection == null) + { + return Task.CompletedTask; + } - headers ??= new Dictionary(); - if (!string.IsNullOrEmpty(_serverId) && !headers.ContainsKey(Constants.Headers.AsrsServerId)) - { - headers.Add(Constants.Headers.AsrsServerId, _serverId); - } + return ((WebSocketConnectionContext)connection).StopAsync(); + } - var connectionOptions = new WebSocketConnectionOptions - { - Headers = headers, - Proxy = provider.Proxy, - }; - var connection = new WebSocketConnectionContext(connectionOptions, _loggerFactory, accessTokenProvider); - try - { - await connection.StartAsync(url, cancellationToken); + private Uri GetServiceUrl(IServiceEndpointProvider provider, string hubName, string connectionId, string target) + { + var baseUri = new UriBuilder(provider.GetServerEndpoint(hubName)); + var query = "cid=" + connectionId; + if (target != null) + { + query = $"{query}&target={WebUtility.UrlEncode(target)}"; + } + if (baseUri.Query != null && baseUri.Query.Length > 1) + { + baseUri.Query = baseUri.Query.Substring(1) + "&" + query; + } + else + { + baseUri.Query = query; + } + return baseUri.Uri; + } - return connection; - } - catch - { - await connection.StopAsync(); - throw; - } + private sealed class GracefulLoggerFactory : ILoggerFactory + { + private readonly ILoggerFactory _inner; + + public GracefulLoggerFactory(ILoggerFactory inner) + { + _inner = inner; } - public Task DisposeAsync(ConnectionContext connection) + public void Dispose() { - if (connection == null) - { - return Task.CompletedTask; - } + _inner.Dispose(); + } - return ((WebSocketConnectionContext)connection).StopAsync(); + public ILogger CreateLogger(string categoryName) + { + var innerLogger = _inner.CreateLogger(categoryName); + return new GracefulLogger(innerLogger); } - private Uri GetServiceUrl(IServiceEndpointProvider provider, string hubName, string connectionId, string target) + public void AddProvider(ILoggerProvider provider) { - var baseUri = new UriBuilder(provider.GetServerEndpoint(hubName)); - var query = "cid=" + connectionId; - if (target != null) - { - query = $"{query}&target={WebUtility.UrlEncode(target)}"; - } - if (baseUri.Query != null && baseUri.Query.Length > 1) - { - baseUri.Query = baseUri.Query.Substring(1) + "&" + query; - } - else - { - baseUri.Query = query; - } - return baseUri.Uri; + _inner.AddProvider(provider); } - private sealed class GracefulLoggerFactory : ILoggerFactory + private sealed class GracefulLogger : ILogger { - private readonly ILoggerFactory _inner; + private readonly ILogger _inner; - public GracefulLoggerFactory(ILoggerFactory inner) + public GracefulLogger(ILogger inner) { _inner = inner; } - public void Dispose() - { - _inner.Dispose(); - } - - public ILogger CreateLogger(string categoryName) + /// + /// Downgrade error level logs, and also exclude exception details + /// Exceptions thrown from inside the HttpConnection are supposed to be handled by the caller and logged with more user-friendly message + /// + /// + /// + /// + /// + /// + /// + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) { - var innerLogger = _inner.CreateLogger(categoryName); - return new GracefulLogger(innerLogger); + if (logLevel >= LogLevel.Error) + { + logLevel = LogLevel.Warning; + } + _inner.Log(logLevel, eventId, state, null, formatter); } - public void AddProvider(ILoggerProvider provider) + public bool IsEnabled(LogLevel logLevel) { - _inner.AddProvider(provider); + return _inner.IsEnabled(logLevel); } - private sealed class GracefulLogger : ILogger + public IDisposable BeginScope(TState state) { - private readonly ILogger _inner; - - public GracefulLogger(ILogger inner) - { - _inner = inner; - } - - /// - /// Downgrade error level logs, and also exclude exception details - /// Exceptions thrown from inside the HttpConnection are supposed to be handled by the caller and logged with more user-friendly message - /// - /// - /// - /// - /// - /// - /// - public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) - { - if (logLevel >= LogLevel.Error) - { - logLevel = LogLevel.Warning; - } - _inner.Log(logLevel, eventId, state, null, formatter); - } - - public bool IsEnabled(LogLevel logLevel) - { - return _inner.IsEnabled(logLevel); - } - - public IDisposable BeginScope(TState state) - { - return _inner.BeginScope(state); - } + return _inner.BeginScope(state); } } } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/GracefulShutdownMode.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/GracefulShutdownMode.cs index d1cee3aae..12cab7cce 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/GracefulShutdownMode.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/GracefulShutdownMode.cs @@ -1,41 +1,40 @@ -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +/// +/// This mode defines the server's behavior after receiving a `Ctrl+C` (SIGINT). +/// +public enum GracefulShutdownMode { /// - /// This mode defines the server's behavior after receiving a `Ctrl+C` (SIGINT). + /// The server will stop immediately, all existing connections will be dropped immediately. /// - public enum GracefulShutdownMode - { - /// - /// The server will stop immediately, all existing connections will be dropped immediately. - /// - Off = 0, + Off = 0, - /// - /// We will immediately remove this server from Azure SignalR, - /// which means no more new connections will be assigned to this server, - /// the existing connections won't be influenced until a default timeout (30s). - /// Once all connections on this server are closed properly, the server stops. - /// - WaitForClientsClose = 1, + /// + /// We will immediately remove this server from Azure SignalR, + /// which means no more new connections will be assigned to this server, + /// the existing connections won't be influenced until a default timeout (30s). + /// Once all connections on this server are closed properly, the server stops. + /// + WaitForClientsClose = 1, - /// - /// Similar to `WaitForClientsClose`, the server will be removed from Azure SignalR. - /// But instead of waiting existing connections to close, we will try to migrate client connections to another valid server, - /// which may save most of your connections during this process. - /// - /// It happens on the message boundaries, considering if each of your message consist of 3 packages. The migration will happen at here: - /// - /// | P1 - P2 - P3 | [HERE] | P4 - P5 - P6 | - /// | Message 1 | | Message 2 | - /// - /// We do this by finding message boundaries on-fly, - /// For JSON protocol, we simply find seperators (,) - /// For MessagePack protocol, we preserve the length header and count body length to determine if the message was finished. - /// - /// This mode always works well with context-free scenarios. - /// Since the `connectionId` will not change before-and-after migration, - /// you may also benifit from this feature by using a distributed storage even if your scenario is not context-free. - /// - MigrateClients = 2, - } + /// + /// Similar to `WaitForClientsClose`, the server will be removed from Azure SignalR. + /// But instead of waiting existing connections to close, we will try to migrate client connections to another valid server, + /// which may save most of your connections during this process. + /// + /// It happens on the message boundaries, considering if each of your message consist of 3 packages. The migration will happen at here: + /// + /// | P1 - P2 - P3 | [HERE] | P4 - P5 - P6 | + /// | Message 1 | | Message 2 | + /// + /// We do this by finding message boundaries on-fly, + /// For JSON protocol, we simply find seperators (,) + /// For MessagePack protocol, we preserve the length header and count body length to determine if the message was finished. + /// + /// This mode always works well with context-free scenarios. + /// Since the `connectionId` will not change before-and-after migration, + /// you may also benifit from this feature by using a distributed storage even if your scenario is not context-free. + /// + MigrateClients = 2, } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs index 481d7c925..e9cf62fb5 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs @@ -11,205 +11,204 @@ using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +/// +/// A service connection container which sends message to multiple service endpoints. +/// +internal class MultiEndpointMessageWriter : IServiceMessageWriter { - /// - /// A service connection container which sends message to multiple service endpoints. - /// - internal class MultiEndpointMessageWriter : IServiceMessageWriter - { - private readonly ILogger _logger; + private readonly ILogger _logger; - internal HubServiceEndpoint[] TargetEndpoints { get; } + internal HubServiceEndpoint[] TargetEndpoints { get; } - public MultiEndpointMessageWriter(IReadOnlyCollection targetEndpoints, ILoggerFactory loggerFactory) + public MultiEndpointMessageWriter(IReadOnlyCollection targetEndpoints, ILoggerFactory loggerFactory) + { + _logger = loggerFactory.CreateLogger(); + var normalized = new List(); + if (targetEndpoints != null) { - _logger = loggerFactory.CreateLogger(); - var normalized = new List(); - if (targetEndpoints != null) + foreach (var endpoint in targetEndpoints.Where(s => s != null)) { - foreach (var endpoint in targetEndpoints.Where(s => s != null)) + var hubEndpoint = endpoint as HubServiceEndpoint; + // it is possible that the endpoint is not a valid HubServiceEndpoint since it can be changed by the router + if (hubEndpoint == null || hubEndpoint.ConnectionContainer == null) { - var hubEndpoint = endpoint as HubServiceEndpoint; - // it is possible that the endpoint is not a valid HubServiceEndpoint since it can be changed by the router - if (hubEndpoint == null || hubEndpoint.ConnectionContainer == null) - { - Log.EndpointNotExists(_logger, endpoint.ToString()); - } - else - { - normalized.Add(hubEndpoint); - } + Log.EndpointNotExists(_logger, endpoint.ToString()); + } + else + { + normalized.Add(hubEndpoint); } } - - TargetEndpoints = normalized.ToArray(); } - public Task ConnectionInitializedTask => Task.WhenAll(TargetEndpoints.Select(e => e.ConnectionContainer.ConnectionInitializedTask)); + TargetEndpoints = normalized.ToArray(); + } - public Task WriteAsync(ServiceMessage serviceMessage) + public Task ConnectionInitializedTask => Task.WhenAll(TargetEndpoints.Select(e => e.ConnectionContainer.ConnectionInitializedTask)); + + public Task WriteAsync(ServiceMessage serviceMessage) + { + return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage)); + } + + public Task WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) + { + if (serviceMessage is CheckConnectionExistenceWithAckMessage + || serviceMessage is JoinGroupWithAckMessage + || serviceMessage is LeaveGroupWithAckMessage) { - return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage)); + return WriteSingleResultAckableMessage(serviceMessage, cancellationToken); } - - public Task WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) + else { - if (serviceMessage is CheckConnectionExistenceWithAckMessage - || serviceMessage is JoinGroupWithAckMessage - || serviceMessage is LeaveGroupWithAckMessage) - { - return WriteSingleResultAckableMessage(serviceMessage, cancellationToken); - } - else - { - return WriteMultiResultAckableMessage(serviceMessage, cancellationToken); - } + return WriteMultiResultAckableMessage(serviceMessage, cancellationToken); } + } - /// - /// For user or group related operations, different endpoints might return different results - /// Strategy: - /// Always wait until all endpoints return or throw - /// * When any endpoint throws, throw - /// * When all endpoints return false, return false - /// * When any endpoint returns true, return true - /// - /// - /// - /// - private async Task WriteMultiResultAckableMessage(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) - { - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - var bag = new ConcurrentBag(); - await WriteMultiEndpointMessageAsync(serviceMessage, async connection => - { - bag.Add(await connection.WriteAckableMessageAsync(serviceMessage.Clone(), cancellationToken)); - }); - - return bag.Any(i => i); - } + /// + /// For user or group related operations, different endpoints might return different results + /// Strategy: + /// Always wait until all endpoints return or throw + /// * When any endpoint throws, throw + /// * When all endpoints return false, return false + /// * When any endpoint returns true, return true + /// + /// + /// + /// + private async Task WriteMultiResultAckableMessage(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - /// - /// For connection related operations, since connectionId is globally unique, only one endpoint can have the connection - /// Strategy: - /// Don't need to wait until all endpoints return or throw - /// * Whenever any endpoint returns true: return true - /// * When any endpoint throws throw - /// * When all endpoints return false, return false - /// - /// - /// - /// - private async Task WriteSingleResultAckableMessage(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) - { - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - var writeMessageTask = WriteMultiEndpointMessageAsync(serviceMessage, async connection => - { - var succeeded = await connection.WriteAckableMessageAsync(serviceMessage.Clone(), cancellationToken); - if (succeeded) - { - tcs.TrySetResult(true); - } - }); + var bag = new ConcurrentBag(); + await WriteMultiEndpointMessageAsync(serviceMessage, async connection => + { + bag.Add(await connection.WriteAckableMessageAsync(serviceMessage.Clone(), cancellationToken)); + }); - // we wait when tcs is set to true or all the tasks return - var task = await Task.WhenAny(tcs.Task, writeMessageTask); + return bag.Any(i => i); + } - // tcs is either already set as true or should be false now - tcs.TrySetResult(false); + /// + /// For connection related operations, since connectionId is globally unique, only one endpoint can have the connection + /// Strategy: + /// Don't need to wait until all endpoints return or throw + /// * Whenever any endpoint returns true: return true + /// * When any endpoint throws throw + /// * When all endpoints return false, return false + /// + /// + /// + /// + private async Task WriteSingleResultAckableMessage(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - if (tcs.Task.Result) + var writeMessageTask = WriteMultiEndpointMessageAsync(serviceMessage, async connection => + { + var succeeded = await connection.WriteAckableMessageAsync(serviceMessage.Clone(), cancellationToken); + if (succeeded) { - return true; + tcs.TrySetResult(true); } + }); - // This will throw exceptions in tasks if exceptions exist - await writeMessageTask; - return false; - } + // we wait when tcs is set to true or all the tasks return + var task = await Task.WhenAny(tcs.Task, writeMessageTask); - private async Task WriteMultiEndpointMessageAsync(ServiceMessage serviceMessage, Func inner) + // tcs is either already set as true or should be false now + tcs.TrySetResult(false); + + if (tcs.Task.Result) { - if (TargetEndpoints.Length == 0) - { - Log.NoEndpointRouted(_logger, serviceMessage.GetType().Name); - return; - } + return true; + } - if (TargetEndpoints.Length == 1) - { - await WriteSingleEndpointMessageAsync(TargetEndpoints[0], serviceMessage, inner); - return; - } + // This will throw exceptions in tasks if exceptions exist + await writeMessageTask; + return false; + } - var task = Task.WhenAll(TargetEndpoints.Select((endpoint) => WriteSingleEndpointMessageAsync(endpoint, serviceMessage, inner))); - try - { - await task; - } - catch (Exception ex) - { - // throw the aggregated exception instead - throw task.Exception ?? ex; - } + private async Task WriteMultiEndpointMessageAsync(ServiceMessage serviceMessage, Func inner) + { + if (TargetEndpoints.Length == 0) + { + Log.NoEndpointRouted(_logger, serviceMessage.GetType().Name); + return; } - private async Task WriteSingleEndpointMessageAsync(HubServiceEndpoint endpoint, ServiceMessage serviceMessage, Func inner) + if (TargetEndpoints.Length == 1) { - try - { - Log.RouteMessageToServiceEndpoint(_logger, serviceMessage, endpoint.ToString()); - await inner(endpoint.ConnectionContainer); - } - catch (ServiceConnectionNotActiveException) - { - // log and don't stop other endpoints - Log.FailedWritingMessageToEndpoint(_logger, serviceMessage.GetType().Name, (serviceMessage as IMessageWithTracingId)?.TracingId, endpoint.ToString()); - throw new FailedWritingMessageToServiceException(endpoint.ServerEndpoint.AbsoluteUri); - } + await WriteSingleEndpointMessageAsync(TargetEndpoints[0], serviceMessage, inner); + return; + } + + var task = Task.WhenAll(TargetEndpoints.Select((endpoint) => WriteSingleEndpointMessageAsync(endpoint, serviceMessage, inner))); + try + { + await task; } + catch (Exception ex) + { + // throw the aggregated exception instead + throw task.Exception ?? ex; + } + } - internal static class Log + private async Task WriteSingleEndpointMessageAsync(HubServiceEndpoint endpoint, ServiceMessage serviceMessage, Func inner) + { + try + { + Log.RouteMessageToServiceEndpoint(_logger, serviceMessage, endpoint.ToString()); + await inner(endpoint.ConnectionContainer); + } + catch (ServiceConnectionNotActiveException) { - public const string FailedWritingMessageToEndpointTemplate = "{0} message {1} is not sent to endpoint {2} because all connections to this endpoint are offline."; + // log and don't stop other endpoints + Log.FailedWritingMessageToEndpoint(_logger, serviceMessage.GetType().Name, (serviceMessage as IMessageWithTracingId)?.TracingId, endpoint.ToString()); + throw new FailedWritingMessageToServiceException(endpoint.ServerEndpoint.AbsoluteUri); + } + } - private static readonly Action _endpointNotExists = - LoggerMessage.Define(LogLevel.Error, new EventId(3, "EndpointNotExists"), "Endpoint {endpoint} from the router does not exists."); + internal static class Log + { + public const string FailedWritingMessageToEndpointTemplate = "{0} message {1} is not sent to endpoint {2} because all connections to this endpoint are offline."; - private static readonly Action _noEndpointRouted = - LoggerMessage.Define(LogLevel.Warning, new EventId(4, "NoEndpointRouted"), "Message {messageType} is not sent because no endpoint is returned from the endpoint router."); + private static readonly Action _endpointNotExists = + LoggerMessage.Define(LogLevel.Error, new EventId(3, "EndpointNotExists"), "Endpoint {endpoint} from the router does not exists."); - private static readonly Action _failedWritingMessageToEndpoint = - LoggerMessage.Define(LogLevel.Warning, new EventId(5, "FailedWritingMessageToEndpoint"), FailedWritingMessageToEndpointTemplate); + private static readonly Action _noEndpointRouted = + LoggerMessage.Define(LogLevel.Warning, new EventId(4, "NoEndpointRouted"), "Message {messageType} is not sent because no endpoint is returned from the endpoint router."); - private static readonly Action _routeMessageToServiceEndpoint = - LoggerMessage.Define(LogLevel.Information, new EventId(11, "RouteMessageToServiceEndpoint"), "Route message {tracingId} to service endpoint {endpoint}."); + private static readonly Action _failedWritingMessageToEndpoint = + LoggerMessage.Define(LogLevel.Warning, new EventId(5, "FailedWritingMessageToEndpoint"), FailedWritingMessageToEndpointTemplate); - public static void RouteMessageToServiceEndpoint(ILogger logger, ServiceMessage message, string endpoint) - { - if (ServiceConnectionContainerScope.EnableMessageLog || ClientConnectionScope.IsDiagnosticClient) - { - _routeMessageToServiceEndpoint(logger, (message as IMessageWithTracingId).TracingId, endpoint, null); - } - } + private static readonly Action _routeMessageToServiceEndpoint = + LoggerMessage.Define(LogLevel.Information, new EventId(11, "RouteMessageToServiceEndpoint"), "Route message {tracingId} to service endpoint {endpoint}."); - public static void EndpointNotExists(ILogger logger, string endpoint) + public static void RouteMessageToServiceEndpoint(ILogger logger, ServiceMessage message, string endpoint) + { + if (ServiceConnectionContainerScope.EnableMessageLog || ClientConnectionScope.IsDiagnosticClient) { - _endpointNotExists(logger, endpoint, null); + _routeMessageToServiceEndpoint(logger, (message as IMessageWithTracingId).TracingId, endpoint, null); } + } - public static void NoEndpointRouted(ILogger logger, string messageType) - { - _noEndpointRouted(logger, messageType, null); - } + public static void EndpointNotExists(ILogger logger, string endpoint) + { + _endpointNotExists(logger, endpoint, null); + } - public static void FailedWritingMessageToEndpoint(ILogger logger, string messageType, ulong? tracingId, string endpoint) - { - _failedWritingMessageToEndpoint(logger, messageType, tracingId, endpoint, null); - } + public static void NoEndpointRouted(ILogger logger, string messageType) + { + _noEndpointRouted(logger, messageType, null); + } + + public static void FailedWritingMessageToEndpoint(ILogger logger, string messageType, ulong? tracingId, string endpoint) + { + _failedWritingMessageToEndpoint(logger, messageType, tracingId, endpoint, null); } } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs index 2b7250fa7..f10bf610e 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs @@ -10,496 +10,504 @@ using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class MultiEndpointServiceConnectionContainer : IServiceConnectionContainer { - internal class MultiEndpointServiceConnectionContainer : IServiceConnectionContainer - { - private readonly string _hubName; - private readonly IMessageRouter _router; - private readonly ILoggerFactory _loggerFactory; - private readonly ILogger _logger; - private readonly IServiceEndpointManager _serviceEndpointManager; - private readonly TimeSpan _scaleTimeout; - private readonly Func _generator; - private readonly object _lock = new object(); - - private (bool needRouter, IReadOnlyList endpoints) _routerEndpoints; - private int _started = 0; - - internal MultiEndpointServiceConnectionContainer( - string hub, - Func generator, - IServiceEndpointManager endpointManager, - IMessageRouter router, - ILoggerFactory loggerFactory, - TimeSpan? scaleTimeout = null) - { - if (generator == null) - { - throw new ArgumentNullException(nameof(generator)); - } + private readonly string _hubName; - _hubName = hub; - _router = router ?? throw new ArgumentNullException(nameof(router)); - _loggerFactory = loggerFactory; - _logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); - _serviceEndpointManager = endpointManager; - _scaleTimeout = scaleTimeout ?? Constants.Periods.DefaultScaleTimeout; + private readonly IMessageRouter _router; - // Reserve generator for potential scale use. - _generator = generator; + private readonly ILoggerFactory _loggerFactory; - // provides a copy to the endpoint per container - var endpoints = endpointManager.GetEndpoints(hub); - UpdateRoutedEndpoints(endpoints); + private readonly ILogger _logger; - foreach (var endpoint in endpoints) - { - endpoint.ConnectionContainer = generator(endpoint); - } + private readonly IServiceEndpointManager _serviceEndpointManager; - _serviceEndpointManager.OnAdd += OnAdd; - _serviceEndpointManager.OnRemove += OnRemove; - } + private readonly TimeSpan _scaleTimeout; - public MultiEndpointServiceConnectionContainer( - IServiceConnectionFactory serviceConnectionFactory, - string hub, - int count, - int? maxCount, - IServiceEndpointManager endpointManager, - IMessageRouter router, - ILoggerFactory loggerFactory, - TimeSpan? scaleTimeout = null - ) : this( - hub, - endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, maxCount, loggerFactory), - endpointManager, - router, - loggerFactory, - scaleTimeout) - { - } + private readonly Func _generator; - // for tests - public IEnumerable GetOnlineEndpoints() - { - return _routerEndpoints.endpoints.Where(s => s.Online); - } + private readonly object _lock = new object(); + + private (bool needRouter, IReadOnlyList endpoints) _routerEndpoints; - private static IServiceConnectionContainer CreateContainer(IServiceConnectionFactory serviceConnectionFactory, HubServiceEndpoint endpoint, int count, int? maxCount, ILoggerFactory loggerFactory) + private int _started = 0; + + public ServiceConnectionStatus Status => throw new NotSupportedException(); + + public Task ConnectionInitializedTask + { + get { - if (endpoint.EndpointType == EndpointType.Primary) - { - return new StrongServiceConnectionContainer(serviceConnectionFactory, count, maxCount, endpoint, loggerFactory.CreateLogger()); - } - else - { - return new WeakServiceConnectionContainer(serviceConnectionFactory, count, endpoint, loggerFactory.CreateLogger()); - } + return Task.WhenAll(from connection in _routerEndpoints.endpoints + select connection.ConnectionContainer.ConnectionInitializedTask); } + } - public ServiceConnectionStatus Status => throw new NotSupportedException(); + public string ServersTag => throw new NotSupportedException(); + + public bool HasClients => throw new NotSupportedException(); + + public MultiEndpointServiceConnectionContainer( + IServiceConnectionFactory serviceConnectionFactory, + string hub, + int count, + int? maxCount, + IServiceEndpointManager endpointManager, + IMessageRouter router, + ILoggerFactory loggerFactory, + TimeSpan? scaleTimeout = null + ) : this( + hub, + endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, maxCount, loggerFactory), + endpointManager, + router, + loggerFactory, + scaleTimeout) + { + } - public Task ConnectionInitializedTask + internal MultiEndpointServiceConnectionContainer( + string hub, + Func generator, + IServiceEndpointManager endpointManager, + IMessageRouter router, + ILoggerFactory loggerFactory, + TimeSpan? scaleTimeout = null) + { + if (generator == null) { - get - { - return Task.WhenAll(from connection in _routerEndpoints.endpoints - select connection.ConnectionContainer.ConnectionInitializedTask); - } + throw new ArgumentNullException(nameof(generator)); } - public string ServersTag => throw new NotSupportedException(); + _hubName = hub; + _router = router ?? throw new ArgumentNullException(nameof(router)); + _loggerFactory = loggerFactory; + _logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); + _serviceEndpointManager = endpointManager; + _scaleTimeout = scaleTimeout ?? Constants.Periods.DefaultScaleTimeout; + + // Reserve generator for potential scale use. + _generator = generator; - public bool HasClients => throw new NotSupportedException(); + // provides a copy to the endpoint per container + var endpoints = endpointManager.GetEndpoints(hub); + UpdateRoutedEndpoints(endpoints); - public Task StartAsync() + foreach (var endpoint in endpoints) { - //ensure started only once - return _started == 1 || Interlocked.CompareExchange(ref _started, 1, 0) == 1 - ? Task.CompletedTask - : Task.WhenAll(_routerEndpoints.endpoints.Select(s => - { - Log.StartingConnection(_logger, s.Endpoint); - return s.ConnectionContainer.StartAsync(); - })); + endpoint.ConnectionContainer = generator(endpoint); } - public Task StopAsync() - { - return Task.WhenAll(_routerEndpoints.endpoints.Select(s => + _serviceEndpointManager.OnAdd += OnAdd; + _serviceEndpointManager.OnRemove += OnRemove; + } + + // for tests + public IEnumerable GetOnlineEndpoints() + { + return _routerEndpoints.endpoints.Where(s => s.Online); + } + + public Task StartAsync() + { + //ensure started only once + return _started == 1 || Interlocked.CompareExchange(ref _started, 1, 0) == 1 + ? Task.CompletedTask + : Task.WhenAll(_routerEndpoints.endpoints.Select(s => { - Log.StoppingConnection(_logger, s.Endpoint); - return s.ConnectionContainer.StopAsync(); + Log.StartingConnection(_logger, s.Endpoint); + return s.ConnectionContainer.StartAsync(); })); - } + } - public Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token) + public Task StopAsync() + { + return Task.WhenAll(_routerEndpoints.endpoints.Select(s => { - return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.OfflineAsync(mode, token))); - } + Log.StoppingConnection(_logger, s.Endpoint); + return s.ConnectionContainer.StopAsync(); + })); + } - public Task CloseClientConnections(CancellationToken token) - { - return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.CloseClientConnections(token))); - } + public Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token) + { + return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.OfflineAsync(mode, token))); + } - public Task WriteAsync(ServiceMessage serviceMessage) - { - return CreateMessageWriter(serviceMessage).WriteAsync(serviceMessage); - } + public Task CloseClientConnections(CancellationToken token) + { + return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.CloseClientConnections(token))); + } - public Task WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) - { - return CreateMessageWriter(serviceMessage).WriteAckableMessageAsync(serviceMessage, cancellationToken); - } - public Task StartGetServersPing() - { - return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StartGetServersPing())); - } + public Task WriteAsync(ServiceMessage serviceMessage) + { + return CreateMessageWriter(serviceMessage).WriteAsync(serviceMessage); + } + + public Task WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) + { + return CreateMessageWriter(serviceMessage).WriteAckableMessageAsync(serviceMessage, cancellationToken); + } + + public Task StartGetServersPing() + { + return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StartGetServersPing())); + } + + public Task StopGetServersPing() + { + return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StopGetServersPing())); + } - public Task StopGetServersPing() + public void Dispose() + { + foreach (var container in _routerEndpoints.endpoints) { - return Task.WhenAll(_routerEndpoints.endpoints.Select(c => c.ConnectionContainer.StopGetServersPing())); + container.ConnectionContainer.Dispose(); } + } - public void Dispose() + internal IEnumerable GetRoutedEndpoints(ServiceMessage message) + { + if (!_routerEndpoints.needRouter) { - foreach(var container in _routerEndpoints.endpoints) - { - container.ConnectionContainer.Dispose(); - } + return _routerEndpoints.endpoints; } - - internal IEnumerable GetRoutedEndpoints(ServiceMessage message) + var endpoints = _routerEndpoints.endpoints; + switch (message) { - if (!_routerEndpoints.needRouter) - { - return _routerEndpoints.endpoints; - } - var endpoints = _routerEndpoints.endpoints; - switch (message) - { - case BroadcastDataMessage bdm: - return _router.GetEndpointsForBroadcast(endpoints); + case BroadcastDataMessage bdm: + return _router.GetEndpointsForBroadcast(endpoints); - case GroupBroadcastDataMessage gbdm: - return _router.GetEndpointsForGroup(gbdm.GroupName, endpoints); + case GroupBroadcastDataMessage gbdm: + return _router.GetEndpointsForGroup(gbdm.GroupName, endpoints); - case JoinGroupWithAckMessage jgm: - return _router.GetEndpointsForGroup(jgm.GroupName, endpoints); + case JoinGroupWithAckMessage jgm: + return _router.GetEndpointsForGroup(jgm.GroupName, endpoints); - case LeaveGroupWithAckMessage lgm: - return _router.GetEndpointsForGroup(lgm.GroupName, endpoints); + case LeaveGroupWithAckMessage lgm: + return _router.GetEndpointsForGroup(lgm.GroupName, endpoints); - case MultiGroupBroadcastDataMessage mgbdm: - return mgbdm.GroupList.SelectMany(g => _router.GetEndpointsForGroup(g, endpoints)).Distinct(); + case MultiGroupBroadcastDataMessage mgbdm: + return mgbdm.GroupList.SelectMany(g => _router.GetEndpointsForGroup(g, endpoints)).Distinct(); - case ConnectionDataMessage cdm: - return _router.GetEndpointsForConnection(cdm.ConnectionId, endpoints); + case ConnectionDataMessage cdm: + return _router.GetEndpointsForConnection(cdm.ConnectionId, endpoints); - case MultiConnectionDataMessage mcd: - return mcd.ConnectionList.SelectMany(c => _router.GetEndpointsForConnection(c, endpoints)).Distinct(); + case MultiConnectionDataMessage mcd: + return mcd.ConnectionList.SelectMany(c => _router.GetEndpointsForConnection(c, endpoints)).Distinct(); - case UserDataMessage udm: - return _router.GetEndpointsForUser(udm.UserId, endpoints); + case UserDataMessage udm: + return _router.GetEndpointsForUser(udm.UserId, endpoints); - case MultiUserDataMessage mudm: - return mudm.UserList.SelectMany(g => _router.GetEndpointsForUser(g, endpoints)).Distinct(); + case MultiUserDataMessage mudm: + return mudm.UserList.SelectMany(g => _router.GetEndpointsForUser(g, endpoints)).Distinct(); - case UserJoinGroupMessage ujgm: - return _router.GetEndpointsForGroup(ujgm.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(ujgm.UserId, endpoints)); + case UserJoinGroupMessage ujgm: + return _router.GetEndpointsForGroup(ujgm.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(ujgm.UserId, endpoints)); - case UserJoinGroupWithAckMessage ujgm: - return _router.GetEndpointsForGroup(ujgm.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(ujgm.UserId, endpoints)); + case UserJoinGroupWithAckMessage ujgm: + return _router.GetEndpointsForGroup(ujgm.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(ujgm.UserId, endpoints)); - case UserLeaveGroupMessage ulgm: - return _router.GetEndpointsForGroup(ulgm.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(ulgm.UserId, endpoints)); + case UserLeaveGroupMessage ulgm: + return _router.GetEndpointsForGroup(ulgm.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(ulgm.UserId, endpoints)); - case UserLeaveGroupWithAckMessage ulgm: - return _router.GetEndpointsForGroup(ulgm.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(ulgm.UserId, endpoints)); + case UserLeaveGroupWithAckMessage ulgm: + return _router.GetEndpointsForGroup(ulgm.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(ulgm.UserId, endpoints)); - case CheckConnectionExistenceWithAckMessage checkConnectionMessage: - return _router.GetEndpointsForConnection(checkConnectionMessage.ConnectionId, endpoints); + case CheckConnectionExistenceWithAckMessage checkConnectionMessage: + return _router.GetEndpointsForConnection(checkConnectionMessage.ConnectionId, endpoints); - case CheckUserExistenceWithAckMessage checkUserMessage: - return _router.GetEndpointsForUser(checkUserMessage.UserId, endpoints); + case CheckUserExistenceWithAckMessage checkUserMessage: + return _router.GetEndpointsForUser(checkUserMessage.UserId, endpoints); - case CheckGroupExistenceWithAckMessage checkGroupMessage: - return _router.GetEndpointsForGroup(checkGroupMessage.GroupName, endpoints); + case CheckGroupExistenceWithAckMessage checkGroupMessage: + return _router.GetEndpointsForGroup(checkGroupMessage.GroupName, endpoints); - case CheckUserInGroupWithAckMessage checkUserInGroupMessage: - return _router.GetEndpointsForGroup(checkUserInGroupMessage.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(checkUserInGroupMessage.UserId, endpoints)); + case CheckUserInGroupWithAckMessage checkUserInGroupMessage: + return _router.GetEndpointsForGroup(checkUserInGroupMessage.GroupName, endpoints).Intersect(_router.GetEndpointsForUser(checkUserInGroupMessage.UserId, endpoints)); - case CloseConnectionMessage closeConnectionMessage: - return _router.GetEndpointsForConnection(closeConnectionMessage.ConnectionId, endpoints); + case CloseConnectionMessage closeConnectionMessage: + return _router.GetEndpointsForConnection(closeConnectionMessage.ConnectionId, endpoints); - case ClientInvocationMessage clientInvocationMessage: - return _router.GetEndpointsForConnection(clientInvocationMessage.ConnectionId, endpoints); + case ClientInvocationMessage clientInvocationMessage: + return _router.GetEndpointsForConnection(clientInvocationMessage.ConnectionId, endpoints); - case ServiceCompletionMessage serviceCompletionMessage: - return _router.GetEndpointsForConnection(serviceCompletionMessage.ConnectionId, endpoints); + case ServiceCompletionMessage serviceCompletionMessage: + return _router.GetEndpointsForConnection(serviceCompletionMessage.ConnectionId, endpoints); - // ServiceMappingMessage should never be sent to the service + // ServiceMappingMessage should never be sent to the service - default: - throw new NotSupportedException(message.GetType().Name); - } + default: + throw new NotSupportedException(message.GetType().Name); } + } - private MultiEndpointMessageWriter CreateMessageWriter(ServiceMessage serviceMessage) - { - var targetEndpoints = GetRoutedEndpoints(serviceMessage)?.ToList(); - return new MultiEndpointMessageWriter(targetEndpoints, _loggerFactory); - } + private static IServiceConnectionContainer CreateContainer(IServiceConnectionFactory serviceConnectionFactory, HubServiceEndpoint endpoint, int count, int? maxCount, ILoggerFactory loggerFactory) + { + if (endpoint.EndpointType == EndpointType.Primary) + { + return new StrongServiceConnectionContainer(serviceConnectionFactory, count, maxCount, endpoint, loggerFactory.CreateLogger()); + } + else + { + return new WeakServiceConnectionContainer(serviceConnectionFactory, count, endpoint, loggerFactory.CreateLogger()); + } + } + private MultiEndpointMessageWriter CreateMessageWriter(ServiceMessage serviceMessage) + { + var targetEndpoints = GetRoutedEndpoints(serviceMessage)?.ToList(); + return new MultiEndpointMessageWriter(targetEndpoints, _loggerFactory); + } - private void OnAdd(HubServiceEndpoint endpoint) + private void OnAdd(HubServiceEndpoint endpoint) + { + if (!endpoint.Hub.Equals(_hubName, StringComparison.OrdinalIgnoreCase)) { - if (!endpoint.Hub.Equals(_hubName, StringComparison.OrdinalIgnoreCase)) - { - return; - } - _ = AddHubServiceEndpointAsync(endpoint); + return; } + _ = AddHubServiceEndpointAsync(endpoint); + } - private async Task AddHubServiceEndpointAsync(HubServiceEndpoint endpoint) + private async Task AddHubServiceEndpointAsync(HubServiceEndpoint endpoint) + { + if (endpoint.ScaleTask.IsCompleted) { - if (endpoint.ScaleTask.IsCompleted) - { - UpdateEndpointsStore(endpoint, ScaleOperation.Add); - return; - } + UpdateEndpointsStore(endpoint, ScaleOperation.Add); + return; + } - var container = _generator(endpoint); - endpoint.ConnectionContainer = container; + var container = _generator(endpoint); + endpoint.ConnectionContainer = container; - try - { - _ = container.StartAsync(); + try + { + _ = container.StartAsync(); - await container.ConnectionInitializedTask; + await container.ConnectionInitializedTask; - // Update local store directly after start connection - // to get a uniformed action on trigger servers ping - UpdateEndpointsStore(endpoint, ScaleOperation.Add); + // Update local store directly after start connection + // to get a uniformed action on trigger servers ping + UpdateEndpointsStore(endpoint, ScaleOperation.Add); - await StartGetServersPing(); - await WaitForServerStable(container, endpoint); - } - catch (Exception ex) - { - Log.FailedStartingConnectionForNewEndpoint(_logger, endpoint.ToString(), ex); - } - finally - { - _ = StopGetServersPing(); - endpoint.CompleteScale(); - } + await StartGetServersPing(); + await WaitForServerStable(container, endpoint); } + catch (Exception ex) + { + Log.FailedStartingConnectionForNewEndpoint(_logger, endpoint.ToString(), ex); + } + finally + { + _ = StopGetServersPing(); + endpoint.CompleteScale(); + } + } - private void OnRemove(HubServiceEndpoint endpoint) + private void OnRemove(HubServiceEndpoint endpoint) + { + if (!endpoint.Hub.Equals(_hubName, StringComparison.OrdinalIgnoreCase)) { - if (!endpoint.Hub.Equals(_hubName, StringComparison.OrdinalIgnoreCase)) - { - return; - } - _ = RemoveHubServiceEndpointAsync(endpoint); + return; } + _ = RemoveHubServiceEndpointAsync(endpoint); + } - private async Task RemoveHubServiceEndpointAsync(HubServiceEndpoint endpoint) + private async Task RemoveHubServiceEndpointAsync(HubServiceEndpoint endpoint) + { + if (endpoint.ScaleTask.IsCompleted) + { + UpdateEndpointsStore(endpoint, ScaleOperation.Remove); + return; + } + try { - if (endpoint.ScaleTask.IsCompleted) + var container = _routerEndpoints.endpoints.FirstOrDefault(e => e.Endpoint == endpoint.Endpoint && e.EndpointType == endpoint.EndpointType); + if (container == null) { - UpdateEndpointsStore(endpoint, ScaleOperation.Remove); + Log.EndpointNotExists(_logger, endpoint.ToString()); return; } - try - { - var container = _routerEndpoints.endpoints.FirstOrDefault(e => e.Endpoint == endpoint.Endpoint && e.EndpointType == endpoint.EndpointType); - if (container == null) - { - Log.EndpointNotExists(_logger, endpoint.ToString()); - return; - } - - // TDOO: shall we pass in cancellation token here? - _ = container.ConnectionContainer.OfflineAsync(GracefulShutdownMode.Off, default); - await WaitForClientsDisconnect(container); - - UpdateEndpointsStore(endpoint, ScaleOperation.Remove); - - // Clean up - await container.ConnectionContainer.StopAsync(); - container.ConnectionContainer.Dispose(); - } - catch (Exception ex) - { - Log.FailedRemovingConnectionForEndpoint(_logger, endpoint.ToString(), ex); - } - finally - { - endpoint.CompleteScale(); - } + + // TDOO: shall we pass in cancellation token here? + _ = container.ConnectionContainer.OfflineAsync(GracefulShutdownMode.Off, default); + await WaitForClientsDisconnect(container); + + UpdateEndpointsStore(endpoint, ScaleOperation.Remove); + + // Clean up + await container.ConnectionContainer.StopAsync(); + container.ConnectionContainer.Dispose(); + } + catch (Exception ex) + { + Log.FailedRemovingConnectionForEndpoint(_logger, endpoint.ToString(), ex); } + finally + { + endpoint.CompleteScale(); + } + } - private void UpdateEndpointsStore(HubServiceEndpoint endpoint, ScaleOperation operation) + private void UpdateEndpointsStore(HubServiceEndpoint endpoint, ScaleOperation operation) + { + // Use lock to ensure store update safety as parallel changes triggered in container side. + lock (_lock) { - // Use lock to ensure store update safety as parallel changes triggered in container side. - lock (_lock) + switch (operation) { - switch (operation) - { - case ScaleOperation.Add: - { - var newEndpoints = _routerEndpoints.endpoints.ToList(); - newEndpoints.Add(endpoint); - UpdateRoutedEndpoints(newEndpoints); - break; - } - case ScaleOperation.Remove: - { - var newEndpoints = _routerEndpoints.endpoints.Where(e => !e.Equals(endpoint)).ToList(); - UpdateRoutedEndpoints(newEndpoints); - break; - } - default: + case ScaleOperation.Add: + { + var newEndpoints = _routerEndpoints.endpoints.ToList(); + newEndpoints.Add(endpoint); + UpdateRoutedEndpoints(newEndpoints); + break; + } + case ScaleOperation.Remove: + { + var newEndpoints = _routerEndpoints.endpoints.Where(e => !e.Equals(endpoint)).ToList(); + UpdateRoutedEndpoints(newEndpoints); break; - } + } + default: + break; } } + } - private void UpdateRoutedEndpoints(IReadOnlyList currentEndpoints) - { - // router will be used when there's customized MessageRouter or multiple endpoints - var needRouter = currentEndpoints.Count > 1 || !(_router is DefaultMessageRouter); - _routerEndpoints = (needRouter, currentEndpoints); - } + private void UpdateRoutedEndpoints(IReadOnlyList currentEndpoints) + { + // router will be used when there's customized MessageRouter or multiple endpoints + var needRouter = currentEndpoints.Count > 1 || !(_router is DefaultMessageRouter); + _routerEndpoints = (needRouter, currentEndpoints); + } - private async Task WaitForServerStable(IServiceConnectionContainer container, HubServiceEndpoint endpoint) + private async Task WaitForServerStable(IServiceConnectionContainer container, HubServiceEndpoint endpoint) + { + var startTime = DateTime.UtcNow; + while (DateTime.UtcNow - startTime < _scaleTimeout) { - var startTime = DateTime.UtcNow; - while (DateTime.UtcNow - startTime < _scaleTimeout) + if (IsServerReady(container)) { - if (IsServerReady(container)) - { - return; - } - await Task.Delay(Constants.Periods.DefaultServersPingInterval); + return; } - Log.TimeoutWaitingForAddingEndpoint(_logger, endpoint.ToString(), (int)_scaleTimeout.TotalSeconds); + await Task.Delay(Constants.Periods.DefaultServersPingInterval); } + Log.TimeoutWaitingForAddingEndpoint(_logger, endpoint.ToString(), (int)_scaleTimeout.TotalSeconds); + } - private bool IsServerReady(IServiceConnectionContainer container) + private bool IsServerReady(IServiceConnectionContainer container) + { + var serversOnNew = container.ServersTag; + var allMatch = !string.IsNullOrEmpty(serversOnNew); + if (!allMatch) { - var serversOnNew = container.ServersTag; - var allMatch = !string.IsNullOrEmpty(serversOnNew); + // return directly if local server list is not set yet. + return false; + } + + // ensure strong consistency of server Ids for new endpoint towards exists + foreach (var endpoint in _routerEndpoints.endpoints) + { + allMatch = !string.IsNullOrEmpty(endpoint.ConnectionContainer.ServersTag) + && serversOnNew.Equals(endpoint.ConnectionContainer.ServersTag, StringComparison.OrdinalIgnoreCase) + && allMatch; if (!allMatch) { - // return directly if local server list is not set yet. return false; } - - // ensure strong consistency of server Ids for new endpoint towards exists - foreach (var endpoint in _routerEndpoints.endpoints) - { - allMatch = !string.IsNullOrEmpty(endpoint.ConnectionContainer.ServersTag) - && serversOnNew.Equals(endpoint.ConnectionContainer.ServersTag, StringComparison.OrdinalIgnoreCase) - && allMatch; - if (!allMatch) - { - return false; - } - } - return allMatch; } + return allMatch; + } - private async Task WaitForClientsDisconnect(HubServiceEndpoint endpoint) + private async Task WaitForClientsDisconnect(HubServiceEndpoint endpoint) + { + var startTime = DateTime.UtcNow; + while (DateTime.UtcNow - startTime < _scaleTimeout) { - var startTime = DateTime.UtcNow; - while (DateTime.UtcNow - startTime < _scaleTimeout) + if (!endpoint.ConnectionContainer.HasClients) { - if (!endpoint.ConnectionContainer.HasClients) - { - return; - } - // status ping interval is 10s, quick delay 5s to do next check - await Task.Delay(Constants.Periods.DefaultCloseDelayInterval); + return; } - Log.TimeoutWaitingClientsDisconnect(_logger, endpoint.ToString(), (int)_scaleTimeout.TotalSeconds); + // status ping interval is 10s, quick delay 5s to do next check + await Task.Delay(Constants.Periods.DefaultCloseDelayInterval); } + Log.TimeoutWaitingClientsDisconnect(_logger, endpoint.ToString(), (int)_scaleTimeout.TotalSeconds); + } - private IEnumerable SingleOrNotSupported(IEnumerable endpoints, ServiceMessage message) + private IEnumerable SingleOrNotSupported(IEnumerable endpoints, ServiceMessage message) + { + var endpointCnt = endpoints.ToList().Count; + if (endpointCnt == 1) { - var endpointCnt = endpoints.ToList().Count; - if (endpointCnt == 1) - { - return endpoints; - } - if (endpointCnt == 0) - { - throw new ArgumentException("Client invocation is not sent because no endpoint is returned from the endpoint router."); - } - throw new NotSupportedException("Client invocation to wait for multiple endpoints' results is not supported yet."); + return endpoints; } - - internal static class Log + if (endpointCnt == 0) { - private static readonly Action _startingConnection = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, "StartingConnection"), "Staring connections for endpoint {endpoint}."); + throw new ArgumentException("Client invocation is not sent because no endpoint is returned from the endpoint router."); + } + throw new NotSupportedException("Client invocation to wait for multiple endpoints' results is not supported yet."); + } - private static readonly Action _stoppingConnection = - LoggerMessage.Define(LogLevel.Debug, new EventId(2, "StoppingConnection"), "Stopping connections for endpoint {endpoint}."); + internal static class Log + { + private static readonly Action _startingConnection = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "StartingConnection"), "Staring connections for endpoint {endpoint}."); - private static readonly Action _endpointNotExists = - LoggerMessage.Define(LogLevel.Error, new EventId(3, "EndpointNotExists"), "Endpoint {endpoint} from the router does not exists."); - private static readonly Action _failedStartingConnectionForNewEndpoint = - LoggerMessage.Define(LogLevel.Error, new EventId(7, "FailedStartingConnectionForNewEndpoint"), "Fail to create and start server connection for new endpoint {endpoint}."); + private static readonly Action _stoppingConnection = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "StoppingConnection"), "Stopping connections for endpoint {endpoint}."); - private static readonly Action _timeoutWaitingForAddingEndpoint = - LoggerMessage.Define(LogLevel.Error, new EventId(8, "TimeoutWaitingForAddingEndpoint"), "Timeout waiting for add a new endpoint {endpoint} in {timeoutSecond} seconds. Check if app configurations are consistant and restart app server."); + private static readonly Action _endpointNotExists = + LoggerMessage.Define(LogLevel.Error, new EventId(3, "EndpointNotExists"), "Endpoint {endpoint} from the router does not exists."); - private static readonly Action _timeoutWaitingClientsDisconnect = - LoggerMessage.Define(LogLevel.Error, new EventId(9, "TimeoutWaitingClientsDisconnect"), "Timeout waiting for clients disconnect for {endpoint} in {timeoutSecond} seconds."); + private static readonly Action _failedStartingConnectionForNewEndpoint = + LoggerMessage.Define(LogLevel.Error, new EventId(7, "FailedStartingConnectionForNewEndpoint"), "Fail to create and start server connection for new endpoint {endpoint}."); - private static readonly Action _failedRemovingConnectionForEndpoint = - LoggerMessage.Define(LogLevel.Error, new EventId(10, "FailedRemovingConnectionForEndpoint"), "Fail to stop server connections for endpoint {endpoint}."); + private static readonly Action _timeoutWaitingForAddingEndpoint = + LoggerMessage.Define(LogLevel.Error, new EventId(8, "TimeoutWaitingForAddingEndpoint"), "Timeout waiting for add a new endpoint {endpoint} in {timeoutSecond} seconds. Check if app configurations are consistant and restart app server."); - public static void StartingConnection(ILogger logger, string endpoint) - { - _startingConnection(logger, endpoint, null); - } + private static readonly Action _timeoutWaitingClientsDisconnect = + LoggerMessage.Define(LogLevel.Error, new EventId(9, "TimeoutWaitingClientsDisconnect"), "Timeout waiting for clients disconnect for {endpoint} in {timeoutSecond} seconds."); - public static void StoppingConnection(ILogger logger, string endpoint) - { - _stoppingConnection(logger, endpoint, null); - } + private static readonly Action _failedRemovingConnectionForEndpoint = + LoggerMessage.Define(LogLevel.Error, new EventId(10, "FailedRemovingConnectionForEndpoint"), "Fail to stop server connections for endpoint {endpoint}."); - public static void EndpointNotExists(ILogger logger, string endpoint) - { - _endpointNotExists(logger, endpoint, null); - } + public static void StartingConnection(ILogger logger, string endpoint) + { + _startingConnection(logger, endpoint, null); + } - public static void FailedStartingConnectionForNewEndpoint(ILogger logger, string endpoint, Exception ex) - { - _failedStartingConnectionForNewEndpoint(logger, endpoint, ex); - } + public static void StoppingConnection(ILogger logger, string endpoint) + { + _stoppingConnection(logger, endpoint, null); + } - public static void TimeoutWaitingForAddingEndpoint(ILogger logger, string endpoint, int timeoutSecond) - { - _timeoutWaitingForAddingEndpoint(logger, endpoint, timeoutSecond, null); - } + public static void EndpointNotExists(ILogger logger, string endpoint) + { + _endpointNotExists(logger, endpoint, null); + } - public static void TimeoutWaitingClientsDisconnect(ILogger logger, string endpoint, int timeoutSecond) - { - _timeoutWaitingClientsDisconnect(logger, endpoint, timeoutSecond, null); - } + public static void FailedStartingConnectionForNewEndpoint(ILogger logger, string endpoint, Exception ex) + { + _failedStartingConnectionForNewEndpoint(logger, endpoint, ex); + } - public static void FailedRemovingConnectionForEndpoint(ILogger logger, string endpoint, Exception ex) - { - _failedRemovingConnectionForEndpoint(logger, endpoint, ex); - } + public static void TimeoutWaitingForAddingEndpoint(ILogger logger, string endpoint, int timeoutSecond) + { + _timeoutWaitingForAddingEndpoint(logger, endpoint, timeoutSecond, null); + } + + public static void TimeoutWaitingClientsDisconnect(ILogger logger, string endpoint, int timeoutSecond) + { + _timeoutWaitingClientsDisconnect(logger, endpoint, timeoutSecond, null); + } + + public static void FailedRemovingConnectionForEndpoint(ILogger logger, string endpoint, Exception ex) + { + _failedRemovingConnectionForEndpoint(logger, endpoint, ex); } } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServerConnectionType.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServerConnectionType.cs index c5ee7be5a..75cb614d7 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServerConnectionType.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServerConnectionType.cs @@ -1,21 +1,20 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal enum ServiceConnectionType { - internal enum ServiceConnectionType - { - /// - /// 0, Default, it can carry clients, service runtime should always accept this kind of connection - /// - Default = 0, - /// - /// 1, OnDemand, creating when service requested more connections, it can carry clients, but it may be rejected by service runtime. - /// - OnDemand = 1, - /// - /// 2, Weak, it can not carry clients, but it can send message - /// - Weak = 2, - } + /// + /// 0, Default, it can carry clients, service runtime should always accept this kind of connection + /// + Default = 0, + /// + /// 1, OnDemand, creating when service requested more connections, it can carry clients, but it may be rejected by service runtime. + /// + OnDemand = 1, + /// + /// 2, Weak, it can not carry clients, but it can send message + /// + Weak = 2, } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs index 3523dab68..1f87e2bd5 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs @@ -2,36 +2,44 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; -using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFactory { - internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFactory + private readonly IServiceEndpointOptions _options; + + private readonly ILoggerFactory _loggerFactory; + + private readonly IServiceEndpointManager _serviceEndpointManager; + + private readonly IMessageRouter _router; + + private readonly IServiceConnectionFactory _serviceConnectionFactory; + + public ServiceConnectionContainerFactory(IServiceConnectionFactory serviceConnectionFactory, + IServiceEndpointManager serviceEndpointManager, + IMessageRouter router, + IServiceEndpointOptions options, + ILoggerFactory loggerFactory) + { + _serviceConnectionFactory = serviceConnectionFactory; + _serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager)); + _router = router ?? throw new ArgumentNullException(nameof(router)); + _options = options; + _loggerFactory = loggerFactory; + } + + public IServiceConnectionContainer Create(string hub, TimeSpan? serviceScaleTimeout = null) { - private readonly IServiceEndpointOptions _options; - private readonly ILoggerFactory _loggerFactory; - private readonly IServiceEndpointManager _serviceEndpointManager; - private readonly IMessageRouter _router; - private readonly IServiceConnectionFactory _serviceConnectionFactory; - - public ServiceConnectionContainerFactory( - IServiceConnectionFactory serviceConnectionFactory, - IServiceEndpointManager serviceEndpointManager, - IMessageRouter router, - IServiceEndpointOptions options, - ILoggerFactory loggerFactory) - { - _serviceConnectionFactory = serviceConnectionFactory; - _serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager)); - _router = router ?? throw new ArgumentNullException(nameof(router)); - _options = options; - _loggerFactory = loggerFactory; - } - - public IServiceConnectionContainer Create(string hub, TimeSpan? serviceScaleTimeout = null) - { - return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.InitialHubServerConnectionCount, _options.MaxHubServerConnectionCount, _serviceEndpointManager, _router, _loggerFactory, serviceScaleTimeout); - } + return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, + hub, + _options.InitialHubServerConnectionCount, + _options.MaxHubServerConnectionCount, + _serviceEndpointManager, + _router, + _loggerFactory, + serviceScaleTimeout); } } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceDiagnosticLogsContext.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceDiagnosticLogsContext.cs index 3fc47c4d1..e7a351eac 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceDiagnosticLogsContext.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceDiagnosticLogsContext.cs @@ -1,10 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class ServiceDiagnosticLogsContext { - internal class ServiceDiagnosticLogsContext - { - public bool EnableMessageLog { get; set; } = false; - } + public bool EnableMessageLog { get; set; } = false; } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StatusChange.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StatusChange.cs index 982e43db0..912a08cb6 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StatusChange.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StatusChange.cs @@ -1,17 +1,17 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class StatusChange { - internal class StatusChange - { - public StatusChange(ServiceConnectionStatus oldStatus, ServiceConnectionStatus newStatus) - { - OldStatus = oldStatus; - NewStatus = newStatus; - } + public ServiceConnectionStatus OldStatus { get; } - public ServiceConnectionStatus OldStatus { get; } - public ServiceConnectionStatus NewStatus { get; } + public ServiceConnectionStatus NewStatus { get; } + + public StatusChange(ServiceConnectionStatus oldStatus, ServiceConnectionStatus newStatus) + { + OldStatus = oldStatus; + NewStatus = newStatus; } } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs index c398d6b0c..31cf40784 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs @@ -5,32 +5,31 @@ using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class StrongServiceConnectionContainer : ServiceConnectionContainerBase { - internal class StrongServiceConnectionContainer : ServiceConnectionContainerBase - { - private readonly int? _maxConnectionCount; + private readonly int? _maxConnectionCount; - public StrongServiceConnectionContainer( - IServiceConnectionFactory serviceConnectionFactory, - int fixedConnectionCount, - int? maxConnectionCount, - HubServiceEndpoint endpoint, - ILogger logger) : base(serviceConnectionFactory, fixedConnectionCount, endpoint, logger: logger) - { - _maxConnectionCount = maxConnectionCount.HasValue ? (maxConnectionCount.Value > fixedConnectionCount ? maxConnectionCount.Value : fixedConnectionCount) : null; - } + public StrongServiceConnectionContainer( + IServiceConnectionFactory serviceConnectionFactory, + int fixedConnectionCount, + int? maxConnectionCount, + HubServiceEndpoint endpoint, + ILogger logger) : base(serviceConnectionFactory, fixedConnectionCount, endpoint, logger: logger) + { + _maxConnectionCount = maxConnectionCount.HasValue ? (maxConnectionCount.Value > fixedConnectionCount ? maxConnectionCount.Value : fixedConnectionCount) : null; + } - public override async Task HandlePingAsync(PingMessage pingMessage) + public override async Task HandlePingAsync(PingMessage pingMessage) + { + await base.HandlePingAsync(pingMessage); + if (RuntimeServicePingMessage.TryGetRebalance(pingMessage, out var target) && !string.IsNullOrEmpty(target) + && (_maxConnectionCount == null || ServiceConnections.Count < _maxConnectionCount)) { - await base.HandlePingAsync(pingMessage); - if (RuntimeServicePingMessage.TryGetRebalance(pingMessage, out var target) && !string.IsNullOrEmpty(target) - && (_maxConnectionCount == null || ServiceConnections.Count < _maxConnectionCount)) - { - var connection = CreateServiceConnectionCore(ServiceConnectionType.OnDemand); - AddOnDemandConnection(connection); - await StartCoreAsync(connection, target); - } + var connection = CreateServiceConnectionCore(ServiceConnectionType.OnDemand); + AddOnDemandConnection(connection); + await StartCoreAsync(connection, target); } } } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/WeakServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/WeakServiceConnectionContainer.cs index 20a52e5f9..5701f3016 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/WeakServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/WeakServiceConnectionContainer.cs @@ -6,32 +6,31 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; -namespace Microsoft.Azure.SignalR.Common +namespace Microsoft.Azure.SignalR; + +internal class WeakServiceConnectionContainer : ServiceConnectionContainerBase { - internal class WeakServiceConnectionContainer : ServiceConnectionContainerBase + protected override ServiceConnectionType InitialConnectionType => ServiceConnectionType.Weak; + + public WeakServiceConnectionContainer(IServiceConnectionFactory serviceConnectionFactory, + int fixedConnectionCount, HubServiceEndpoint endpoint, ILogger logger) + : base(serviceConnectionFactory, fixedConnectionCount, endpoint, logger: logger) { - protected override ServiceConnectionType InitialConnectionType => ServiceConnectionType.Weak; + } - public WeakServiceConnectionContainer(IServiceConnectionFactory serviceConnectionFactory, - int fixedConnectionCount, HubServiceEndpoint endpoint, ILogger logger) - : base(serviceConnectionFactory, fixedConnectionCount, endpoint, logger: logger) - { - } + public override Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token) + { + return Task.CompletedTask; + } - public override Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token) - { - return Task.CompletedTask; - } + private static class Log + { + private static readonly Action _ignoreSendingMessageToInactiveEndpoint = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "IgnoreSendingMessageToInactiveEndpoint"), "Message {type} sending to {endpoint} for hub {hub} is ignored because the endpoint is inactive."); - private static class Log + public static void IgnoreSendingMessageToInactiveEndpoint(ILogger logger, Type messageType, HubServiceEndpoint endpoint) { - private static readonly Action _ignoreSendingMessageToInactiveEndpoint = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, "IgnoreSendingMessageToInactiveEndpoint"), "Message {type} sending to {endpoint} for hub {hub} is ignored because the endpoint is inactive."); - - public static void IgnoreSendingMessageToInactiveEndpoint(ILogger logger, Type messageType, HubServiceEndpoint endpoint) - { - _ignoreSendingMessageToInactiveEndpoint(logger, messageType.Name, endpoint, endpoint.Hub, null); - } + _ignoreSendingMessageToInactiveEndpoint(logger, messageType.Name, endpoint, endpoint.Hub, null); } } } diff --git a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs index ab8f566ac..886106890 100644 --- a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs +++ b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs @@ -7,41 +7,41 @@ using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; -namespace Microsoft.Azure.SignalR.IntegrationTests.Infrastructure +namespace Microsoft.Azure.SignalR.IntegrationTests.Infrastructure; + +internal class MockServiceConnectionFactory : ServiceConnectionFactory { - internal class MockServiceConnectionFactory : ServiceConnectionFactory + private IMockService _mockService; + + public MockServiceConnectionFactory( + IMockService mockService, + IServiceProtocol serviceProtocol, + IClientConnectionManager clientConnectionManager, + IConnectionFactory connectionFactory, + ILoggerFactory loggerFactory, + ConnectionDelegate connectionDelegate, + IClientConnectionFactory clientConnectionFactory, + IClientInvocationManager clientInvocationManager, + IServerNameProvider nameProvider, + IHubProtocolResolver hubProtocolResolver) + : base( + serviceProtocol, + clientConnectionManager, + connectionFactory, + loggerFactory, + connectionDelegate, + clientConnectionFactory, + nameProvider, + null, + clientInvocationManager, + hubProtocolResolver) { - private IMockService _mockService; - public MockServiceConnectionFactory( - IMockService mockService, - IServiceProtocol serviceProtocol, - IClientConnectionManager clientConnectionManager, - IConnectionFactory connectionFactory, - ILoggerFactory loggerFactory, - ConnectionDelegate connectionDelegate, - IClientConnectionFactory clientConnectionFactory, - IClientInvocationManager clientInvocationManager, - IServerNameProvider nameProvider, - IHubProtocolResolver hubProtocolResolver) - : base( - serviceProtocol, - clientConnectionManager, - connectionFactory, - loggerFactory, - connectionDelegate, - clientConnectionFactory, - nameProvider, - null, - clientInvocationManager, - hubProtocolResolver) - { - _mockService = mockService; - } + _mockService = mockService; + } - public override IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type) - { - var serviceConnection = base.Create(endpoint, serviceMessageHandler, ackHandler, type); - return new MockServiceConnection(_mockService, serviceConnection); - } + public override IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type) + { + var serviceConnection = base.Create(endpoint, serviceMessageHandler, ackHandler, type); + return new MockServiceConnection(_mockService, serviceConnection); } }