From 64b9805ad97f3738d0d234944a7fbdbd2615b7af Mon Sep 17 00:00:00 2001 From: Terence Fan Date: Tue, 3 Dec 2024 13:04:05 +0800 Subject: [PATCH] Add #nullable for Endpoint related classes (#2101) * Add #nullable for Endpoint classes * update * update * update * update --- .../ServiceEndpointProvider.cs | 149 +- .../Endpoints/HubServiceEndpoint.cs | 12 +- .../Endpoints/ServiceEndpoint.cs | 97 +- .../Utilities/ConnectionStringParser.cs | 51 +- .../Utilities/ParsedConnectionString.cs | 16 +- .../ServiceProtocol.cs | 2487 ++++++++--------- .../DefaultServiceEndpointGenerator.cs | 92 +- .../IServiceEndpointGenerator.cs | 18 +- .../Auth/ConnectionStringParserTests.cs | 14 +- .../ServiceEndpointFacts.cs | 698 +++-- .../MockServiceMessageOrderTestParams.cs | 31 +- .../TestClasses/TestServiceEndpoint.cs | 6 +- .../ClientConnectionContextFacts.cs | 2 +- .../ServiceMessageTests.cs | 4 +- 14 files changed, 1889 insertions(+), 1788 deletions(-) diff --git a/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs b/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs index 615a92fa3..757aff914 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs @@ -8,112 +8,111 @@ using System.Text; using System.Threading.Tasks; -namespace Microsoft.Azure.SignalR.AspNet +namespace Microsoft.Azure.SignalR.AspNet; + +internal class ServiceEndpointProvider : IServiceEndpointProvider { - internal class ServiceEndpointProvider : IServiceEndpointProvider - { - public static readonly string ConnectionStringNotFound = - "No connection string was specified. " + - $"Please specify a configuration entry for {Constants.Keys.ConnectionStringDefaultKey}, " + - "or explicitly pass one using IAppBuilder.RunAzureSignalR(connectionString) in Startup.ConfigureServices."; + public static readonly string ConnectionStringNotFound = + "No connection string was specified. " + + $"Please specify a configuration entry for {Constants.Keys.ConnectionStringDefaultKey}, " + + "or explicitly pass one using IAppBuilder.RunAzureSignalR(connectionString) in Startup.ConfigureServices."; - private const string ClientPath = "aspnetclient"; + private const string ClientPath = "aspnetclient"; - private const string ServerPath = "aspnetserver"; + private const string ServerPath = "aspnetserver"; - private readonly string _audienceBaseUrl; + private readonly string _audienceBaseUrl; - private readonly string _clientEndpoint; + private readonly string _clientEndpoint; - private readonly string _serverEndpoint; + private readonly string _serverEndpoint; - private readonly IAccessKey _accessKey; + private readonly IAccessKey _accessKey; - private readonly string _appName; + private readonly string _appName; - private readonly TimeSpan _accessTokenLifetime; + private readonly TimeSpan _accessTokenLifetime; - private readonly AccessTokenAlgorithm _algorithm; + private readonly AccessTokenAlgorithm _algorithm; - public IWebProxy Proxy { get; } + public IWebProxy Proxy { get; } - public ServiceEndpointProvider(ServiceEndpoint endpoint, ServiceOptions options) - { - _accessTokenLifetime = options.AccessTokenLifetime; + public ServiceEndpointProvider(ServiceEndpoint endpoint, ServiceOptions options) + { + _accessTokenLifetime = options.AccessTokenLifetime; - // Version is ignored for aspnet signalr case - _audienceBaseUrl = endpoint.AudienceBaseUrl; - _clientEndpoint = endpoint.ClientEndpoint.AbsoluteUri; - _serverEndpoint = endpoint.ServerEndpoint.AbsoluteUri; - _accessKey = endpoint.AccessKey; - _appName = options.ApplicationName; - _algorithm = options.AccessTokenAlgorithm; + // Version is ignored for aspnet signalr case + _audienceBaseUrl = endpoint.AudienceBaseUrl; + _clientEndpoint = endpoint.ClientEndpoint.AbsoluteUri; + _serverEndpoint = endpoint.ServerEndpoint.AbsoluteUri; + _accessKey = endpoint.AccessKey; + _appName = options.ApplicationName; + _algorithm = options.AccessTokenAlgorithm; - Proxy = options.Proxy; - } + Proxy = options.Proxy; + } - public Task GenerateClientAccessTokenAsync(string hubName = null, IEnumerable claims = null, TimeSpan? lifetime = null) - { - var audience = $"{_audienceBaseUrl}{ClientPath}"; + public Task GenerateClientAccessTokenAsync(string hubName = null, IEnumerable claims = null, TimeSpan? lifetime = null) + { + var audience = $"{_audienceBaseUrl}{ClientPath}"; - return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm); - } + return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm); + } + + public string GetClientEndpoint(string hubName = null, string originalPath = null, string queryString = null) + { + var queryBuilder = new StringBuilder(); - public string GetClientEndpoint(string hubName = null, string originalPath = null, string queryString = null) + if (!string.IsNullOrEmpty(queryString)) { - var queryBuilder = new StringBuilder(); + queryBuilder.Append(queryString); + } - if (!string.IsNullOrEmpty(queryString)) + if (!string.IsNullOrEmpty(originalPath)) + { + if (queryBuilder.Length == 0) { - queryBuilder.Append(queryString); + queryBuilder.Append("?"); } - - if (!string.IsNullOrEmpty(originalPath)) + else { - if (queryBuilder.Length == 0) - { - queryBuilder.Append("?"); - } - else - { - queryBuilder.Append("&"); - } - - queryBuilder - .Append(Constants.QueryParameter.OriginalPath) - .Append("=") - .Append(WebUtility.UrlEncode(originalPath)); + queryBuilder.Append("&"); } - return $"{_clientEndpoint}{ClientPath}{queryBuilder}"; + queryBuilder + .Append(Constants.QueryParameter.OriginalPath) + .Append("=") + .Append(WebUtility.UrlEncode(originalPath)); } - public string GetServerEndpoint(string hubName) + return $"{_clientEndpoint}{ClientPath}{queryBuilder}"; + } + + public string GetServerEndpoint(string hubName) + { + return $"{_serverEndpoint}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}"; + } + + public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId) + { + if (_accessKey is MicrosoftEntraAccessKey key) { - return $"{_serverEndpoint}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}"; + return new MicrosoftEntraTokenProvider(key); } - - public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId) + else if (_accessKey is AccessKey key2) { - if (_accessKey is MicrosoftEntraAccessKey key) - { - return new MicrosoftEntraTokenProvider(key); - } - else if (_accessKey is AccessKey key2) - { - var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}"; - var claims = serverId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, serverId) } : null; - return new LocalTokenProvider(key2, audience, claims, _algorithm, _accessTokenLifetime); - } - else - { - throw new ArgumentNullException(nameof(AccessKey)); - } + var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}"; + var claims = serverId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, serverId) } : null; + return new LocalTokenProvider(key2, audience, claims, _algorithm, _accessTokenLifetime); } - - private string GetPrefixedHubName(string applicationName, string hubName) + else { - return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}"; + throw new ArgumentNullException(nameof(AccessKey)); } } + + private string GetPrefixedHubName(string applicationName, string hubName) + { + return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}"; + } } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/HubServiceEndpoint.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/HubServiceEndpoint.cs index 12a8375cc..b886bbd1a 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/HubServiceEndpoint.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/HubServiceEndpoint.cs @@ -7,15 +7,15 @@ namespace Microsoft.Azure.SignalR; +#nullable enable + internal class HubServiceEndpoint : ServiceEndpoint { private static long s_currentIndex; private readonly ServiceEndpoint _endpoint; - private readonly long _uniqueIndex; - - private TaskCompletionSource _scaleTcs; + private TaskCompletionSource? _scaleTcs; public string Hub { get; } @@ -23,14 +23,14 @@ internal class HubServiceEndpoint : ServiceEndpoint public IServiceEndpointProvider Provider { get; } - public IServiceConnectionContainer ConnectionContainer { get; set; } + public IServiceConnectionContainer? ConnectionContainer { get; set; } /// /// Task waiting for HubServiceEndpoint turn ready when live add/remove endpoint /// public Task ScaleTask => _scaleTcs?.Task ?? Task.CompletedTask; - public long UniqueIndex => _uniqueIndex; + public long UniqueIndex { get; } // Value here is not accurate. internal override bool PendingReload => throw new NotSupportedException(); @@ -43,7 +43,7 @@ public HubServiceEndpoint(string hub, Provider = provider; _endpoint = endpoint; _scaleTcs = endpoint.PendingReload ? new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously) : null; - _uniqueIndex = Interlocked.Increment(ref s_currentIndex); + UniqueIndex = Interlocked.Increment(ref s_currentIndex); } public void CompleteScale() diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs index 3e98f829f..6292b42f2 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs @@ -7,21 +7,23 @@ namespace Microsoft.Azure.SignalR; +#nullable enable + public class ServiceEndpoint { private readonly Uri _serviceEndpoint; - private readonly Uri _serverEndpoint; + private readonly Uri? _serverEndpoint; - private readonly Uri _clientEndpoint; + private readonly Uri? _clientEndpoint; - private readonly TokenCredential _tokenCredential; + private readonly TokenCredential? _tokenCredential; private readonly object _lock = new object(); - private volatile IAccessKey _accessKey; + private volatile IAccessKey? _accessKey; - public string ConnectionString { get; } + public string? ConnectionString { get; } public EndpointType EndpointType { get; } = EndpointType.Primary; @@ -32,11 +34,7 @@ public class ServiceEndpoint /// public Uri ServerEndpoint { - get => _serverEndpoint ?? _serviceEndpoint; init - { - CheckScheme(value); - _serverEndpoint = value; - } + get => _serverEndpoint ?? _serviceEndpoint; init => _serverEndpoint = CheckScheme(value); } /// @@ -44,11 +42,7 @@ public Uri ServerEndpoint /// public Uri ClientEndpoint { - get => _clientEndpoint ?? _serviceEndpoint; init - { - CheckScheme(value); - _clientEndpoint = value; - } + get => _clientEndpoint ?? _serviceEndpoint; init => _clientEndpoint = CheckScheme(value); } /// @@ -76,7 +70,7 @@ public Uri ClientEndpoint internal string AudienceBaseUrl { get; } - internal string Version { get; } + internal string? Version { get; } internal IAccessKey AccessKey { @@ -84,6 +78,10 @@ internal IAccessKey AccessKey { if (_accessKey is null) { + if (_tokenCredential is null) + { + throw new ArgumentNullException(nameof(_tokenCredential)); + } lock (_lock) { _accessKey ??= new MicrosoftEntraAccessKey(_serviceEndpoint, _tokenCredential, ServerEndpoint); @@ -91,7 +89,6 @@ internal IAccessKey AccessKey } return _accessKey; } - private init => _accessKey = value; } // Flag to indicate an updaing endpoint needs staging @@ -123,13 +120,14 @@ public ServiceEndpoint(string connectionString, EndpointType type = EndpointType ConnectionString = connectionString; var result = ConnectionStringParser.Parse(connectionString); - AccessKey = result.AccessKey; - _serviceEndpoint = result.Endpoint; - ClientEndpoint = result.ClientEndpoint; - ServerEndpoint = result.ServerEndpoint; EndpointType = type; Name = name; + _accessKey = result.AccessKey; + _serviceEndpoint = result.Endpoint; + _clientEndpoint = result.ClientEndpoint; + _serverEndpoint = result.ServerEndpoint; + Endpoint = BuildEndpointString(_serviceEndpoint); AudienceBaseUrl = BuildAudienceBaseUrlEndWithSlash(_serviceEndpoint); } @@ -159,21 +157,20 @@ public ServiceEndpoint(Uri endpoint, TokenCredential credential, EndpointType endpointType = EndpointType.Primary, string name = "", - Uri serverEndpoint = null, - Uri clientEndpoint = null) + Uri? serverEndpoint = null, + Uri? clientEndpoint = null) { - _serviceEndpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint)); - CheckScheme(endpoint); - _tokenCredential = credential ?? throw new ArgumentNullException(nameof(credential)); EndpointType = endpointType; Name = name; + _serviceEndpoint = CheckScheme(endpoint); + _serverEndpoint = serverEndpoint == null ? serverEndpoint : CheckScheme(serverEndpoint); + _clientEndpoint = clientEndpoint == null ? clientEndpoint : CheckScheme(clientEndpoint); + AudienceBaseUrl = BuildAudienceBaseUrlEndWithSlash(_serviceEndpoint); Endpoint = BuildEndpointString(_serviceEndpoint); - ServerEndpoint = serverEndpoint; - ClientEndpoint = clientEndpoint; } /// @@ -182,19 +179,18 @@ public ServiceEndpoint(Uri endpoint, /// public ServiceEndpoint(ServiceEndpoint other) { - if (other != null) - { - ConnectionString = other.ConnectionString; - EndpointType = other.EndpointType; - Name = other.Name; - Version = other.Version; - AccessKey = other.AccessKey; - Endpoint = other.Endpoint; - ClientEndpoint = other.ClientEndpoint; - ServerEndpoint = other.ServerEndpoint; - AudienceBaseUrl = other.AudienceBaseUrl; - _serviceEndpoint = other._serviceEndpoint; - } + ConnectionString = other.ConnectionString; + EndpointType = other.EndpointType; + Name = other.Name; + Version = other.Version; + Endpoint = other.Endpoint; + AudienceBaseUrl = other.AudienceBaseUrl; + + _accessKey = other._accessKey; + _tokenCredential = other._tokenCredential; + _serviceEndpoint = other._serviceEndpoint; + _clientEndpoint = other._clientEndpoint; + _serverEndpoint = other._serverEndpoint; } public override string ToString() @@ -211,7 +207,7 @@ public override int GetHashCode() return (Endpoint, EndpointType, Name, ClientEndpoint, ServerEndpoint).GetHashCode(); } - public override bool Equals(object obj) + public override bool Equals(object? obj) { if (obj == null) { @@ -231,14 +227,14 @@ public override bool Equals(object obj) return (Name, Endpoint, EndpointType, ClientEndpoint, ServerEndpoint) == (that.Name, that.Endpoint, that.EndpointType, that.ClientEndpoint, that.ServerEndpoint); } - private static string BuildAudienceBaseUrlEndWithSlash(Uri uri) + internal static string BuildEndpointString(Uri uri) { - return $"{uri.Scheme}://{uri.Host}/"; + return new Uri($"{uri.Scheme}://{uri.Host}:{uri.Port}").AbsoluteUri.TrimEnd('/'); } - private static string BuildEndpointString(Uri uri) + private static string BuildAudienceBaseUrlEndWithSlash(Uri uri) { - return new Uri($"{uri.Scheme}://{uri.Host}:{uri.Port}").AbsoluteUri.TrimEnd('/'); + return $"{uri.Scheme}://{uri.Host}/"; } private static (string, EndpointType) Parse(string nameWithEndpointType) @@ -263,11 +259,10 @@ private static (string, EndpointType) Parse(string nameWithEndpointType) } } - private static void CheckScheme(Uri uri) + private static Uri CheckScheme(Uri uri) { - if (uri != null && uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps) - { - throw new ArgumentException("Endpoint scheme must be 'http://' or 'https://'"); - } + return uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps + ? throw new ArgumentException("Endpoint scheme must be 'http://' or 'https://'") + : uri; } } diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs index d01274fcb..f4668c479 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs @@ -8,6 +8,8 @@ namespace Microsoft.Azure.SignalR; +#nullable enable + internal static class ConnectionStringParser { private const string AccessKeyProperty = "accesskey"; @@ -16,15 +18,15 @@ internal static class ConnectionStringParser private const string ClientCertProperty = "clientCert"; - private const string ClientEndpointProperty = "clientEndpoint"; - private const string ClientIdProperty = "clientId"; private const string ClientSecretProperty = "clientSecret"; private const string EndpointProperty = "endpoint"; - private const string ServerEndpointProperty = "ServerEndpoint"; + private const string ClientEndpointProperty = "clientEndpoint"; + + private const string ServerEndpointProperty = "serverEndpoint"; private const string InvalidVersionValueFormat = "Version {0} is not supported."; @@ -47,11 +49,13 @@ internal static class ConnectionStringParser private const string VersionProperty = "version"; + private static readonly string InvalidEndpointProperty = $"Invalid value for {EndpointProperty} property, it must be a valid URI."; + private static readonly string InvalidClientEndpointProperty = $"Invalid value for {ClientEndpointProperty} property, it must be a valid URI."; - private static readonly string InvalidEndpointProperty = $"Invalid value for {EndpointProperty} property, it must be a valid URI."; + private static readonly string InvalidServerEndpointProperty = $"Invalid value for {ServerEndpointProperty} property, it must be a valid URI."; - private static readonly string InvalidPortValue = $"Invalid value for {PortProperty} property, it must be an positive integer between (0, 65536)"; + private static readonly string InvalidPortValue = $"Invalid value for {PortProperty} property, it must be an positive integer between (0, 65536)."; private static readonly char[] KeyValueSeparator = { '=' }; @@ -65,10 +69,10 @@ internal static class ConnectionStringParser $"Connection string missing required properties {ClientSecretProperty} or {ClientCertProperty}."; private static readonly string MissingEndpointProperty = - $"Connection string missing required properties {EndpointProperty}."; + $"Connection string missing required properties {EndpointProperty}."; private static readonly string MissingTenantIdProperty = - $"Connection string missing required properties {TenantIdProperty}."; + $"Connection string missing required properties {TenantIdProperty}."; private static readonly char[] PropertySeparator = { ';' }; @@ -83,14 +87,14 @@ internal static ParsedConnectionString Parse(string connectionString) } endpoint = endpoint.TrimEnd('/'); - if (!TryGetEndpointUri(endpoint, out var endpointUri)) + if (!TryCreateEndpointUri(endpoint, out var endpointUri)) { throw new ArgumentException(InvalidEndpointProperty, nameof(endpoint)); } - var builder = new UriBuilder(endpointUri); + var builder = new UriBuilder(endpointUri!); // parse and validate version. - string version = null; + string? version = null; if (dict.TryGetValue(VersionProperty, out var v)) { if (!Regex.IsMatch(v, ValidVersionRegex)) @@ -113,13 +117,13 @@ internal static ParsedConnectionString Parse(string connectionString) } } - Uri clientEndpointUri = null; - Uri serverEndpointUri = null; + Uri? clientEndpointUri = null; + Uri? serverEndpointUri = null; // parse and validate clientEndpoint. if (dict.TryGetValue(ClientEndpointProperty, out var clientEndpoint)) { - if (!TryGetEndpointUri(clientEndpoint, out clientEndpointUri)) + if (!TryCreateEndpointUri(clientEndpoint, out clientEndpointUri)) { throw new ArgumentException(InvalidClientEndpointProperty, nameof(clientEndpoint)); } @@ -128,9 +132,9 @@ internal static ParsedConnectionString Parse(string connectionString) // parse and validate clientEndpoint. if (dict.TryGetValue(ServerEndpointProperty, out var serverEndpoint)) { - if (!TryGetEndpointUri(serverEndpoint, out serverEndpointUri)) + if (!TryCreateEndpointUri(serverEndpoint, out serverEndpointUri)) { - throw new ArgumentException($"{ServerEndpointProperty} property in connection string is not a valid URI: {serverEndpoint}."); + throw new ArgumentException(InvalidServerEndpointProperty, nameof(serverEndpoint)); } } @@ -145,9 +149,8 @@ internal static ParsedConnectionString Parse(string connectionString) _ => BuildAccessKey(builder.Uri, dict), }; - return new ParsedConnectionString() + return new ParsedConnectionString(builder.Uri) { - Endpoint = builder.Uri, ClientEndpoint = clientEndpointUri, AccessKey = accessKey, Version = version, @@ -155,13 +158,13 @@ internal static ParsedConnectionString Parse(string connectionString) }; } - internal static bool TryGetEndpointUri(string endpoint, out Uri uriResult) + private static bool TryCreateEndpointUri(string endpoint, out Uri? uriResult) { - return Uri.TryCreate(endpoint, UriKind.Absolute, out uriResult) && - (uriResult.Scheme == Uri.UriSchemeHttp || uriResult.Scheme == Uri.UriSchemeHttps); + return Uri.TryCreate(endpoint, UriKind.Absolute, out uriResult) + && (uriResult.Scheme == Uri.UriSchemeHttp || uriResult.Scheme == Uri.UriSchemeHttps); } - private static IAccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) + private static IAccessKey BuildAzureADAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary dict) { if (dict.TryGetValue(ClientIdProperty, out var clientId)) { @@ -198,12 +201,12 @@ private static IAccessKey BuildAccessKey(Uri uri, Dictionary dic : throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty); } - private static IAccessKey BuildAzureAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) + private static IAccessKey BuildAzureAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary dict) { return new MicrosoftEntraAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri); } - private static IAccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) + private static IAccessKey BuildAzureAppAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary dict) { if (!dict.TryGetValue(ClientIdProperty, out var clientId)) { @@ -226,7 +229,7 @@ private static IAccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty); } - private static IAccessKey BuildAzureMsiAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) + private static IAccessKey BuildAzureMsiAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary dict) { return dict.TryGetValue(ClientIdProperty, out var clientId) ? new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri) diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/ParsedConnectionString.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/ParsedConnectionString.cs index a027076b0..b84f6685c 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/ParsedConnectionString.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/ParsedConnectionString.cs @@ -5,16 +5,22 @@ namespace Microsoft.Azure.SignalR; +#nullable enable + internal class ParsedConnectionString { - internal IAccessKey AccessKey { get; set; } + internal Uri Endpoint { get; } - internal Uri Endpoint { get; set; } + internal IAccessKey? AccessKey { get; set; } - internal Uri ClientEndpoint { get; set; } + internal Uri? ClientEndpoint { get; set; } - internal Uri ServerEndpoint { get; set; } + internal Uri? ServerEndpoint { get; set; } - internal string Version { get; set; } + internal string? Version { get; set; } + public ParsedConnectionString(Uri endpoint) + { + Endpoint = endpoint; + } } diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs index 0cbdf0427..133a770a3 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs @@ -1,1533 +1,1532 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -#nullable enable using System; using System.Buffers; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.Diagnostics; using System.IO; using System.Security.Claims; using MessagePack; using Microsoft.Extensions.Primitives; -namespace Microsoft.Azure.SignalR.Protocol -{ - /// - /// Implements the Azure SignalR Service Protocol. - /// - public class ServiceProtocol : IServiceProtocol - { - private static readonly IDictionary> EmptyReadOnlyMemoryDictionary = new Dictionary>(); +namespace Microsoft.Azure.SignalR.Protocol; - private static readonly IDictionary EmptyStringValuesDictionaryIgnoreCase = new Dictionary(StringComparer.OrdinalIgnoreCase); +#nullable enable - private static readonly int ProtocolVersion = 1; +/// +/// Implements the Azure SignalR Service Protocol. +/// +public class ServiceProtocol : IServiceProtocol +{ + private static readonly IDictionary> EmptyReadOnlyMemoryDictionary = new Dictionary>(); - /// - public int Version => ProtocolVersion; + private static readonly IDictionary EmptyStringValuesDictionaryIgnoreCase = new Dictionary(StringComparer.OrdinalIgnoreCase); - /// - public bool TryParseMessage(ref ReadOnlySequence input, out ServiceMessage? message) - { - if (!BinaryMessageParser.TryParseMessage(ref input, out var payload)) - { - message = null; - return false; - } + private static readonly int ProtocolVersion = 1; - var reader = new MessagePackReader(payload); + /// + public int Version => ProtocolVersion; - message = ParseMessage(ref reader); - return true; + /// + public bool TryParseMessage(ref ReadOnlySequence input, out ServiceMessage? message) + { + if (!BinaryMessageParser.TryParseMessage(ref input, out var payload)) + { + message = null; + return false; } - private static ServiceMessage? ParseMessage(ref MessagePackReader reader) - { - var arrayLength = reader.ReadArrayHeader(); + var reader = new MessagePackReader(payload); - var messageType = reader.ReadInt32(); + message = ParseMessage(ref reader); + return true; + } - switch (messageType) - { - case ServiceProtocolConstants.HandshakeRequestType: - return CreateHandshakeRequestMessage(ref reader, arrayLength); - case ServiceProtocolConstants.HandshakeResponseType: - return CreateHandshakeResponseMessage(ref reader, arrayLength); - case ServiceProtocolConstants.AccessKeyRequestType: - return CreateAccessKeyRequestMessage(ref reader, arrayLength); - case ServiceProtocolConstants.AccessKeyResponseType: - return CreateAccessKeyResponseMessage(ref reader, arrayLength); - case ServiceProtocolConstants.PingMessageType: - return CreatePingMessage(ref reader, arrayLength); - case ServiceProtocolConstants.OpenConnectionMessageType: - return CreateOpenConnectionMessage(ref reader, arrayLength); - case ServiceProtocolConstants.CloseConnectionMessageType: - return CreateCloseConnectionMessage(ref reader, arrayLength); - case ServiceProtocolConstants.ConnectionDataMessageType: - return CreateConnectionDataMessage(ref reader, arrayLength); - case ServiceProtocolConstants.ConnectionReconnectMessageType: - return CreateConnectionReconnectMessage(ref reader, arrayLength); - case ServiceProtocolConstants.MultiConnectionDataMessageType: - return CreateMultiConnectionDataMessage(ref reader, arrayLength); - case ServiceProtocolConstants.UserDataMessageType: - return CreateUserDataMessage(ref reader, arrayLength); - case ServiceProtocolConstants.MultiUserDataMessageType: - return CreateMultiUserDataMessage(ref reader, arrayLength); - case ServiceProtocolConstants.BroadcastDataMessageType: - return CreateBroadcastDataMessage(ref reader, arrayLength); - case ServiceProtocolConstants.JoinGroupMessageType: - return CreateJoinGroupMessage(ref reader, arrayLength); - case ServiceProtocolConstants.LeaveGroupMessageType: - return CreateLeaveGroupMessage(ref reader, arrayLength); - case ServiceProtocolConstants.UserJoinGroupMessageType: - return CreateUserJoinGroupMessage(ref reader, arrayLength); - case ServiceProtocolConstants.UserLeaveGroupMessageType: - return CreateUserLeaveGroupMessage(ref reader, arrayLength); - case ServiceProtocolConstants.UserJoinGroupWithAckMessageType: - return CreateUserJoinGroupWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.UserLeaveGroupWithAckMessageType: - return CreateUserLeaveGroupWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.GroupBroadcastDataMessageType: - return CreateGroupBroadcastDataMessage(ref reader, arrayLength); - case ServiceProtocolConstants.MultiGroupBroadcastDataMessageType: - return CreateMultiGroupBroadcastDataMessage(ref reader, arrayLength); - case ServiceProtocolConstants.ServiceErrorMessageType: - return CreateServiceErrorMessage(ref reader); - case ServiceProtocolConstants.ServiceEventMessageType: - return CreateServiceEventMessage(ref reader); - case ServiceProtocolConstants.JoinGroupWithAckMessageType: - return CreateJoinGroupWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.LeaveGroupWithAckMessageType: - return CreateLeaveGroupWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.CheckUserInGroupWithAckMessageType: - return CreateCheckUserInGroupWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.CheckGroupExistenceWithAckMessageType: - return CreateGroupExistenceWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.CheckConnectionExistenceWithAckMessageType: - return CreateCheckConnectionExistenceWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.CheckUserExistenceWithAckMessageType: - return CreateCheckUserExistenceWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.CloseConnectionsWithAckMessageType: + private static ServiceMessage? ParseMessage(ref MessagePackReader reader) + { + var arrayLength = reader.ReadArrayHeader(); + + var messageType = reader.ReadInt32(); + + switch (messageType) + { + case ServiceProtocolConstants.HandshakeRequestType: + return CreateHandshakeRequestMessage(ref reader, arrayLength); + case ServiceProtocolConstants.HandshakeResponseType: + return CreateHandshakeResponseMessage(ref reader, arrayLength); + case ServiceProtocolConstants.AccessKeyRequestType: + return CreateAccessKeyRequestMessage(ref reader, arrayLength); + case ServiceProtocolConstants.AccessKeyResponseType: + return CreateAccessKeyResponseMessage(ref reader, arrayLength); + case ServiceProtocolConstants.PingMessageType: + return CreatePingMessage(ref reader, arrayLength); + case ServiceProtocolConstants.OpenConnectionMessageType: + return CreateOpenConnectionMessage(ref reader, arrayLength); + case ServiceProtocolConstants.CloseConnectionMessageType: + return CreateCloseConnectionMessage(ref reader, arrayLength); + case ServiceProtocolConstants.ConnectionDataMessageType: + return CreateConnectionDataMessage(ref reader, arrayLength); + case ServiceProtocolConstants.ConnectionReconnectMessageType: + return CreateConnectionReconnectMessage(ref reader, arrayLength); + case ServiceProtocolConstants.MultiConnectionDataMessageType: + return CreateMultiConnectionDataMessage(ref reader, arrayLength); + case ServiceProtocolConstants.UserDataMessageType: + return CreateUserDataMessage(ref reader, arrayLength); + case ServiceProtocolConstants.MultiUserDataMessageType: + return CreateMultiUserDataMessage(ref reader, arrayLength); + case ServiceProtocolConstants.BroadcastDataMessageType: + return CreateBroadcastDataMessage(ref reader, arrayLength); + case ServiceProtocolConstants.JoinGroupMessageType: + return CreateJoinGroupMessage(ref reader, arrayLength); + case ServiceProtocolConstants.LeaveGroupMessageType: + return CreateLeaveGroupMessage(ref reader, arrayLength); + case ServiceProtocolConstants.UserJoinGroupMessageType: + return CreateUserJoinGroupMessage(ref reader, arrayLength); + case ServiceProtocolConstants.UserLeaveGroupMessageType: + return CreateUserLeaveGroupMessage(ref reader, arrayLength); + case ServiceProtocolConstants.UserJoinGroupWithAckMessageType: + return CreateUserJoinGroupWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.UserLeaveGroupWithAckMessageType: + return CreateUserLeaveGroupWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.GroupBroadcastDataMessageType: + return CreateGroupBroadcastDataMessage(ref reader, arrayLength); + case ServiceProtocolConstants.MultiGroupBroadcastDataMessageType: + return CreateMultiGroupBroadcastDataMessage(ref reader, arrayLength); + case ServiceProtocolConstants.ServiceErrorMessageType: + return CreateServiceErrorMessage(ref reader); + case ServiceProtocolConstants.ServiceEventMessageType: + return CreateServiceEventMessage(ref reader); + case ServiceProtocolConstants.JoinGroupWithAckMessageType: + return CreateJoinGroupWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.LeaveGroupWithAckMessageType: + return CreateLeaveGroupWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.CheckUserInGroupWithAckMessageType: + return CreateCheckUserInGroupWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.CheckGroupExistenceWithAckMessageType: + return CreateGroupExistenceWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.CheckConnectionExistenceWithAckMessageType: + return CreateCheckConnectionExistenceWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.CheckUserExistenceWithAckMessageType: + return CreateCheckUserExistenceWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.CloseConnectionsWithAckMessageType: #pragma warning disable CS0612 // Type or member is obsolete - return CreateCloseConnectionsWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.CloseConnectionWithAckMessageType: - return CreateCloseConnectionWithAckMessage(ref reader, arrayLength); + return CreateCloseConnectionsWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.CloseConnectionWithAckMessageType: + return CreateCloseConnectionWithAckMessage(ref reader, arrayLength); #pragma warning restore CS0612 // Type or member is obsolete - case ServiceProtocolConstants.CloseUserConnectionsWithAckMessageType: - return CreateCloseUserConnectionsWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.CloseGroupConnectionsWithAckMessageType: - return CreateCloseGroupConnectionsWithAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.AckMessageType: - return CreateAckMessage(ref reader, arrayLength); - case ServiceProtocolConstants.ClientInvocationMessageType: - return CreateClientInvocationMessage(ref reader, arrayLength); - case ServiceProtocolConstants.ClientCompletionMessageType: - return CreateClientCompletionMessage(ref reader, arrayLength); - case ServiceProtocolConstants.ErrorCompletionMessageType: - return CreateErrorCompletionMessage(ref reader, arrayLength); - case ServiceProtocolConstants.ServiceMappingMessageType: - return CreateServiceMappingMessage(ref reader, arrayLength); - case ServiceProtocolConstants.ConnectionFlowControlMessageType: - return CreateConnectionFlowControlMessage(ref reader, arrayLength); - default: - // Future protocol changes can add message types, old clients can ignore them - return null; - } + case ServiceProtocolConstants.CloseUserConnectionsWithAckMessageType: + return CreateCloseUserConnectionsWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.CloseGroupConnectionsWithAckMessageType: + return CreateCloseGroupConnectionsWithAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.AckMessageType: + return CreateAckMessage(ref reader, arrayLength); + case ServiceProtocolConstants.ClientInvocationMessageType: + return CreateClientInvocationMessage(ref reader, arrayLength); + case ServiceProtocolConstants.ClientCompletionMessageType: + return CreateClientCompletionMessage(ref reader, arrayLength); + case ServiceProtocolConstants.ErrorCompletionMessageType: + return CreateErrorCompletionMessage(ref reader, arrayLength); + case ServiceProtocolConstants.ServiceMappingMessageType: + return CreateServiceMappingMessage(ref reader, arrayLength); + case ServiceProtocolConstants.ConnectionFlowControlMessageType: + return CreateConnectionFlowControlMessage(ref reader, arrayLength); + default: + // Future protocol changes can add message types, old clients can ignore them + return null; } + } - /// - public void WriteMessage(ServiceMessage message, IBufferWriter output) - { - var memoryBufferWriter = MemoryBufferWriter.Get(); + /// + public void WriteMessage(ServiceMessage message, IBufferWriter output) + { + var memoryBufferWriter = MemoryBufferWriter.Get(); - try - { - var writer = new MessagePackWriter(memoryBufferWriter); + try + { + var writer = new MessagePackWriter(memoryBufferWriter); - // Write message to a buffer so we can get its length - WriteMessageCore(ref writer, message); + // Write message to a buffer so we can get its length + WriteMessageCore(ref writer, message); - // Write length then message to output - BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output); - memoryBufferWriter.CopyTo(output); - } - finally - { - MemoryBufferWriter.Return(memoryBufferWriter); - } + // Write length then message to output + BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output); + memoryBufferWriter.CopyTo(output); } - - /// - public ReadOnlyMemory GetMessageBytes(ServiceMessage message) + finally { - var memoryBufferWriter = MemoryBufferWriter.Get(); + MemoryBufferWriter.Return(memoryBufferWriter); + } + } - try - { - var writer = new MessagePackWriter(memoryBufferWriter); + /// + public ReadOnlyMemory GetMessageBytes(ServiceMessage message) + { + var memoryBufferWriter = MemoryBufferWriter.Get(); - // Write message to a buffer so we can get its length - WriteMessageCore(ref writer, message); + try + { + var writer = new MessagePackWriter(memoryBufferWriter); - var dataLength = memoryBufferWriter.Length; - var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length); + // Write message to a buffer so we can get its length + WriteMessageCore(ref writer, message); - var array = new byte[dataLength + prefixLength]; - var span = array.AsSpan(); + var dataLength = memoryBufferWriter.Length; + var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length); - // Write length then message to output - var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span); - Debug.Assert(written == prefixLength); - memoryBufferWriter.CopyTo(span.Slice(prefixLength)); + var array = new byte[dataLength + prefixLength]; + var span = array.AsSpan(); - return array; - } - finally - { - MemoryBufferWriter.Return(memoryBufferWriter); - } - } + // Write length then message to output + var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span); + Debug.Assert(written == prefixLength); + memoryBufferWriter.CopyTo(span.Slice(prefixLength)); - private static void WriteMessageCore(ref MessagePackWriter writer, ServiceMessage message) + return array; + } + finally { + MemoryBufferWriter.Return(memoryBufferWriter); + } + } + + private static void WriteMessageCore(ref MessagePackWriter writer, ServiceMessage message) + { #pragma warning disable CS0618 // Type or member is obsolete - switch (message) - { - case HandshakeRequestMessage handshakeRequestMessage: - WriteHandshakeRequestMessage(ref writer, handshakeRequestMessage); - break; - case HandshakeResponseMessage handshakeResponseMessage: - WriteHandshakeResponseMessage(ref writer, handshakeResponseMessage); - break; - case AccessKeyRequestMessage accessKeyRequestMessage: - WriteAccessKeyRequestMessage(ref writer, accessKeyRequestMessage); - break; - case AccessKeyResponseMessage accessKeyResponseMessage: - WriteAccessKeyResponseMessage(ref writer, accessKeyResponseMessage); - break; - case PingMessage pingMessage: - WritePingMessage(ref writer, pingMessage); - break; - case OpenConnectionMessage openConnectionMessage: - WriteOpenConnectionMessage(ref writer, openConnectionMessage); - break; - case CloseConnectionMessage closeConnectionMessage: - WriteCloseConnectionMessage(ref writer, closeConnectionMessage); - break; - case ConnectionDataMessage connectionDataMessage: - WriteConnectionDataMessage(ref writer, connectionDataMessage); - break; - case ConnectionReconnectMessage connectionReconnectMessage: - WriteConnectionReconnectMessage(ref writer, connectionReconnectMessage); - break; - case MultiConnectionDataMessage multiConnectionDataMessage: - WriteMultiConnectionDataMessage(ref writer, multiConnectionDataMessage); - break; - case UserDataMessage userDataMessage: - WriteUserDataMessage(ref writer, userDataMessage); - break; - case MultiUserDataMessage multiUserDataMessage: - WriteMultiUserDataMessage(ref writer, multiUserDataMessage); - break; - case BroadcastDataMessage broadcastDataMessage: - WriteBroadcastDataMessage(ref writer, broadcastDataMessage); - break; - case JoinGroupMessage joinGroupMessage: - WriteJoinGroupMessage(ref writer, joinGroupMessage); - break; - case JoinGroupWithAckMessage joinGroupWithAckMessage: - WriteJoinGroupWithAckMessage(ref writer, joinGroupWithAckMessage); - break; - case LeaveGroupMessage leaveGroupMessage: - WriteLeaveGroupMessage(ref writer, leaveGroupMessage); - break; - case LeaveGroupWithAckMessage leaveGroupWithAckMessage: - WriteLeaveGroupWithAckMessage(ref writer, leaveGroupWithAckMessage); - break; - case CheckUserInGroupWithAckMessage checkUserInGroupWithAckMessage: - WriteCheckUserInGroupWithAckMessage(ref writer, checkUserInGroupWithAckMessage); - break; - case CheckGroupExistenceWithAckMessage checkAnyConnectionInGroupWithAckMessage: - WriteCheckGroupExistenceWithAckMessage(ref writer, checkAnyConnectionInGroupWithAckMessage); - break; - case CheckConnectionExistenceWithAckMessage checkConnectionExistenceWithAckMessage: - WriteCheckConnectionExistenceWithAckMessage(ref writer, checkConnectionExistenceWithAckMessage); - break; - case CheckUserExistenceWithAckMessage checkConnectionExistenceAsUserWithAckMessage: - WriteCheckUserExistenceWithAckMessage(ref writer, checkConnectionExistenceAsUserWithAckMessage); - break; - case UserJoinGroupMessage userJoinGroupMessage: - WriteUserJoinGroupMessage(ref writer, userJoinGroupMessage); - break; - case UserLeaveGroupMessage userLeaveGroupMessage: - WriteUserLeaveGroupMessage(ref writer, userLeaveGroupMessage); - break; - case UserJoinGroupWithAckMessage userJoinGroupWithAckMessage: - WriteUserJoinGroupWithAckMessage(ref writer, userJoinGroupWithAckMessage); - break; - case UserLeaveGroupWithAckMessage userLeaveGroupWithAckMessage: - WriteUserLeaveGroupWithAckMessage(ref writer, userLeaveGroupWithAckMessage); - break; - case GroupBroadcastDataMessage groupBroadcastDataMessage: - WriteGroupBroadcastDataMessage(ref writer, groupBroadcastDataMessage); - break; - case MultiGroupBroadcastDataMessage multiGroupBroadcastDataMessage: - WriteMultiGroupBroadcastDataMessage(ref writer, multiGroupBroadcastDataMessage); - break; - case ServiceErrorMessage serviceErrorMessage: - WriteServiceErrorMessage(ref writer, serviceErrorMessage); - break; - case ServiceEventMessage serviceWarningMessage: - WriteServiceEventMessage(ref writer, serviceWarningMessage); - break; - case CloseConnectionWithAckMessage closeConnectionWithAckMessage: + switch (message) + { + case HandshakeRequestMessage handshakeRequestMessage: + WriteHandshakeRequestMessage(ref writer, handshakeRequestMessage); + break; + case HandshakeResponseMessage handshakeResponseMessage: + WriteHandshakeResponseMessage(ref writer, handshakeResponseMessage); + break; + case AccessKeyRequestMessage accessKeyRequestMessage: + WriteAccessKeyRequestMessage(ref writer, accessKeyRequestMessage); + break; + case AccessKeyResponseMessage accessKeyResponseMessage: + WriteAccessKeyResponseMessage(ref writer, accessKeyResponseMessage); + break; + case PingMessage pingMessage: + WritePingMessage(ref writer, pingMessage); + break; + case OpenConnectionMessage openConnectionMessage: + WriteOpenConnectionMessage(ref writer, openConnectionMessage); + break; + case CloseConnectionMessage closeConnectionMessage: + WriteCloseConnectionMessage(ref writer, closeConnectionMessage); + break; + case ConnectionDataMessage connectionDataMessage: + WriteConnectionDataMessage(ref writer, connectionDataMessage); + break; + case ConnectionReconnectMessage connectionReconnectMessage: + WriteConnectionReconnectMessage(ref writer, connectionReconnectMessage); + break; + case MultiConnectionDataMessage multiConnectionDataMessage: + WriteMultiConnectionDataMessage(ref writer, multiConnectionDataMessage); + break; + case UserDataMessage userDataMessage: + WriteUserDataMessage(ref writer, userDataMessage); + break; + case MultiUserDataMessage multiUserDataMessage: + WriteMultiUserDataMessage(ref writer, multiUserDataMessage); + break; + case BroadcastDataMessage broadcastDataMessage: + WriteBroadcastDataMessage(ref writer, broadcastDataMessage); + break; + case JoinGroupMessage joinGroupMessage: + WriteJoinGroupMessage(ref writer, joinGroupMessage); + break; + case JoinGroupWithAckMessage joinGroupWithAckMessage: + WriteJoinGroupWithAckMessage(ref writer, joinGroupWithAckMessage); + break; + case LeaveGroupMessage leaveGroupMessage: + WriteLeaveGroupMessage(ref writer, leaveGroupMessage); + break; + case LeaveGroupWithAckMessage leaveGroupWithAckMessage: + WriteLeaveGroupWithAckMessage(ref writer, leaveGroupWithAckMessage); + break; + case CheckUserInGroupWithAckMessage checkUserInGroupWithAckMessage: + WriteCheckUserInGroupWithAckMessage(ref writer, checkUserInGroupWithAckMessage); + break; + case CheckGroupExistenceWithAckMessage checkAnyConnectionInGroupWithAckMessage: + WriteCheckGroupExistenceWithAckMessage(ref writer, checkAnyConnectionInGroupWithAckMessage); + break; + case CheckConnectionExistenceWithAckMessage checkConnectionExistenceWithAckMessage: + WriteCheckConnectionExistenceWithAckMessage(ref writer, checkConnectionExistenceWithAckMessage); + break; + case CheckUserExistenceWithAckMessage checkConnectionExistenceAsUserWithAckMessage: + WriteCheckUserExistenceWithAckMessage(ref writer, checkConnectionExistenceAsUserWithAckMessage); + break; + case UserJoinGroupMessage userJoinGroupMessage: + WriteUserJoinGroupMessage(ref writer, userJoinGroupMessage); + break; + case UserLeaveGroupMessage userLeaveGroupMessage: + WriteUserLeaveGroupMessage(ref writer, userLeaveGroupMessage); + break; + case UserJoinGroupWithAckMessage userJoinGroupWithAckMessage: + WriteUserJoinGroupWithAckMessage(ref writer, userJoinGroupWithAckMessage); + break; + case UserLeaveGroupWithAckMessage userLeaveGroupWithAckMessage: + WriteUserLeaveGroupWithAckMessage(ref writer, userLeaveGroupWithAckMessage); + break; + case GroupBroadcastDataMessage groupBroadcastDataMessage: + WriteGroupBroadcastDataMessage(ref writer, groupBroadcastDataMessage); + break; + case MultiGroupBroadcastDataMessage multiGroupBroadcastDataMessage: + WriteMultiGroupBroadcastDataMessage(ref writer, multiGroupBroadcastDataMessage); + break; + case ServiceErrorMessage serviceErrorMessage: + WriteServiceErrorMessage(ref writer, serviceErrorMessage); + break; + case ServiceEventMessage serviceWarningMessage: + WriteServiceEventMessage(ref writer, serviceWarningMessage); + break; + case CloseConnectionWithAckMessage closeConnectionWithAckMessage: #pragma warning disable CS0612 // Type or member is obsolete - WriteCloseConnectionWithAckMessage(ref writer, closeConnectionWithAckMessage); - break; - case CloseConnectionsWithAckMessage closeConnectionsWithAckMessage: - WriteCloseConnectionsWithAckMessage(ref writer, closeConnectionsWithAckMessage); + WriteCloseConnectionWithAckMessage(ref writer, closeConnectionWithAckMessage); + break; + case CloseConnectionsWithAckMessage closeConnectionsWithAckMessage: + WriteCloseConnectionsWithAckMessage(ref writer, closeConnectionsWithAckMessage); #pragma warning restore CS0612 // Type or member is obsolete - break; - case CloseUserConnectionsWithAckMessage closeUserConnectionsWithAckMessage: - WriteCloseUserConnectionsWithAckMessage(ref writer, closeUserConnectionsWithAckMessage); - break; - case CloseGroupConnectionsWithAckMessage closeGroupConnectionsWithAckMessage: - WriteCloseGroupConnectionsWithAckMessage(ref writer, closeGroupConnectionsWithAckMessage); - break; - case AckMessage ackMessage: - WriteAckMessage(ref writer, ackMessage); - break; - case ClientInvocationMessage clientInvocationMessage: - WriteClientInvocationMessage(ref writer, clientInvocationMessage); - break; - case ClientCompletionMessage clientCompletionMesssage: - WriteClientCompletionMessage(ref writer, clientCompletionMesssage); - break; - case ErrorCompletionMessage errorCompletionMesssage: - WriteErrorCompletionMessage(ref writer, errorCompletionMesssage); - break; - case ServiceMappingMessage serviceMappingMessage: - WriteServiceMappingMessage(ref writer, serviceMappingMessage); - break; - case ConnectionFlowControlMessage connectionFlowControlMessage: - WriteConnectionFlowControlMessage(ref writer, connectionFlowControlMessage); - break; - default: - throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}"); - } + break; + case CloseUserConnectionsWithAckMessage closeUserConnectionsWithAckMessage: + WriteCloseUserConnectionsWithAckMessage(ref writer, closeUserConnectionsWithAckMessage); + break; + case CloseGroupConnectionsWithAckMessage closeGroupConnectionsWithAckMessage: + WriteCloseGroupConnectionsWithAckMessage(ref writer, closeGroupConnectionsWithAckMessage); + break; + case AckMessage ackMessage: + WriteAckMessage(ref writer, ackMessage); + break; + case ClientInvocationMessage clientInvocationMessage: + WriteClientInvocationMessage(ref writer, clientInvocationMessage); + break; + case ClientCompletionMessage clientCompletionMesssage: + WriteClientCompletionMessage(ref writer, clientCompletionMesssage); + break; + case ErrorCompletionMessage errorCompletionMesssage: + WriteErrorCompletionMessage(ref writer, errorCompletionMesssage); + break; + case ServiceMappingMessage serviceMappingMessage: + WriteServiceMappingMessage(ref writer, serviceMappingMessage); + break; + case ConnectionFlowControlMessage connectionFlowControlMessage: + WriteConnectionFlowControlMessage(ref writer, connectionFlowControlMessage); + break; + default: + throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}"); + } #pragma warning restore CS0618 // Type or member is obsolete - writer.Flush(); - } + writer.Flush(); + } - private static void WriteHandshakeRequestMessage(ref MessagePackWriter writer, HandshakeRequestMessage message) - { - writer.WriteArrayHeader(7); - writer.Write(ServiceProtocolConstants.HandshakeRequestType); - writer.Write(message.Version); - writer.Write(message.ConnectionType); - writer.Write(message.ConnectionType == 0 ? "" : message.Target ?? string.Empty); - writer.Write(message.MigrationLevel); - message.WriteExtensionMembers(ref writer); - writer.Write(message.AllowStatefulReconnects); - } + private static void WriteHandshakeRequestMessage(ref MessagePackWriter writer, HandshakeRequestMessage message) + { + writer.WriteArrayHeader(7); + writer.Write(ServiceProtocolConstants.HandshakeRequestType); + writer.Write(message.Version); + writer.Write(message.ConnectionType); + writer.Write(message.ConnectionType == 0 ? "" : message.Target ?? string.Empty); + writer.Write(message.MigrationLevel); + message.WriteExtensionMembers(ref writer); + writer.Write(message.AllowStatefulReconnects); + } - private static void WriteHandshakeResponseMessage(ref MessagePackWriter writer, HandshakeResponseMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.HandshakeResponseType); - writer.Write(message.ErrorMessage); - message.WriteExtensionMembers(ref writer); - writer.Write(message.ConnectionId); - } + private static void WriteHandshakeResponseMessage(ref MessagePackWriter writer, HandshakeResponseMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.HandshakeResponseType); + writer.Write(message.ErrorMessage); + message.WriteExtensionMembers(ref writer); + writer.Write(message.ConnectionId); + } - private static void WriteAccessKeyRequestMessage(ref MessagePackWriter writer, AccessKeyRequestMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.AccessKeyRequestType); - writer.Write(message.Token); - writer.Write(message.Kid); - message.WriteExtensionMembers(ref writer); - } + private static void WriteAccessKeyRequestMessage(ref MessagePackWriter writer, AccessKeyRequestMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.AccessKeyRequestType); + writer.Write(message.Token); + writer.Write(message.Kid); + message.WriteExtensionMembers(ref writer); + } - private static void WriteAccessKeyResponseMessage(ref MessagePackWriter writer, AccessKeyResponseMessage message) - { - writer.WriteArrayHeader(6); - writer.Write(ServiceProtocolConstants.AccessKeyResponseType); - writer.Write(message.Kid); - writer.Write(message.AccessKey); - writer.Write(message.ErrorType); - writer.Write(message.ErrorMessage); - message.WriteExtensionMembers(ref writer); - } + private static void WriteAccessKeyResponseMessage(ref MessagePackWriter writer, AccessKeyResponseMessage message) + { + writer.WriteArrayHeader(6); + writer.Write(ServiceProtocolConstants.AccessKeyResponseType); + writer.Write(message.Kid); + writer.Write(message.AccessKey); + writer.Write(message.ErrorType); + writer.Write(message.ErrorMessage); + message.WriteExtensionMembers(ref writer); + } - private static void WritePingMessage(ref MessagePackWriter writer, PingMessage message) + private static void WritePingMessage(ref MessagePackWriter writer, PingMessage message) + { + writer.WriteArrayHeader(message.Messages.Length + 1); + writer.Write(ServiceProtocolConstants.PingMessageType); + foreach (var item in message.Messages) { - writer.WriteArrayHeader(message.Messages.Length + 1); - writer.Write(ServiceProtocolConstants.PingMessageType); - foreach (var item in message.Messages) - { - writer.Write(item); - } + writer.Write(item); } + } - private static void WriteOpenConnectionMessage(ref MessagePackWriter writer, OpenConnectionMessage message) - { - writer.WriteArrayHeader(6); - writer.Write(ServiceProtocolConstants.OpenConnectionMessageType); - writer.Write(message.ConnectionId); + private static void WriteOpenConnectionMessage(ref MessagePackWriter writer, OpenConnectionMessage message) + { + writer.WriteArrayHeader(6); + writer.Write(ServiceProtocolConstants.OpenConnectionMessageType); + writer.Write(message.ConnectionId); - if (message.Claims?.Length > 0) - { - writer.WriteMapHeader(message.Claims.Length); - foreach (var claim in message.Claims) - { - writer.Write(claim.Type); - writer.Write(claim.Value); - } - } - else + if (message.Claims?.Length > 0) + { + writer.WriteMapHeader(message.Claims.Length); + foreach (var claim in message.Claims) { - writer.WriteMapHeader(0); + writer.Write(claim.Type); + writer.Write(claim.Value); } - WriteHeaders(ref writer, message.Headers); - - writer.Write(message.QueryString); - message.WriteExtensionMembers(ref writer); } - - private static void WriteCloseConnectionMessage(ref MessagePackWriter writer, CloseConnectionMessage message) + else { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.CloseConnectionMessageType); - writer.Write(message.ConnectionId); - writer.Write(message.ErrorMessage); - WriteHeaders(ref writer, message.Headers); - message.WriteExtensionMembers(ref writer); + writer.WriteMapHeader(0); } + WriteHeaders(ref writer, message.Headers); - [Obsolete] - private static void WriteCloseConnectionWithAckMessage(ref MessagePackWriter writer, CloseConnectionWithAckMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.CloseConnectionWithAckMessageType); - writer.Write(message.ConnectionId); - writer.Write(message.Reason); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + writer.Write(message.QueryString); + message.WriteExtensionMembers(ref writer); + } - [Obsolete] - private static void WriteCloseConnectionsWithAckMessage(ref MessagePackWriter writer, CloseConnectionsWithAckMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.CloseConnectionsWithAckMessageType); - writer.Write(message.Reason); - writer.Write(message.AckId); - WriteStringArray(ref writer, message.ExcludedList); - message.WriteExtensionMembers(ref writer); - } + private static void WriteCloseConnectionMessage(ref MessagePackWriter writer, CloseConnectionMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.CloseConnectionMessageType); + writer.Write(message.ConnectionId); + writer.Write(message.ErrorMessage); + WriteHeaders(ref writer, message.Headers); + message.WriteExtensionMembers(ref writer); + } - private static void WriteCloseUserConnectionsWithAckMessage(ref MessagePackWriter writer, CloseUserConnectionsWithAckMessage message) - { - writer.WriteArrayHeader(6); - writer.Write(ServiceProtocolConstants.CloseUserConnectionsWithAckMessageType); - writer.Write(message.UserId); - writer.Write(message.Reason); - writer.Write(message.AckId); - WriteStringArray(ref writer, message.ExcludedList); - message.WriteExtensionMembers(ref writer); - } + [Obsolete] + private static void WriteCloseConnectionWithAckMessage(ref MessagePackWriter writer, CloseConnectionWithAckMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.CloseConnectionWithAckMessageType); + writer.Write(message.ConnectionId); + writer.Write(message.Reason); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteCloseGroupConnectionsWithAckMessage(ref MessagePackWriter writer, CloseGroupConnectionsWithAckMessage message) - { - writer.WriteArrayHeader(6); - writer.Write(ServiceProtocolConstants.CloseGroupConnectionsWithAckMessageType); - writer.Write(message.GroupName); - writer.Write(message.Reason); - writer.Write(message.AckId); - WriteStringArray(ref writer, message.ExcludedList); - message.WriteExtensionMembers(ref writer); - } + [Obsolete] + private static void WriteCloseConnectionsWithAckMessage(ref MessagePackWriter writer, CloseConnectionsWithAckMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.CloseConnectionsWithAckMessageType); + writer.Write(message.Reason); + writer.Write(message.AckId); + WriteStringArray(ref writer, message.ExcludedList); + message.WriteExtensionMembers(ref writer); + } - private static void WriteConnectionDataMessage(ref MessagePackWriter writer, ConnectionDataMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.ConnectionDataMessageType); - writer.Write(message.ConnectionId); - writer.Write(message.Payload); - message.WriteExtensionMembers(ref writer); - } + private static void WriteCloseUserConnectionsWithAckMessage(ref MessagePackWriter writer, CloseUserConnectionsWithAckMessage message) + { + writer.WriteArrayHeader(6); + writer.Write(ServiceProtocolConstants.CloseUserConnectionsWithAckMessageType); + writer.Write(message.UserId); + writer.Write(message.Reason); + writer.Write(message.AckId); + WriteStringArray(ref writer, message.ExcludedList); + message.WriteExtensionMembers(ref writer); + } - private static void WriteConnectionReconnectMessage(ref MessagePackWriter writer, ConnectionReconnectMessage message) - { - writer.WriteArrayHeader(3); - writer.Write(ServiceProtocolConstants.ConnectionReconnectMessageType); - writer.Write(message.ConnectionId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteCloseGroupConnectionsWithAckMessage(ref MessagePackWriter writer, CloseGroupConnectionsWithAckMessage message) + { + writer.WriteArrayHeader(6); + writer.Write(ServiceProtocolConstants.CloseGroupConnectionsWithAckMessageType); + writer.Write(message.GroupName); + writer.Write(message.Reason); + writer.Write(message.AckId); + WriteStringArray(ref writer, message.ExcludedList); + message.WriteExtensionMembers(ref writer); + } - private static void WriteMultiConnectionDataMessage(ref MessagePackWriter writer, MultiConnectionDataMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.MultiConnectionDataMessageType); - WriteStringArray(ref writer, message.ConnectionList); - WritePayloads(ref writer, message.Payloads); - message.WriteExtensionMembers(ref writer); - } + private static void WriteConnectionDataMessage(ref MessagePackWriter writer, ConnectionDataMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.ConnectionDataMessageType); + writer.Write(message.ConnectionId); + writer.Write(message.Payload); + message.WriteExtensionMembers(ref writer); + } - private static void WriteUserDataMessage(ref MessagePackWriter writer, UserDataMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.UserDataMessageType); - writer.Write(message.UserId); - WritePayloads(ref writer, message.Payloads); - message.WriteExtensionMembers(ref writer); - } + private static void WriteConnectionReconnectMessage(ref MessagePackWriter writer, ConnectionReconnectMessage message) + { + writer.WriteArrayHeader(3); + writer.Write(ServiceProtocolConstants.ConnectionReconnectMessageType); + writer.Write(message.ConnectionId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteMultiUserDataMessage(ref MessagePackWriter writer, MultiUserDataMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.MultiUserDataMessageType); - WriteStringArray(ref writer, message.UserList); - WritePayloads(ref writer, message.Payloads); - message.WriteExtensionMembers(ref writer); - } + private static void WriteMultiConnectionDataMessage(ref MessagePackWriter writer, MultiConnectionDataMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.MultiConnectionDataMessageType); + WriteStringArray(ref writer, message.ConnectionList); + WritePayloads(ref writer, message.Payloads); + message.WriteExtensionMembers(ref writer); + } - private static void WriteBroadcastDataMessage(ref MessagePackWriter writer, BroadcastDataMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.BroadcastDataMessageType); - WriteStringArray(ref writer, message.ExcludedList); - WritePayloads(ref writer, message.Payloads); - message.WriteExtensionMembers(ref writer); - } + private static void WriteUserDataMessage(ref MessagePackWriter writer, UserDataMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.UserDataMessageType); + writer.Write(message.UserId); + WritePayloads(ref writer, message.Payloads); + message.WriteExtensionMembers(ref writer); + } - private static void WriteJoinGroupMessage(ref MessagePackWriter writer, JoinGroupMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.JoinGroupMessageType); - writer.Write(message.ConnectionId); - writer.Write(message.GroupName); - message.WriteExtensionMembers(ref writer); - } + private static void WriteMultiUserDataMessage(ref MessagePackWriter writer, MultiUserDataMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.MultiUserDataMessageType); + WriteStringArray(ref writer, message.UserList); + WritePayloads(ref writer, message.Payloads); + message.WriteExtensionMembers(ref writer); + } - private static void WriteLeaveGroupMessage(ref MessagePackWriter writer, LeaveGroupMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.LeaveGroupMessageType); - writer.Write(message.ConnectionId); - writer.Write(message.GroupName); - message.WriteExtensionMembers(ref writer); - } + private static void WriteBroadcastDataMessage(ref MessagePackWriter writer, BroadcastDataMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.BroadcastDataMessageType); + WriteStringArray(ref writer, message.ExcludedList); + WritePayloads(ref writer, message.Payloads); + message.WriteExtensionMembers(ref writer); + } - private static void WriteUserJoinGroupMessage(ref MessagePackWriter writer, UserJoinGroupMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.UserJoinGroupMessageType); - writer.Write(message.UserId); - writer.Write(message.GroupName); - message.WriteExtensionMembers(ref writer); - } + private static void WriteJoinGroupMessage(ref MessagePackWriter writer, JoinGroupMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.JoinGroupMessageType); + writer.Write(message.ConnectionId); + writer.Write(message.GroupName); + message.WriteExtensionMembers(ref writer); + } - private static void WriteUserLeaveGroupMessage(ref MessagePackWriter writer, UserLeaveGroupMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.UserLeaveGroupMessageType); - writer.Write(message.UserId); - writer.Write(message.GroupName); - message.WriteExtensionMembers(ref writer); - } + private static void WriteLeaveGroupMessage(ref MessagePackWriter writer, LeaveGroupMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.LeaveGroupMessageType); + writer.Write(message.ConnectionId); + writer.Write(message.GroupName); + message.WriteExtensionMembers(ref writer); + } - private static void WriteUserJoinGroupWithAckMessage(ref MessagePackWriter writer, UserJoinGroupWithAckMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.UserJoinGroupWithAckMessageType); - writer.Write(message.UserId); - writer.Write(message.GroupName); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteUserJoinGroupMessage(ref MessagePackWriter writer, UserJoinGroupMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.UserJoinGroupMessageType); + writer.Write(message.UserId); + writer.Write(message.GroupName); + message.WriteExtensionMembers(ref writer); + } - private static void WriteUserLeaveGroupWithAckMessage(ref MessagePackWriter writer, UserLeaveGroupWithAckMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.UserLeaveGroupWithAckMessageType); - writer.Write(message.UserId); - writer.Write(message.GroupName); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteUserLeaveGroupMessage(ref MessagePackWriter writer, UserLeaveGroupMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.UserLeaveGroupMessageType); + writer.Write(message.UserId); + writer.Write(message.GroupName); + message.WriteExtensionMembers(ref writer); + } - private static void WriteGroupBroadcastDataMessage(ref MessagePackWriter writer, GroupBroadcastDataMessage message) - { - writer.WriteArrayHeader(7); - writer.Write(ServiceProtocolConstants.GroupBroadcastDataMessageType); - writer.Write(message.GroupName); - WriteStringArray(ref writer, message.ExcludedList); - WritePayloads(ref writer, message.Payloads); - message.WriteExtensionMembers(ref writer); - WriteStringArray(ref writer, message.ExcludedUserList); - writer.Write(message.CallerUserId); - } + private static void WriteUserJoinGroupWithAckMessage(ref MessagePackWriter writer, UserJoinGroupWithAckMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.UserJoinGroupWithAckMessageType); + writer.Write(message.UserId); + writer.Write(message.GroupName); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteMultiGroupBroadcastDataMessage(ref MessagePackWriter writer, MultiGroupBroadcastDataMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.MultiGroupBroadcastDataMessageType); - WriteStringArray(ref writer, message.GroupList); - WritePayloads(ref writer, message.Payloads); - message.WriteExtensionMembers(ref writer); - } + private static void WriteUserLeaveGroupWithAckMessage(ref MessagePackWriter writer, UserLeaveGroupWithAckMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.UserLeaveGroupWithAckMessageType); + writer.Write(message.UserId); + writer.Write(message.GroupName); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteServiceErrorMessage(ref MessagePackWriter writer, ServiceErrorMessage message) - { - writer.WriteArrayHeader(2); - writer.Write(ServiceProtocolConstants.ServiceErrorMessageType); - writer.Write(message.ErrorMessage); - } + private static void WriteGroupBroadcastDataMessage(ref MessagePackWriter writer, GroupBroadcastDataMessage message) + { + writer.WriteArrayHeader(7); + writer.Write(ServiceProtocolConstants.GroupBroadcastDataMessageType); + writer.Write(message.GroupName); + WriteStringArray(ref writer, message.ExcludedList); + WritePayloads(ref writer, message.Payloads); + message.WriteExtensionMembers(ref writer); + WriteStringArray(ref writer, message.ExcludedUserList); + writer.Write(message.CallerUserId); + } - private static void WriteServiceEventMessage(ref MessagePackWriter writer, ServiceEventMessage message) - { - writer.WriteArrayHeader(6); - writer.Write(ServiceProtocolConstants.ServiceEventMessageType); - writer.Write((int)message.Type); - writer.Write(message.Id); - writer.Write((int)message.Kind); - writer.Write(message.Message); - message.WriteExtensionMembers(ref writer); - } + private static void WriteMultiGroupBroadcastDataMessage(ref MessagePackWriter writer, MultiGroupBroadcastDataMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.MultiGroupBroadcastDataMessageType); + WriteStringArray(ref writer, message.GroupList); + WritePayloads(ref writer, message.Payloads); + message.WriteExtensionMembers(ref writer); + } - private static void WriteJoinGroupWithAckMessage(ref MessagePackWriter writer, JoinGroupWithAckMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.JoinGroupWithAckMessageType); - writer.Write(message.ConnectionId); - writer.Write(message.GroupName); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteServiceErrorMessage(ref MessagePackWriter writer, ServiceErrorMessage message) + { + writer.WriteArrayHeader(2); + writer.Write(ServiceProtocolConstants.ServiceErrorMessageType); + writer.Write(message.ErrorMessage); + } - private static void WriteLeaveGroupWithAckMessage(ref MessagePackWriter writer, LeaveGroupWithAckMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.LeaveGroupWithAckMessageType); - writer.Write(message.ConnectionId); - writer.Write(message.GroupName); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteServiceEventMessage(ref MessagePackWriter writer, ServiceEventMessage message) + { + writer.WriteArrayHeader(6); + writer.Write(ServiceProtocolConstants.ServiceEventMessageType); + writer.Write((int)message.Type); + writer.Write(message.Id); + writer.Write((int)message.Kind); + writer.Write(message.Message); + message.WriteExtensionMembers(ref writer); + } - private static void WriteCheckUserInGroupWithAckMessage(ref MessagePackWriter writer, CheckUserInGroupWithAckMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.CheckUserInGroupWithAckMessageType); - writer.Write(message.UserId); - writer.Write(message.GroupName); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteJoinGroupWithAckMessage(ref MessagePackWriter writer, JoinGroupWithAckMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.JoinGroupWithAckMessageType); + writer.Write(message.ConnectionId); + writer.Write(message.GroupName); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteCheckGroupExistenceWithAckMessage(ref MessagePackWriter writer, CheckGroupExistenceWithAckMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.CheckGroupExistenceWithAckMessageType); - writer.Write(message.GroupName); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteLeaveGroupWithAckMessage(ref MessagePackWriter writer, LeaveGroupWithAckMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.LeaveGroupWithAckMessageType); + writer.Write(message.ConnectionId); + writer.Write(message.GroupName); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteCheckConnectionExistenceWithAckMessage(ref MessagePackWriter writer, CheckConnectionExistenceWithAckMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.CheckConnectionExistenceWithAckMessageType); - writer.Write(message.ConnectionId); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteCheckUserInGroupWithAckMessage(ref MessagePackWriter writer, CheckUserInGroupWithAckMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.CheckUserInGroupWithAckMessageType); + writer.Write(message.UserId); + writer.Write(message.GroupName); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteCheckUserExistenceWithAckMessage(ref MessagePackWriter writer, CheckUserExistenceWithAckMessage message) - { - writer.WriteArrayHeader(4); - writer.Write(ServiceProtocolConstants.CheckUserExistenceWithAckMessageType); - writer.Write(message.UserId); - writer.Write(message.AckId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteCheckGroupExistenceWithAckMessage(ref MessagePackWriter writer, CheckGroupExistenceWithAckMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.CheckGroupExistenceWithAckMessageType); + writer.Write(message.GroupName); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteAckMessage(ref MessagePackWriter writer, AckMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.AckMessageType); - writer.Write(message.AckId); - writer.Write(message.Status); - writer.Write(message.Message); - message.WriteExtensionMembers(ref writer); - } + private static void WriteCheckConnectionExistenceWithAckMessage(ref MessagePackWriter writer, CheckConnectionExistenceWithAckMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.CheckConnectionExistenceWithAckMessageType); + writer.Write(message.ConnectionId); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteClientInvocationMessage(ref MessagePackWriter writer, ClientInvocationMessage message) - { - writer.WriteArrayHeader(6); - writer.Write(ServiceProtocolConstants.ClientInvocationMessageType); - writer.Write(message.InvocationId); - writer.Write(message.ConnectionId); - writer.Write(message.CallerServerId); - WritePayloads(ref writer, message.Payloads); - message.WriteExtensionMembers(ref writer); - } + private static void WriteCheckUserExistenceWithAckMessage(ref MessagePackWriter writer, CheckUserExistenceWithAckMessage message) + { + writer.WriteArrayHeader(4); + writer.Write(ServiceProtocolConstants.CheckUserExistenceWithAckMessageType); + writer.Write(message.UserId); + writer.Write(message.AckId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteClientCompletionMessage(ref MessagePackWriter writer, ClientCompletionMessage message) - { - writer.WriteArrayHeader(7); - writer.Write(ServiceProtocolConstants.ClientCompletionMessageType); - writer.Write(message.InvocationId); - writer.Write(message.ConnectionId); - writer.Write(message.CallerServerId); - writer.Write(message.Protocol); - writer.Write(message.Payload); - message.WriteExtensionMembers(ref writer); - } + private static void WriteAckMessage(ref MessagePackWriter writer, AckMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.AckMessageType); + writer.Write(message.AckId); + writer.Write(message.Status); + writer.Write(message.Message); + message.WriteExtensionMembers(ref writer); + } - private static void WriteErrorCompletionMessage(ref MessagePackWriter writer, ErrorCompletionMessage message) - { - writer.WriteArrayHeader(6); - writer.Write(ServiceProtocolConstants.ErrorCompletionMessageType); - writer.Write(message.InvocationId); - writer.Write(message.ConnectionId); - writer.Write(message.CallerServerId); - writer.Write(message.Error); - message.WriteExtensionMembers(ref writer); - } + private static void WriteClientInvocationMessage(ref MessagePackWriter writer, ClientInvocationMessage message) + { + writer.WriteArrayHeader(6); + writer.Write(ServiceProtocolConstants.ClientInvocationMessageType); + writer.Write(message.InvocationId); + writer.Write(message.ConnectionId); + writer.Write(message.CallerServerId); + WritePayloads(ref writer, message.Payloads); + message.WriteExtensionMembers(ref writer); + } - private static void WriteServiceMappingMessage(ref MessagePackWriter writer, ServiceMappingMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.ServiceMappingMessageType); - writer.Write(message.InvocationId); - writer.Write(message.ConnectionId); - writer.Write(message.InstanceId); - message.WriteExtensionMembers(ref writer); - } + private static void WriteClientCompletionMessage(ref MessagePackWriter writer, ClientCompletionMessage message) + { + writer.WriteArrayHeader(7); + writer.Write(ServiceProtocolConstants.ClientCompletionMessageType); + writer.Write(message.InvocationId); + writer.Write(message.ConnectionId); + writer.Write(message.CallerServerId); + writer.Write(message.Protocol); + writer.Write(message.Payload); + message.WriteExtensionMembers(ref writer); + } - private static void WriteConnectionFlowControlMessage(ref MessagePackWriter writer, ConnectionFlowControlMessage message) - { - writer.WriteArrayHeader(5); - writer.Write(ServiceProtocolConstants.ConnectionFlowControlMessageType); - writer.Write(message.ConnectionId); - writer.WriteInt32((int)message.ConnectionType); - writer.WriteInt32((int)message.Operation); - message.WriteExtensionMembers(ref writer); - } + private static void WriteErrorCompletionMessage(ref MessagePackWriter writer, ErrorCompletionMessage message) + { + writer.WriteArrayHeader(6); + writer.Write(ServiceProtocolConstants.ErrorCompletionMessageType); + writer.Write(message.InvocationId); + writer.Write(message.ConnectionId); + writer.Write(message.CallerServerId); + writer.Write(message.Error); + message.WriteExtensionMembers(ref writer); + } + + private static void WriteServiceMappingMessage(ref MessagePackWriter writer, ServiceMappingMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.ServiceMappingMessageType); + writer.Write(message.InvocationId); + writer.Write(message.ConnectionId); + writer.Write(message.InstanceId); + message.WriteExtensionMembers(ref writer); + } - private static void WriteStringArray(ref MessagePackWriter writer, IReadOnlyList? array) + private static void WriteConnectionFlowControlMessage(ref MessagePackWriter writer, ConnectionFlowControlMessage message) + { + writer.WriteArrayHeader(5); + writer.Write(ServiceProtocolConstants.ConnectionFlowControlMessageType); + writer.Write(message.ConnectionId); + writer.WriteInt32((int)message.ConnectionType); + writer.WriteInt32((int)message.Operation); + message.WriteExtensionMembers(ref writer); + } + + private static void WriteStringArray(ref MessagePackWriter writer, IReadOnlyList? array) + { + if (array?.Count > 0) { - if (array?.Count > 0) - { - writer.WriteArrayHeader(array.Count); - foreach (var value in array) - { - writer.Write(value); - } - } - else + writer.WriteArrayHeader(array.Count); + foreach (var value in array) { - writer.WriteArrayHeader(0); + writer.Write(value); } } + else + { + writer.WriteArrayHeader(0); + } + } - private static void WritePayloads(ref MessagePackWriter writer, IDictionary> payloads) + private static void WritePayloads(ref MessagePackWriter writer, IDictionary> payloads) + { + if (payloads?.Count > 0) { - if (payloads?.Count > 0) + writer.WriteMapHeader(payloads.Count); + foreach (var payload in payloads) { - writer.WriteMapHeader(payloads.Count); - foreach (var payload in payloads) - { - writer.Write(payload.Key); + writer.Write(payload.Key); - /*********************************************************************************/ - writer.Write(payload.Value.Span); - /*********************************************************************************/ - // REVIEW : PREVIOUS CODE WAS : - //bool isArray = MemoryMarshal.TryGetArray(payload.Value, out var segment); - //Debug.Assert(isArray, "We're not using managed memory"); + /*********************************************************************************/ + writer.Write(payload.Value.Span); + /*********************************************************************************/ + // REVIEW : PREVIOUS CODE WAS : + //bool isArray = MemoryMarshal.TryGetArray(payload.Value, out var segment); + //Debug.Assert(isArray, "We're not using managed memory"); - // writer.WriteBytes(segment.Array, segment.Offset, segment.Count); - /*********************************************************************************/ - } - } - else - { - writer.WriteMapHeader(0); + // writer.WriteBytes(segment.Array, segment.Offset, segment.Count); + /*********************************************************************************/ } } + else + { + writer.WriteMapHeader(0); + } + } - private static void WriteHeaders(ref MessagePackWriter writer, IDictionary headers) + private static void WriteHeaders(ref MessagePackWriter writer, IDictionary headers) + { + if (headers?.Count > 0) { - if (headers?.Count > 0) + writer.WriteMapHeader(headers.Count); + foreach (var header in headers) { - writer.WriteMapHeader(headers.Count); - foreach (var header in headers) + writer.Write(header.Key); + writer.WriteArrayHeader(header.Value.Count); + foreach (var stringValue in header.Value) { - writer.Write(header.Key); - writer.WriteArrayHeader(header.Value.Count); - foreach (var stringValue in header.Value) - { - writer.Write(stringValue); - } + writer.Write(stringValue); } } - else - { - writer.WriteMapHeader(0); - } } - - private static AccessKeyRequestMessage CreateAccessKeyRequestMessage(ref MessagePackReader reader, int arrayLength) + else { - var message = new AccessKeyRequestMessage() - { - Token = ReadString(ref reader, "token"), - Kid = ReadString(ref reader, "kid"), - }; - message.ReadExtensionMembers(ref reader); - return message; + writer.WriteMapHeader(0); } + } - private static AccessKeyResponseMessage CreateAccessKeyResponseMessage(ref MessagePackReader reader, int arrayLength) + private static AccessKeyRequestMessage CreateAccessKeyRequestMessage(ref MessagePackReader reader, int arrayLength) + { + var message = new AccessKeyRequestMessage() { - var message = new AccessKeyResponseMessage() - { - Kid = ReadString(ref reader, "kid"), - AccessKey = ReadString(ref reader, "accessKey"), - ErrorType = ReadString(ref reader, "errorType"), - ErrorMessage = ReadString(ref reader, "errorMessage"), - }; - message.ReadExtensionMembers(ref reader); - return message; - } + Token = ReadString(ref reader, "token"), + Kid = ReadString(ref reader, "kid"), + }; + message.ReadExtensionMembers(ref reader); + return message; + } - private static HandshakeRequestMessage CreateHandshakeRequestMessage(ref MessagePackReader reader, int arrayLength) - { - var version = ReadInt32(ref reader, "version"); - var result = new HandshakeRequestMessage(version); - if (arrayLength >= 4) - { - result.ConnectionType = ReadInt32(ref reader, "connectionType"); - result.Target = ReadString(ref reader, "target"); - } - result.MigrationLevel = arrayLength >= 5 ? ReadInt32(ref reader, "migratableStatus") : 0; - if (arrayLength >= 6) - { - result.ReadExtensionMembers(ref reader); - } - if (arrayLength >= 7) - { - result.AllowStatefulReconnects = ReadBoolean(ref reader, "enableStatefulReconnects"); - } - return result; - } + private static AccessKeyResponseMessage CreateAccessKeyResponseMessage(ref MessagePackReader reader, int arrayLength) + { + var message = new AccessKeyResponseMessage() + { + Kid = ReadString(ref reader, "kid"), + AccessKey = ReadString(ref reader, "accessKey"), + ErrorType = ReadString(ref reader, "errorType"), + ErrorMessage = ReadString(ref reader, "errorMessage"), + }; + message.ReadExtensionMembers(ref reader); + return message; + } - private static HandshakeResponseMessage CreateHandshakeResponseMessage(ref MessagePackReader reader, int arrayLength) + private static HandshakeRequestMessage CreateHandshakeRequestMessage(ref MessagePackReader reader, int arrayLength) + { + var version = ReadInt32(ref reader, "version"); + var result = new HandshakeRequestMessage(version); + if (arrayLength >= 4) { - var errorMessage = ReadString(ref reader, "errorMessage"); - var result = new HandshakeResponseMessage(errorMessage); - if (arrayLength >= 3) - { - result.ReadExtensionMembers(ref reader); - } - if (arrayLength >= 4) - { - result.ConnectionId = ReadString(ref reader, "connectionId"); - } - return result; + result.ConnectionType = ReadInt32(ref reader, "connectionType"); + result.Target = ReadString(ref reader, "target"); } - - private static PingMessage CreatePingMessage(ref MessagePackReader reader, int arrayLength) + result.MigrationLevel = arrayLength >= 5 ? ReadInt32(ref reader, "migratableStatus") : 0; + if (arrayLength >= 6) { - if (arrayLength > 1) - { - var length = arrayLength - 1; - var values = new string[length]; - for (int i = 0; i < length; i++) - { - values[i] = ReadString(ref reader, "messages[{0}]", i); - } - - return new PingMessage { Messages = values }; - } - return PingMessage.Instance; + result.ReadExtensionMembers(ref reader); } - - private static OpenConnectionMessage CreateOpenConnectionMessage(ref MessagePackReader reader, int arrayLength) + if (arrayLength >= 7) { - var connectionId = ReadString(ref reader, "connectionId"); - var claims = ReadClaims(ref reader); - - // Backward compatible with old versions - if (arrayLength >= 5) - { - var headers = ReadHeaders(ref reader); - var queryString = ReadString(ref reader, "queryString"); - var result = new OpenConnectionMessage(connectionId, claims, headers, queryString); - if (arrayLength >= 6) - { - result.ReadExtensionMembers(ref reader); - } - return result; - } - else - { - return new OpenConnectionMessage(connectionId, claims); - } + result.AllowStatefulReconnects = ReadBoolean(ref reader, "enableStatefulReconnects"); } + return result; + } - private static CloseConnectionMessage CreateCloseConnectionMessage(ref MessagePackReader reader, int arrayLength) + private static HandshakeResponseMessage CreateHandshakeResponseMessage(ref MessagePackReader reader, int arrayLength) + { + var errorMessage = ReadString(ref reader, "errorMessage"); + var result = new HandshakeResponseMessage(errorMessage); + if (arrayLength >= 3) { - var connectionId = ReadString(ref reader, "connectionId"); - var errorMessage = ReadString(ref reader, "errorMessage"); - var headers = arrayLength >= 4 ? ReadHeaders(ref reader) : new Dictionary(); - var result = new CloseConnectionMessage(connectionId, errorMessage, headers); - if (arrayLength >= 5) - { - result.ReadExtensionMembers(ref reader); - } - return result; + result.ReadExtensionMembers(ref reader); } - - [Obsolete] - private static CloseConnectionWithAckMessage CreateCloseConnectionWithAckMessage(ref MessagePackReader reader, int arrayLength) + if (arrayLength >= 4) { - var connectionId = ReadString(ref reader, "connectionId"); - var reason = ReadString(ref reader, "reason"); - var ackId = ReadInt32(ref reader, "ackId"); - var result = new CloseConnectionWithAckMessage(connectionId, ackId) - { - Reason = reason - }; - if (arrayLength >= 5) - { - result.ReadExtensionMembers(ref reader); - } - return result; + result.ConnectionId = ReadString(ref reader, "connectionId"); } + return result; + } - [Obsolete] - private static CloseConnectionsWithAckMessage CreateCloseConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength) + private static PingMessage CreatePingMessage(ref MessagePackReader reader, int arrayLength) + { + if (arrayLength > 1) { - var reason = ReadString(ref reader, "reason"); - var ackId = ReadInt32(ref reader, "ackId"); - var excluded = ReadStringArray(ref reader, "excluded"); - - var result = new CloseConnectionsWithAckMessage(ackId) - { - Reason = reason, - ExcludedList = excluded - }; - if (arrayLength >= 5) + var length = arrayLength - 1; + var values = new string[length]; + for (int i = 0; i < length; i++) { - result.ReadExtensionMembers(ref reader); + values[i] = ReadString(ref reader, "messages[{0}]", i); } - return result; + + return new PingMessage { Messages = values }; } + return PingMessage.Instance; + } - private static CloseUserConnectionsWithAckMessage CreateCloseUserConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength) - { - var userId = ReadString(ref reader, "userId"); - var reason = ReadString(ref reader, "reason"); - var ackId = ReadInt32(ref reader, "ackId"); - var excluded = ReadStringArray(ref reader, "excluded"); + private static OpenConnectionMessage CreateOpenConnectionMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var claims = ReadClaims(ref reader); - var result = new CloseUserConnectionsWithAckMessage(userId, ackId) - { - Reason = reason, - ExcludedList = excluded - }; + // Backward compatible with old versions + if (arrayLength >= 5) + { + var headers = ReadHeaders(ref reader); + var queryString = ReadString(ref reader, "queryString"); + var result = new OpenConnectionMessage(connectionId, claims, headers, queryString); if (arrayLength >= 6) { result.ReadExtensionMembers(ref reader); } return result; } - - private static CloseGroupConnectionsWithAckMessage CreateCloseGroupConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength) + else { - var group = ReadString(ref reader, "group"); - var reason = ReadString(ref reader, "reason"); - var ackId = ReadInt32(ref reader, "ackId"); - var excluded = ReadStringArray(ref reader, "excluded"); - - var result = new CloseGroupConnectionsWithAckMessage(group, ackId) - { - Reason = reason, - ExcludedList = excluded - }; - if (arrayLength >= 6) - { - result.ReadExtensionMembers(ref reader); - } - return result; + return new OpenConnectionMessage(connectionId, claims); } + } - private static ConnectionDataMessage CreateConnectionDataMessage(ref MessagePackReader reader, int arrayLength) + private static CloseConnectionMessage CreateCloseConnectionMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var errorMessage = ReadString(ref reader, "errorMessage"); + var headers = arrayLength >= 4 ? ReadHeaders(ref reader) : new Dictionary(); + var result = new CloseConnectionMessage(connectionId, errorMessage, headers); + if (arrayLength >= 5) { - var connectionId = ReadString(ref reader, "connectionId"); - var payload = ReadBytes(ref reader, "payload"); - - var result = new ConnectionDataMessage(connectionId, payload); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + result.ReadExtensionMembers(ref reader); } + return result; + } - private static ConnectionReconnectMessage CreateConnectionReconnectMessage(ref MessagePackReader reader, int arrayLength) + [Obsolete] + private static CloseConnectionWithAckMessage CreateCloseConnectionWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var reason = ReadString(ref reader, "reason"); + var ackId = ReadInt32(ref reader, "ackId"); + var result = new CloseConnectionWithAckMessage(connectionId, ackId) + { + Reason = reason + }; + if (arrayLength >= 5) { - var connectionId = ReadString(ref reader, "connectionId"); - - var result = new ConnectionReconnectMessage(connectionId); result.ReadExtensionMembers(ref reader); - return result; } + return result; + } - private static MultiConnectionDataMessage CreateMultiConnectionDataMessage(ref MessagePackReader reader, int arrayLength) - { - var connectionList = ReadStringArray(ref reader, "connectionList"); - var payloads = ReadPayloads(ref reader); + [Obsolete] + private static CloseConnectionsWithAckMessage CreateCloseConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var reason = ReadString(ref reader, "reason"); + var ackId = ReadInt32(ref reader, "ackId"); + var excluded = ReadStringArray(ref reader, "excluded"); - var result = new MultiConnectionDataMessage(connectionList, payloads); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + var result = new CloseConnectionsWithAckMessage(ackId) + { + Reason = reason, + ExcludedList = excluded + }; + if (arrayLength >= 5) + { + result.ReadExtensionMembers(ref reader); } + return result; + } - private static ServiceMessage CreateUserDataMessage(ref MessagePackReader reader, int arrayLength) - { - var userId = ReadString(ref reader, "userId"); - var payloads = ReadPayloads(ref reader); + private static CloseUserConnectionsWithAckMessage CreateCloseUserConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var userId = ReadString(ref reader, "userId"); + var reason = ReadString(ref reader, "reason"); + var ackId = ReadInt32(ref reader, "ackId"); + var excluded = ReadStringArray(ref reader, "excluded"); - var result = new UserDataMessage(userId, payloads); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + var result = new CloseUserConnectionsWithAckMessage(userId, ackId) + { + Reason = reason, + ExcludedList = excluded + }; + if (arrayLength >= 6) + { + result.ReadExtensionMembers(ref reader); } + return result; + } - private static MultiUserDataMessage CreateMultiUserDataMessage(ref MessagePackReader reader, int arrayLength) - { - var userList = ReadStringArray(ref reader, "userList"); - var payloads = ReadPayloads(ref reader); + private static CloseGroupConnectionsWithAckMessage CreateCloseGroupConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var group = ReadString(ref reader, "group"); + var reason = ReadString(ref reader, "reason"); + var ackId = ReadInt32(ref reader, "ackId"); + var excluded = ReadStringArray(ref reader, "excluded"); - var result = new MultiUserDataMessage(userList, payloads); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + var result = new CloseGroupConnectionsWithAckMessage(group, ackId) + { + Reason = reason, + ExcludedList = excluded + }; + if (arrayLength >= 6) + { + result.ReadExtensionMembers(ref reader); } + return result; + } - private static BroadcastDataMessage CreateBroadcastDataMessage(ref MessagePackReader reader, int arrayLength) - { - var excludedList = ReadStringArray(ref reader, "excludedList"); - var payloads = ReadPayloads(ref reader); + private static ConnectionDataMessage CreateConnectionDataMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var payload = ReadBytes(ref reader, "payload"); - var result = new BroadcastDataMessage(excludedList, payloads); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + var result = new ConnectionDataMessage(connectionId, payload); + if (arrayLength >= 4) + { + result.ReadExtensionMembers(ref reader); } + return result; + } + + private static ConnectionReconnectMessage CreateConnectionReconnectMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + + var result = new ConnectionReconnectMessage(connectionId); + result.ReadExtensionMembers(ref reader); + return result; + } + + private static MultiConnectionDataMessage CreateMultiConnectionDataMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionList = ReadStringArray(ref reader, "connectionList"); + var payloads = ReadPayloads(ref reader); - private static JoinGroupMessage CreateJoinGroupMessage(ref MessagePackReader reader, int arrayLength) + var result = new MultiConnectionDataMessage(connectionList, payloads); + if (arrayLength >= 4) { - var connectionId = ReadString(ref reader, "connectionId"); - var groupName = ReadString(ref reader, "groupName"); + result.ReadExtensionMembers(ref reader); + } + return result; + } - var result = new JoinGroupMessage(connectionId, groupName); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + private static ServiceMessage CreateUserDataMessage(ref MessagePackReader reader, int arrayLength) + { + var userId = ReadString(ref reader, "userId"); + var payloads = ReadPayloads(ref reader); + + var result = new UserDataMessage(userId, payloads); + if (arrayLength >= 4) + { + result.ReadExtensionMembers(ref reader); } + return result; + } - private static LeaveGroupMessage CreateLeaveGroupMessage(ref MessagePackReader reader, int arrayLength) - { - var connectionId = ReadString(ref reader, "connectionId"); - var groupName = ReadString(ref reader, "groupName"); + private static MultiUserDataMessage CreateMultiUserDataMessage(ref MessagePackReader reader, int arrayLength) + { + var userList = ReadStringArray(ref reader, "userList"); + var payloads = ReadPayloads(ref reader); - var result = new LeaveGroupMessage(connectionId, groupName); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + var result = new MultiUserDataMessage(userList, payloads); + if (arrayLength >= 4) + { + result.ReadExtensionMembers(ref reader); } + return result; + } - private static UserJoinGroupMessage CreateUserJoinGroupMessage(ref MessagePackReader reader, int arrayLength) - { - var userId = ReadString(ref reader, "userId"); - var groupName = ReadString(ref reader, "groupName"); + private static BroadcastDataMessage CreateBroadcastDataMessage(ref MessagePackReader reader, int arrayLength) + { + var excludedList = ReadStringArray(ref reader, "excludedList"); + var payloads = ReadPayloads(ref reader); - var result = new UserJoinGroupMessage(userId, groupName); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + var result = new BroadcastDataMessage(excludedList, payloads); + if (arrayLength >= 4) + { + result.ReadExtensionMembers(ref reader); } + return result; + } - private static UserLeaveGroupMessage CreateUserLeaveGroupMessage(ref MessagePackReader reader, int arrayLength) - { - var userId = ReadString(ref reader, "userId"); - var groupName = ReadString(ref reader, "groupName"); + private static JoinGroupMessage CreateJoinGroupMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var groupName = ReadString(ref reader, "groupName"); - var result = new UserLeaveGroupMessage(userId, groupName); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; + var result = new JoinGroupMessage(connectionId, groupName); + if (arrayLength >= 4) + { + result.ReadExtensionMembers(ref reader); } + return result; + } - private static UserJoinGroupWithAckMessage CreateUserJoinGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) - { - var userId = ReadString(ref reader, "userId"); - var groupName = ReadString(ref reader, "groupName"); - var ackId = ReadInt32(ref reader, "ackId"); + private static LeaveGroupMessage CreateLeaveGroupMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var groupName = ReadString(ref reader, "groupName"); - var result = new UserJoinGroupWithAckMessage(userId, groupName, ackId); + var result = new LeaveGroupMessage(connectionId, groupName); + if (arrayLength >= 4) + { result.ReadExtensionMembers(ref reader); - return result; } + return result; + } - private static UserLeaveGroupWithAckMessage CreateUserLeaveGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) - { - var userId = ReadString(ref reader, "userId"); - var groupName = ReadString(ref reader, "groupName"); - var ackId = ReadInt32(ref reader, "ackId"); + private static UserJoinGroupMessage CreateUserJoinGroupMessage(ref MessagePackReader reader, int arrayLength) + { + var userId = ReadString(ref reader, "userId"); + var groupName = ReadString(ref reader, "groupName"); - var result = new UserLeaveGroupWithAckMessage(userId, groupName, ackId); + var result = new UserJoinGroupMessage(userId, groupName); + if (arrayLength >= 4) + { result.ReadExtensionMembers(ref reader); - return result; } + return result; + } + + private static UserLeaveGroupMessage CreateUserLeaveGroupMessage(ref MessagePackReader reader, int arrayLength) + { + var userId = ReadString(ref reader, "userId"); + var groupName = ReadString(ref reader, "groupName"); - private static GroupBroadcastDataMessage CreateGroupBroadcastDataMessage(ref MessagePackReader reader, int arrayLength) + var result = new UserLeaveGroupMessage(userId, groupName); + if (arrayLength >= 4) { - var groupName = ReadString(ref reader, "groupName"); - var excludedList = ReadStringArray(ref reader, "excludedList"); - var payloads = ReadPayloads(ref reader); + result.ReadExtensionMembers(ref reader); + } + return result; + } - var result = new GroupBroadcastDataMessage(groupName, excludedList, payloads); - if (arrayLength >= 5) - { - result.ReadExtensionMembers(ref reader); - } + private static UserJoinGroupWithAckMessage CreateUserJoinGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var userId = ReadString(ref reader, "userId"); + var groupName = ReadString(ref reader, "groupName"); + var ackId = ReadInt32(ref reader, "ackId"); - if (arrayLength >= 7) - { - result.ExcludedUserList = ReadStringArray(ref reader, "excludedUserList"); - result.CallerUserId = ReadString(ref reader, "callerUserId"); - } + var result = new UserJoinGroupWithAckMessage(userId, groupName, ackId); + result.ReadExtensionMembers(ref reader); + return result; + } - return result; - } + private static UserLeaveGroupWithAckMessage CreateUserLeaveGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var userId = ReadString(ref reader, "userId"); + var groupName = ReadString(ref reader, "groupName"); + var ackId = ReadInt32(ref reader, "ackId"); - private static MultiGroupBroadcastDataMessage CreateMultiGroupBroadcastDataMessage(ref MessagePackReader reader, int arrayLength) - { - var groupList = ReadStringArray(ref reader, "groupList"); - var payloads = ReadPayloads(ref reader); + var result = new UserLeaveGroupWithAckMessage(userId, groupName, ackId); + result.ReadExtensionMembers(ref reader); + return result; + } - var result = new MultiGroupBroadcastDataMessage(groupList, payloads); - if (arrayLength >= 4) - { - result.ReadExtensionMembers(ref reader); - } - return result; - } + private static GroupBroadcastDataMessage CreateGroupBroadcastDataMessage(ref MessagePackReader reader, int arrayLength) + { + var groupName = ReadString(ref reader, "groupName"); + var excludedList = ReadStringArray(ref reader, "excludedList"); + var payloads = ReadPayloads(ref reader); - private static ServiceErrorMessage CreateServiceErrorMessage(ref MessagePackReader reader) + var result = new GroupBroadcastDataMessage(groupName, excludedList, payloads); + if (arrayLength >= 5) { - var errorMessage = ReadString(ref reader, "errorMessage"); - - return new ServiceErrorMessage(errorMessage); + result.ReadExtensionMembers(ref reader); } - private static ServiceEventMessage CreateServiceEventMessage(ref MessagePackReader reader) + if (arrayLength >= 7) { - var type = ReadInt32(ref reader, "type"); - var id = ReadString(ref reader, "id"); - var kind = ReadInt32(ref reader, "kind"); - var message = ReadString(ref reader, "message"); - var result = new ServiceEventMessage((ServiceEventObjectType)type, id, (ServiceEventKind)kind, message); - result.ReadExtensionMembers(ref reader); - return result; + result.ExcludedUserList = ReadStringArray(ref reader, "excludedUserList"); + result.CallerUserId = ReadString(ref reader, "callerUserId"); } - private static JoinGroupWithAckMessage CreateJoinGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) - { - var connectionId = ReadString(ref reader, "connectionId"); - var groupName = ReadString(ref reader, "groupName"); - var ackId = ReadInt32(ref reader, "ackId"); + return result; + } - var result = new JoinGroupWithAckMessage(connectionId, groupName, ackId); - if (arrayLength >= 5) - { - result.ReadExtensionMembers(ref reader); - } - return result; - } + private static MultiGroupBroadcastDataMessage CreateMultiGroupBroadcastDataMessage(ref MessagePackReader reader, int arrayLength) + { + var groupList = ReadStringArray(ref reader, "groupList"); + var payloads = ReadPayloads(ref reader); - private static LeaveGroupWithAckMessage CreateLeaveGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) + var result = new MultiGroupBroadcastDataMessage(groupList, payloads); + if (arrayLength >= 4) { - var connectionId = ReadString(ref reader, "connectionId"); - var groupName = ReadString(ref reader, "groupName"); - var ackId = ReadInt32(ref reader, "ackId"); - - var result = new LeaveGroupWithAckMessage(connectionId, groupName, ackId); - if (arrayLength >= 5) - { - result.ReadExtensionMembers(ref reader); - } - return result; + result.ReadExtensionMembers(ref reader); } + return result; + } - private static CheckUserInGroupWithAckMessage CreateCheckUserInGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) - { - var userId = ReadString(ref reader, "userId"); - var groupName = ReadString(ref reader, "groupName"); - var ackId = ReadInt32(ref reader, "ackId"); + private static ServiceErrorMessage CreateServiceErrorMessage(ref MessagePackReader reader) + { + var errorMessage = ReadString(ref reader, "errorMessage"); - var result = new CheckUserInGroupWithAckMessage(userId, groupName, ackId); - if (arrayLength >= 5) - { - result.ReadExtensionMembers(ref reader); - } - return result; - } + return new ServiceErrorMessage(errorMessage); + } - private static CheckGroupExistenceWithAckMessage CreateGroupExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) - { - var groupName = ReadString(ref reader, "groupName"); - var ackId = ReadInt32(ref reader, "ackId"); + private static ServiceEventMessage CreateServiceEventMessage(ref MessagePackReader reader) + { + var type = ReadInt32(ref reader, "type"); + var id = ReadString(ref reader, "id"); + var kind = ReadInt32(ref reader, "kind"); + var message = ReadString(ref reader, "message"); + var result = new ServiceEventMessage((ServiceEventObjectType)type, id, (ServiceEventKind)kind, message); + result.ReadExtensionMembers(ref reader); + return result; + } + + private static JoinGroupWithAckMessage CreateJoinGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var groupName = ReadString(ref reader, "groupName"); + var ackId = ReadInt32(ref reader, "ackId"); - var result = new CheckGroupExistenceWithAckMessage(groupName, ackId); + var result = new JoinGroupWithAckMessage(connectionId, groupName, ackId); + if (arrayLength >= 5) + { result.ReadExtensionMembers(ref reader); - return result; } + return result; + } - private static CheckConnectionExistenceWithAckMessage CreateCheckConnectionExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) - { - var connectionId = ReadString(ref reader, "connectionId"); - var ackId = ReadInt32(ref reader, "ackId"); + private static LeaveGroupWithAckMessage CreateLeaveGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var groupName = ReadString(ref reader, "groupName"); + var ackId = ReadInt32(ref reader, "ackId"); - var result = new CheckConnectionExistenceWithAckMessage(connectionId, ackId); + var result = new LeaveGroupWithAckMessage(connectionId, groupName, ackId); + if (arrayLength >= 5) + { result.ReadExtensionMembers(ref reader); - return result; } + return result; + } - private static CheckUserExistenceWithAckMessage CreateCheckUserExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) - { - var userId = ReadString(ref reader, "userId"); - var ackId = ReadInt32(ref reader, "ackId"); + private static CheckUserInGroupWithAckMessage CreateCheckUserInGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var userId = ReadString(ref reader, "userId"); + var groupName = ReadString(ref reader, "groupName"); + var ackId = ReadInt32(ref reader, "ackId"); - var result = new CheckUserExistenceWithAckMessage(userId, ackId); + var result = new CheckUserInGroupWithAckMessage(userId, groupName, ackId); + if (arrayLength >= 5) + { result.ReadExtensionMembers(ref reader); - return result; } + return result; + } - private static AckMessage CreateAckMessage(ref MessagePackReader reader, int arrayLength) - { - var ackId = ReadInt32(ref reader, "ackId"); - var status = ReadInt32(ref reader, "status"); - var message = ReadString(ref reader, "message"); + private static CheckGroupExistenceWithAckMessage CreateGroupExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var groupName = ReadString(ref reader, "groupName"); + var ackId = ReadInt32(ref reader, "ackId"); - var result = new AckMessage(ackId, status, message); - if (arrayLength >= 5) - { - result.ReadExtensionMembers(ref reader); - } - return result; - } + var result = new CheckGroupExistenceWithAckMessage(groupName, ackId); + result.ReadExtensionMembers(ref reader); + return result; + } - private static ClientInvocationMessage CreateClientInvocationMessage(ref MessagePackReader reader, int arrayLength) - { - var invocationId = ReadString(ref reader, "invocationId"); - var connectionId = ReadString(ref reader, "connectionId"); - var callerServerId = ReadString(ref reader, "callerServerId"); - var payloads = ReadPayloads(ref reader); + private static CheckConnectionExistenceWithAckMessage CreateCheckConnectionExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var ackId = ReadInt32(ref reader, "ackId"); - var result = new ClientInvocationMessage(invocationId, connectionId, callerServerId, payloads); - result.ReadExtensionMembers(ref reader); - return result; - } + var result = new CheckConnectionExistenceWithAckMessage(connectionId, ackId); + result.ReadExtensionMembers(ref reader); + return result; + } - private static ClientCompletionMessage CreateClientCompletionMessage(ref MessagePackReader reader, int arrayLength) - { - var invocationId = ReadString(ref reader, "invocationId"); - var connectionId = ReadString(ref reader, "connectionId"); - var callerServerId = ReadString(ref reader, "callerServerId"); - var protocol = ReadString(ref reader, "protocol"); - var payload = ReadBytes(ref reader, "payload"); + private static CheckUserExistenceWithAckMessage CreateCheckUserExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) + { + var userId = ReadString(ref reader, "userId"); + var ackId = ReadInt32(ref reader, "ackId"); + + var result = new CheckUserExistenceWithAckMessage(userId, ackId); + result.ReadExtensionMembers(ref reader); + return result; + } - var result = new ClientCompletionMessage(invocationId, connectionId, callerServerId, protocol, payload); + private static AckMessage CreateAckMessage(ref MessagePackReader reader, int arrayLength) + { + var ackId = ReadInt32(ref reader, "ackId"); + var status = ReadInt32(ref reader, "status"); + var message = ReadString(ref reader, "message"); + var result = new AckMessage(ackId, status, message); + if (arrayLength >= 5) + { result.ReadExtensionMembers(ref reader); - return result; } + return result; + } - private static ErrorCompletionMessage CreateErrorCompletionMessage(ref MessagePackReader reader, int arrayLength) - { - var invocationId = ReadString(ref reader, "invocationId"); - var connectionId = ReadString(ref reader, "connectionId"); - var callerServerId = ReadString(ref reader, "callerServerId"); - var error = ReadString(ref reader, "error"); + private static ClientInvocationMessage CreateClientInvocationMessage(ref MessagePackReader reader, int arrayLength) + { + var invocationId = ReadString(ref reader, "invocationId"); + var connectionId = ReadString(ref reader, "connectionId"); + var callerServerId = ReadString(ref reader, "callerServerId"); + var payloads = ReadPayloads(ref reader); + + var result = new ClientInvocationMessage(invocationId, connectionId, callerServerId, payloads); + result.ReadExtensionMembers(ref reader); + return result; + } - var result = new ErrorCompletionMessage(invocationId, connectionId, callerServerId, error); + private static ClientCompletionMessage CreateClientCompletionMessage(ref MessagePackReader reader, int arrayLength) + { + var invocationId = ReadString(ref reader, "invocationId"); + var connectionId = ReadString(ref reader, "connectionId"); + var callerServerId = ReadString(ref reader, "callerServerId"); + var protocol = ReadString(ref reader, "protocol"); + var payload = ReadBytes(ref reader, "payload"); - result.ReadExtensionMembers(ref reader); - return result; - } + var result = new ClientCompletionMessage(invocationId, connectionId, callerServerId, protocol, payload); - private static ServiceMappingMessage CreateServiceMappingMessage(ref MessagePackReader reader, int arrayLength) - { - var invocationId = ReadString(ref reader, "invocationId"); - var connectionId = ReadString(ref reader, "connectionId"); - var instanceId = ReadString(ref reader, "instanceId"); + result.ReadExtensionMembers(ref reader); + return result; + } - var result = new ServiceMappingMessage(invocationId, connectionId, instanceId); + private static ErrorCompletionMessage CreateErrorCompletionMessage(ref MessagePackReader reader, int arrayLength) + { + var invocationId = ReadString(ref reader, "invocationId"); + var connectionId = ReadString(ref reader, "connectionId"); + var callerServerId = ReadString(ref reader, "callerServerId"); + var error = ReadString(ref reader, "error"); - result.ReadExtensionMembers(ref reader); - return result; - } + var result = new ErrorCompletionMessage(invocationId, connectionId, callerServerId, error); - private static ConnectionFlowControlMessage CreateConnectionFlowControlMessage(ref MessagePackReader reader, int arrayLength) - { - var connectionId = ReadString(ref reader, "connectionId"); - var connectionType = ReadInt32(ref reader, "connectionType"); - var operation = ReadInt32(ref reader, "operation"); + result.ReadExtensionMembers(ref reader); + return result; + } - switch (connectionType) - { - case (int)ConnectionType.Client: - case (int)ConnectionType.Server: - break; - default: - throw new InvalidDataException($"Unsupported connection type: {connectionType}"); - } + private static ServiceMappingMessage CreateServiceMappingMessage(ref MessagePackReader reader, int arrayLength) + { + var invocationId = ReadString(ref reader, "invocationId"); + var connectionId = ReadString(ref reader, "connectionId"); + var instanceId = ReadString(ref reader, "instanceId"); - switch (operation) - { - case (int)ConnectionFlowControlOperation.Pause: - case (int)ConnectionFlowControlOperation.PauseAck: - case (int)ConnectionFlowControlOperation.Resume: - case (int)ConnectionFlowControlOperation.Offline: - break; - default: - throw new InvalidDataException($"Unsupported operation: {operation}"); - } + var result = new ServiceMappingMessage(invocationId, connectionId, instanceId); - var result = new ConnectionFlowControlMessage( - connectionId, - (ConnectionFlowControlOperation)operation, - (ConnectionType)connectionType); - return result; + result.ReadExtensionMembers(ref reader); + return result; + } + + private static ConnectionFlowControlMessage CreateConnectionFlowControlMessage(ref MessagePackReader reader, int arrayLength) + { + var connectionId = ReadString(ref reader, "connectionId"); + var connectionType = ReadInt32(ref reader, "connectionType"); + var operation = ReadInt32(ref reader, "operation"); + + switch (connectionType) + { + case (int)ConnectionType.Client: + case (int)ConnectionType.Server: + break; + default: + throw new InvalidDataException($"Unsupported connection type: {connectionType}"); } - private static Claim[] ReadClaims(ref MessagePackReader reader) + switch (operation) { - var claimCount = ReadMapLength(ref reader, "claims"); - if (claimCount > 0) - { - var claims = new Claim[claimCount]; + case (int)ConnectionFlowControlOperation.Pause: + case (int)ConnectionFlowControlOperation.PauseAck: + case (int)ConnectionFlowControlOperation.Resume: + case (int)ConnectionFlowControlOperation.Offline: + break; + default: + throw new InvalidDataException($"Unsupported operation: {operation}"); + } - for (var i = 0; i < claimCount; i++) - { - var type = ReadString(ref reader, "claims[{0}].Type", i); - var value = ReadString(ref reader, "claims[{0}].Value", i); - claims[i] = new Claim(type, value); - } + var result = new ConnectionFlowControlMessage( + connectionId, + (ConnectionFlowControlOperation)operation, + (ConnectionType)connectionType); + return result; + } + + private static Claim[] ReadClaims(ref MessagePackReader reader) + { + var claimCount = ReadMapLength(ref reader, "claims"); + if (claimCount > 0) + { + var claims = new Claim[claimCount]; - return claims; + for (var i = 0; i < claimCount; i++) + { + var type = ReadString(ref reader, "claims[{0}].Type", i); + var value = ReadString(ref reader, "claims[{0}].Value", i); + claims[i] = new Claim(type, value); } - return []; + return claims; } - private static IDictionary> ReadPayloads(ref MessagePackReader reader) + return []; + } + + private static IDictionary> ReadPayloads(ref MessagePackReader reader) + { + var payloadCount = ReadMapLength(ref reader, "payloads"); + if (payloadCount > 0) { - var payloadCount = ReadMapLength(ref reader, "payloads"); - if (payloadCount > 0) + var payloads = new ArrayDictionary>((int)payloadCount, StringComparer.OrdinalIgnoreCase); + for (var i = 0; i < payloadCount; i++) { - var payloads = new ArrayDictionary>((int)payloadCount, StringComparer.OrdinalIgnoreCase); - for (var i = 0; i < payloadCount; i++) - { - var key = ReadString(ref reader, "payloads[{0}].key", i); - var value = ReadBytes(ref reader, "payloads[{0}].value", i); - payloads.Add(key, value); - } - - return payloads; + var key = ReadString(ref reader, "payloads[{0}].key", i); + var value = ReadBytes(ref reader, "payloads[{0}].value", i); + payloads.Add(key, value); } - return EmptyReadOnlyMemoryDictionary; + return payloads; } - private static IDictionary ReadHeaders(ref MessagePackReader reader) + return EmptyReadOnlyMemoryDictionary; + } + + private static IDictionary ReadHeaders(ref MessagePackReader reader) + { + var headerCount = ReadMapLength(ref reader, "headers"); + if (headerCount > 0) { - var headerCount = ReadMapLength(ref reader, "headers"); - if (headerCount > 0) + var headers = new Dictionary((int)headerCount, StringComparer.OrdinalIgnoreCase); + for (var i = 0; i < headerCount; i++) { - var headers = new Dictionary((int)headerCount, StringComparer.OrdinalIgnoreCase); - for (var i = 0; i < headerCount; i++) + var key = ReadString(ref reader, $"headers[{i}].key"); + var count = ReadArrayLength(ref reader, $"headers[{i}].value.length"); + var stringValues = new string[count]; + for (var j = 0; j < count; j++) { - var key = ReadString(ref reader, $"headers[{i}].key"); - var count = ReadArrayLength(ref reader, $"headers[{i}].value.length"); - var stringValues = new string[count]; - for (var j = 0; j < count; j++) - { - stringValues[j] = ReadString(ref reader, $"headers[{i}].value[{j}]"); - } - headers.Add(key, stringValues); + stringValues[j] = ReadString(ref reader, $"headers[{i}].value[{j}]"); } - - return headers; + headers.Add(key, stringValues); } - return EmptyStringValuesDictionaryIgnoreCase; + return headers; } - private static bool ReadBoolean(ref MessagePackReader reader, string field) + return EmptyStringValuesDictionaryIgnoreCase; + } + + private static bool ReadBoolean(ref MessagePackReader reader, string field) + { + try { - try - { - return reader.ReadBoolean(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex); - } + return reader.ReadBoolean(); } - - private static int ReadInt32(ref MessagePackReader reader, string field) + catch (Exception ex) { - try - { - return reader.ReadInt32(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex); - } + throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex); + } + } + private static int ReadInt32(ref MessagePackReader reader, string field) + { + try + { + return reader.ReadInt32(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex); } - private static string ReadString(ref MessagePackReader reader, string field) + } + + private static string ReadString(ref MessagePackReader reader, string field) + { + try { - try - { - return reader.ReadString(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading '{field}' as String failed.", ex); - } + return reader.ReadString(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{field}' as String failed.", ex); } + } - private static string ReadString(ref MessagePackReader reader, string formatField, int param) + private static string ReadString(ref MessagePackReader reader, string formatField, int param) + { + try { - try - { - return reader.ReadString(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as String failed.", ex); - } + return reader.ReadString(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as String failed.", ex); } + } - private static string ReadString(ref MessagePackReader reader, string formatField, string param1, int param2) + private static string ReadString(ref MessagePackReader reader, string formatField, string param1, int param2) + { + try { - try - { - return reader.ReadString(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading '{string.Format(formatField, param1, param2)}' as String failed.", ex); - } + return reader.ReadString(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{string.Format(formatField, param1, param2)}' as String failed.", ex); } + } - private static string[] ReadStringArray(ref MessagePackReader reader, string field) + private static string[] ReadStringArray(ref MessagePackReader reader, string field) + { + var arrayLength = ReadArrayLength(ref reader, field); + if (arrayLength > 0) { - var arrayLength = ReadArrayLength(ref reader, field); - if (arrayLength > 0) + var array = new string[arrayLength]; + for (int i = 0; i < arrayLength; i++) { - var array = new string[arrayLength]; - for (int i = 0; i < arrayLength; i++) - { - array[i] = ReadString(ref reader, "{0}[{1}]", field, i); - } - - return array; + array[i] = ReadString(ref reader, "{0}[{1}]", field, i); } - return []; + return array; } - private static byte[] ReadBytes(ref MessagePackReader reader, string field) + return []; + } + + private static byte[] ReadBytes(ref MessagePackReader reader, string field) + { + try { - try - { - return reader.ReadBytes()?.ToArray() ?? Array.Empty(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading '{field}' as Byte[] failed.", ex); - } + return reader.ReadBytes()?.ToArray() ?? Array.Empty(); } - - private static byte[] ReadBytes(ref MessagePackReader reader, string formatField, int param) + catch (Exception ex) { - try - { - return reader.ReadBytes()?.ToArray() ?? Array.Empty(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as Byte[] failed.", ex); - } + throw new InvalidDataException($"Reading '{field}' as Byte[] failed.", ex); } + } - private static long ReadMapLength(ref MessagePackReader reader, string field) + private static byte[] ReadBytes(ref MessagePackReader reader, string formatField, int param) + { + try { - try - { - return reader.ReadMapHeader(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading map length for '{field}' failed.", ex); - } + return reader.ReadBytes()?.ToArray() ?? Array.Empty(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as Byte[] failed.", ex); } + } - private static long ReadArrayLength(ref MessagePackReader reader, string field) + private static long ReadMapLength(ref MessagePackReader reader, string field) + { + try { - try - { - return reader.ReadArrayHeader(); - } - catch (Exception ex) - { - throw new InvalidDataException($"Reading array length for '{field}' failed.", ex); - } + return reader.ReadMapHeader(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading map length for '{field}' failed.", ex); + } + } + private static long ReadArrayLength(ref MessagePackReader reader, string field) + { + try + { + return reader.ReadArrayHeader(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading array length for '{field}' failed.", ex); } + } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR/EndpointProvider/DefaultServiceEndpointGenerator.cs b/src/Microsoft.Azure.SignalR/EndpointProvider/DefaultServiceEndpointGenerator.cs index 07d75055d..11d9a5214 100644 --- a/src/Microsoft.Azure.SignalR/EndpointProvider/DefaultServiceEndpointGenerator.cs +++ b/src/Microsoft.Azure.SignalR/EndpointProvider/DefaultServiceEndpointGenerator.cs @@ -4,63 +4,67 @@ using System.Net; using System.Text; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +#nullable enable + +internal sealed class DefaultServiceEndpointGenerator : IServiceEndpointGenerator { - internal sealed class DefaultServiceEndpointGenerator : IServiceEndpointGenerator - { - private const string ClientPath = "client"; - private const string ServerPath = "server"; + private const string ClientPath = "client"; - public string Version { get; } + private const string ServerPath = "server"; - public string AudienceBaseUrl { get; } - public string ClientEndpoint { get; } - public string ServerEndpoint { get; } + public string? Version { get; } - public DefaultServiceEndpointGenerator(ServiceEndpoint endpoint) - { - Version = endpoint.Version; - AudienceBaseUrl = endpoint.AudienceBaseUrl; - ClientEndpoint = endpoint.ClientEndpoint.AbsoluteUri; - ServerEndpoint = endpoint.ServerEndpoint.AbsoluteUri; - } + public string AudienceBaseUrl { get; } - public string GetClientAudience(string hubName, string applicationName) => - InternalGetUri(ClientPath, hubName, applicationName, AudienceBaseUrl); + public string ClientEndpoint { get; } - public string GetClientEndpoint(string hubName, string applicationName, string originalPath, string queryString) - { - var queryBuilder = new StringBuilder(); - if (!string.IsNullOrEmpty(originalPath)) - { - queryBuilder.Append("&") - .Append(Constants.QueryParameter.OriginalPath) - .Append("=") - .Append(WebUtility.UrlEncode(originalPath)); - } - - if (!string.IsNullOrEmpty(queryString)) - { - queryBuilder.Append("&").Append(queryString); - } - - return $"{InternalGetUri(ClientPath, hubName, applicationName, ClientEndpoint)}{queryBuilder}"; - } + public string ServerEndpoint { get; } - public string GetServerAudience(string hubName, string applicationName) => - InternalGetUri(ServerPath, hubName, applicationName, AudienceBaseUrl); + public DefaultServiceEndpointGenerator(ServiceEndpoint endpoint) + { + Version = endpoint.Version; + AudienceBaseUrl = endpoint.AudienceBaseUrl; + ClientEndpoint = endpoint.ClientEndpoint.AbsoluteUri; + ServerEndpoint = endpoint.ServerEndpoint.AbsoluteUri; + } - public string GetServerEndpoint(string hubName, string applicationName) => - InternalGetUri(ServerPath, hubName, applicationName, ServerEndpoint); + public string GetClientAudience(string hubName, string applicationName) => + InternalGetUri(ClientPath, hubName, applicationName, AudienceBaseUrl); - private string GetPrefixedHubName(string applicationName, string hubName) + public string GetClientEndpoint(string hubName, string applicationName, string originalPath, string queryString) + { + var queryBuilder = new StringBuilder(); + if (!string.IsNullOrEmpty(originalPath)) { - return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}"; + queryBuilder.Append('&') + .Append(Constants.QueryParameter.OriginalPath) + .Append('=') + .Append(WebUtility.UrlEncode(originalPath)); } - private string InternalGetUri(string path, string hubName, string applicationName, string target) + if (!string.IsNullOrEmpty(queryString)) { - return $"{target}{path}/?hub={GetPrefixedHubName(applicationName, hubName)}"; + queryBuilder.Append('&').Append(queryString); } + + return $"{InternalGetUri(ClientPath, hubName, applicationName, ClientEndpoint)}{queryBuilder}"; + } + + public string GetServerAudience(string hubName, string applicationName) => + InternalGetUri(ServerPath, hubName, applicationName, AudienceBaseUrl); + + public string GetServerEndpoint(string hubName, string applicationName) => + InternalGetUri(ServerPath, hubName, applicationName, ServerEndpoint); + + private static string GetPrefixedHubName(string applicationName, string hubName) + { + return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}"; + } + + private static string InternalGetUri(string path, string hubName, string applicationName, string target) + { + return $"{target}{path}/?hub={GetPrefixedHubName(applicationName, hubName)}"; } } diff --git a/src/Microsoft.Azure.SignalR/EndpointProvider/IServiceEndpointGenerator.cs b/src/Microsoft.Azure.SignalR/EndpointProvider/IServiceEndpointGenerator.cs index ddb775383..6aa9b68a8 100644 --- a/src/Microsoft.Azure.SignalR/EndpointProvider/IServiceEndpointGenerator.cs +++ b/src/Microsoft.Azure.SignalR/EndpointProvider/IServiceEndpointGenerator.cs @@ -1,13 +1,15 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal interface IServiceEndpointGenerator { - internal interface IServiceEndpointGenerator - { - string GetClientAudience(string hubName, string applicationName); - string GetClientEndpoint(string hubName, string applicationName, string originalPath, string queryString); - string GetServerAudience(string hubName, string applicationName); - string GetServerEndpoint(string hubName, string applicationName); - } + string GetClientAudience(string hubName, string applicationName); + + string GetClientEndpoint(string hubName, string applicationName, string originalPath, string queryString); + + string GetServerAudience(string hubName, string applicationName); + + string GetServerEndpoint(string hubName, string applicationName); } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs index a5f7d085d..411bcae65 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs @@ -46,13 +46,23 @@ public void InvalidAzureApplication(string connectionString) [Theory] [InlineData("endpoint=https://aaa;clientEndpoint=aaa;AccessKey=bbb;")] - [InlineData("endpoint=https://aaa;ClientEndpoint=endpoint=aaa;AccessKey=bbb;")] + [InlineData("endpoint=https://aaa;ClientEndpoint=aaa;AccessKey=bbb;")] public void InvalidClientEndpoint(string connectionString) { var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); Assert.Contains("Invalid value for clientEndpoint property, it must be a valid URI. (Parameter 'clientEndpoint')", exception.Message); } + [Theory] + [InlineData("endpoint=https://aaa;serverEndpoint=aaa;AccessKey=bbb;")] + [InlineData("endpoint=https://aaa;ServerEndpoint=aaa;AccessKey=bbb;")] + public void InvalidServerEndpoint(string connectionString) + { + var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); + Assert.Contains("Invalid value for serverEndpoint property, it must be a valid URI. (Parameter 'serverEndpoint')", exception.Message); + } + + [Theory] [InlineData("Endpoint=xxx")] [InlineData("AccessKey=xxx")] @@ -80,7 +90,7 @@ public void InvalidEndpoint(string connectionString) public void InvalidPort(string connectionString) { var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); - Assert.Contains("Invalid value for port property, it must be an positive integer between (0, 65536) (Parameter 'port')", exception.Message); + Assert.Contains("Invalid value for port property, it must be an positive integer between (0, 65536). (Parameter 'port')", exception.Message); } [Theory] diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs index 64257482b..7b2875f6f 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs @@ -4,365 +4,443 @@ using System; using System.Collections; using System.Collections.Generic; - using Azure.Identity; using Xunit; -namespace Microsoft.Azure.SignalR.Common.Tests +#nullable enable + +namespace Microsoft.Azure.SignalR.Common.Tests; + +public class ServiceEndpointFacts { - public class ServiceEndpointFacts - { - private const string HttpEndpoint = "http://aaa"; + private const string HttpEndpoint = "http://aaa"; - private const string HttpsEndpoint = "https://aaa"; + private const string HttpsEndpoint = "https://aaa"; - private const string HttpClientEndpoint = "http://bbb"; + private const string HttpClientEndpoint = "http://bbb"; - private const string HttpsClientEndpoint = "http://bbb"; + private const string HttpsClientEndpoint = "http://bbb"; - private const string DefaultKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + private const string DefaultKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; - [Theory] - [ClassData(typeof(EndpointAndPortTestData))] - public void TestEndpointAndAudience(string connectionString, string expectedAudience, string expectedEndpoint) - { - var endpoint = new ServiceEndpoint(connectionString); - Assert.Equal(expectedAudience, endpoint.AudienceBaseUrl); - Assert.Equal(expectedEndpoint, endpoint.Endpoint); - Assert.Equal(new Uri(expectedEndpoint), endpoint.ClientEndpoint); - Assert.Equal(new Uri(expectedEndpoint), endpoint.ServerEndpoint); - } + private const string DefaultServiceEndpoint = "https://test.service.signalr.net"; - [Theory] - [ClassData(typeof(ClientEndpointTestData))] - public void TestClientEndpoint(string connectionString, string expectedClientEndpoint) + public static IEnumerable ServerEndpointTestData + { + get { - var endpoint = new ServiceEndpoint(connectionString); - Assert.Equal(new Uri(expectedClientEndpoint), endpoint.ClientEndpoint); + yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;serverEndpoint={HttpClientEndpoint}", HttpClientEndpoint }; + yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;serverEndpoint={HttpClientEndpoint}:80", HttpClientEndpoint + ":80" }; + yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;serverEndpoint={HttpClientEndpoint}:80/abc", HttpClientEndpoint + ":80/abc" }; + yield return new object[] { $"endpoint={HttpsEndpoint};authType=aad;serverEndpoint={HttpsClientEndpoint}", HttpsClientEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint};authType=aad;serverEndpoint={HttpsClientEndpoint}:443", HttpsClientEndpoint + ":443" }; } + } - [Theory] - [MemberData(nameof(ServerEndpointTestData))] - public void TestServerEndpoint(string connectionString, string expectedServerEndpoint) - { - var endpoint = new ServiceEndpoint(connectionString); - Assert.Equal(new Uri(expectedServerEndpoint), endpoint.ServerEndpoint); - } + private static Uri DefaultServiceEndpointUri { get; } = new Uri(DefaultServiceEndpoint); - [Theory] - [ClassData(typeof(EndpointEndWithSlash))] - public void TestEndpointEndWithSlash(string connectionString, string expectedEndpoint) - { - var endpoint = new ServiceEndpoint(connectionString); - Assert.Equal(new Uri(expectedEndpoint), new Uri(endpoint.Endpoint)); - } + [Theory] + [ClassData(typeof(EndpointAndPortTestData))] + public void TestEndpointAndAudience(string connectionString, string expectedAudience, string expectedEndpoint) + { + var endpoint = new ServiceEndpoint(connectionString); + Assert.Equal(expectedAudience, endpoint.AudienceBaseUrl); + Assert.Equal(expectedEndpoint, endpoint.Endpoint); + Assert.Equal(new Uri(expectedEndpoint), endpoint.ClientEndpoint); + Assert.Equal(new Uri(expectedEndpoint), endpoint.ServerEndpoint); + } - [Fact] - public void TestCustomizeEndpointInConstructor() - { - var clientEndpoint = new Uri("https://clientEndpoint/path"); - var serverEndpoint = new Uri("http://serverEndpoint:123/path"); - var endpoint = "https://test.service.signalr.net"; - var serviceEndpoints = new ServiceEndpoint[]{ - new ServiceEndpoint(new Uri(endpoint), new DefaultAzureCredential()) - { - ClientEndpoint = clientEndpoint, - ServerEndpoint = serverEndpoint - }, - new ServiceEndpoint($"Endpoint={endpoint};AccessKey={DefaultKey}") - { - ClientEndpoint = clientEndpoint, - ServerEndpoint = serverEndpoint - } - }; - foreach (var serviceEndpoint in serviceEndpoints) + [Theory] + [ClassData(typeof(ClientEndpointTestData))] + public void TestClientEndpoint(string connectionString, string expectedClientEndpoint) + { + var endpoint = new ServiceEndpoint(connectionString); + Assert.Equal(new Uri(expectedClientEndpoint), endpoint.ClientEndpoint); + } + + [Theory] + [MemberData(nameof(ServerEndpointTestData))] + public void TestServerEndpoint(string connectionString, string expectedServerEndpoint) + { + var endpoint = new ServiceEndpoint(connectionString); + Assert.Equal(new Uri(expectedServerEndpoint), endpoint.ServerEndpoint); + } + + [Theory] + [ClassData(typeof(EndpointEndWithSlash))] + public void TestEndpointEndWithSlash(string connectionString, string expectedEndpoint) + { + var endpoint = new ServiceEndpoint(connectionString); + Assert.Equal(new Uri(expectedEndpoint), new Uri(endpoint.Endpoint)); + } + + [Fact] + public void TestCustomizeEndpointInConstructor() + { + var clientEndpoint = new Uri("https://clientEndpoint/path"); + var serverEndpoint = new Uri("http://serverEndpoint:123/path"); + var endpoint = "https://test.service.signalr.net"; + var serviceEndpoints = new ServiceEndpoint[]{ + new ServiceEndpoint(new Uri(endpoint), new DefaultAzureCredential()) + { + ClientEndpoint = clientEndpoint, + ServerEndpoint = serverEndpoint + }, + new ServiceEndpoint($"Endpoint={endpoint};AccessKey={DefaultKey}") { - Assert.Equal(endpoint, serviceEndpoint.Endpoint); - Assert.Equal(clientEndpoint, serviceEndpoint.ClientEndpoint); - Assert.Equal(serverEndpoint, serviceEndpoint.ServerEndpoint); + ClientEndpoint = clientEndpoint, + ServerEndpoint = serverEndpoint } + }; + foreach (var serviceEndpoint in serviceEndpoints) + { + Assert.Equal(endpoint, serviceEndpoint.Endpoint); + Assert.Equal(clientEndpoint, serviceEndpoint.ClientEndpoint); + Assert.Equal(serverEndpoint, serviceEndpoint.ServerEndpoint); } + } - [Fact] - public void TestCreateServiceEndpointFromAnother() + [Fact] + public void TestCreateServiceEndpointFromAnother() + { + var serviceEndpoint = new ServiceEndpoint($"Endpoint=http://abc.service.signalr.net;AccessKey={DefaultKey}", EndpointType.Secondary, "name1") { - var serviceEndpoint = new ServiceEndpoint($"Endpoint=http://abc.service.signalr.net;AccessKey={DefaultKey}", EndpointType.Secondary, "name1") - { - ClientEndpoint = new Uri("https://clientEndpoint/path"), - ServerEndpoint = new Uri("http://serverEndpoint:123/path") - }; - var cloned = new ServiceEndpoint(serviceEndpoint); - Assert.Equal(serviceEndpoint.Endpoint, cloned.Endpoint); - Assert.Equal(serviceEndpoint.ServerEndpoint, cloned.ServerEndpoint); - Assert.Equal(serviceEndpoint.ClientEndpoint, cloned.ClientEndpoint); - Assert.Equal(serviceEndpoint.EndpointType, cloned.EndpointType); - Assert.Equal(serviceEndpoint.Name, cloned.Name); - Assert.Equal(serviceEndpoint.ConnectionString, cloned.ConnectionString); - Assert.Equal(serviceEndpoint.AudienceBaseUrl, cloned.AudienceBaseUrl); - Assert.Equal(serviceEndpoint.Version, cloned.Version); - Assert.Equal(serviceEndpoint.AccessKey, cloned.AccessKey); - } + ClientEndpoint = new Uri("https://clientEndpoint/path"), + ServerEndpoint = new Uri("http://serverEndpoint:123/path") + }; + var cloned = new ServiceEndpoint(serviceEndpoint); + Assert.Equal(serviceEndpoint.Endpoint, cloned.Endpoint); + Assert.Equal(serviceEndpoint.ServerEndpoint, cloned.ServerEndpoint); + Assert.Equal(serviceEndpoint.ClientEndpoint, cloned.ClientEndpoint); + Assert.Equal(serviceEndpoint.EndpointType, cloned.EndpointType); + Assert.Equal(serviceEndpoint.Name, cloned.Name); + Assert.Equal(serviceEndpoint.ConnectionString, cloned.ConnectionString); + Assert.Equal(serviceEndpoint.AudienceBaseUrl, cloned.AudienceBaseUrl); + Assert.Equal(serviceEndpoint.Version, cloned.Version); + Assert.Equal(serviceEndpoint.AccessKey, cloned.AccessKey); + } + [Theory] + [InlineData("http://localhost", "http://localhost", 80)] + [InlineData("https://localhost", "https://localhost", 443)] + [InlineData("http://localhost:5050", "http://localhost:5050", 5050)] + [InlineData("https://localhost:5050", "https://localhost:5050", 5050)] + [InlineData("http://localhost/", "http://localhost", 80)] + [InlineData("http://localhost/foo", "http://localhost", 80)] + [InlineData("https://localhost/foo/", "https://localhost", 443)] + public void TestAzureCredentialConstructor(string url, string expectedEndpoint, int port) + { + var uri = new Uri(url); + var serviceEndpoint = new ServiceEndpoint(uri, new DefaultAzureCredential()); + Assert.IsType(serviceEndpoint.AccessKey); + Assert.Equal(expectedEndpoint, serviceEndpoint.Endpoint); + Assert.Equal("", serviceEndpoint.Name); + Assert.Equal(port, serviceEndpoint.AccessKey.Endpoint.Port); + Assert.Equal(EndpointType.Primary, serviceEndpoint.EndpointType); + TestCopyConstructor(serviceEndpoint); + } - [Theory] - [InlineData("http://localhost", "http://localhost", 80)] - [InlineData("https://localhost", "https://localhost", 443)] - [InlineData("http://localhost:5050", "http://localhost:5050", 5050)] - [InlineData("https://localhost:5050", "https://localhost:5050", 5050)] - [InlineData("http://localhost/", "http://localhost", 80)] - [InlineData("http://localhost/foo", "http://localhost", 80)] - [InlineData("https://localhost/foo/", "https://localhost", 443)] - public void TestAzureADConstructor(string url, string expectedEndpoint, int port) - { - var uri = new Uri(url); - var serviceEndpoint = new ServiceEndpoint(uri, new DefaultAzureCredential()); - Assert.IsType(serviceEndpoint.AccessKey); - Assert.Equal(expectedEndpoint, serviceEndpoint.Endpoint); - Assert.Equal("", serviceEndpoint.Name); - Assert.Equal(port, serviceEndpoint.AccessKey.Endpoint.Port); - Assert.Equal(EndpointType.Primary, serviceEndpoint.EndpointType); - TestCopyConstructor(serviceEndpoint); - } + [Theory] + [InlineData("ftp://localhost")] + [InlineData("ws://localhost")] + [InlineData("localhost:5050")] + public void TestAzureCredentialConstructorThrowsError(string url) + { + var uri = new Uri(url); + var e = Assert.Throws(() => new ServiceEndpoint(uri, new DefaultAzureCredential())); + Assert.Equal("Endpoint scheme must be 'http://' or 'https://'", e.Message); + } + + [Theory] + [InlineData("", "", EndpointType.Primary)] + [InlineData("foo", "foo", EndpointType.Primary)] + [InlineData("foo:primary", "foo", EndpointType.Primary)] + [InlineData("foo:secondary", "foo", EndpointType.Secondary)] + [InlineData("foo:SECONDARY", "foo", EndpointType.Secondary)] + [InlineData("foo:bar", "foo:bar", EndpointType.Primary)] + [InlineData(":", ":", EndpointType.Primary)] + [InlineData(":bar", ":bar", EndpointType.Primary)] + [InlineData(":primary", "", EndpointType.Primary)] + [InlineData(":secondary", "", EndpointType.Secondary)] + public void TestAzureCredentialConstructorWithKey(string key, string name, EndpointType type) + { + var uri = new Uri("http://localhost"); + var serviceEndpoint = new ServiceEndpoint(key, uri, new DefaultAzureCredential()); + Assert.IsType(serviceEndpoint.AccessKey); + Assert.Equal(name, serviceEndpoint.Name); + Assert.Equal(type, serviceEndpoint.EndpointType); + TestCopyConstructor(serviceEndpoint); + } + + [Fact] + public void TestAzureCredentialConstructorWithServerEndpoint() + { + var credential = new DefaultAzureCredential(); - [Theory] - [InlineData("ftp://localhost")] - [InlineData("ws://localhost")] - [InlineData("localhost:5050")] - public void TestAzureADConstructorThrowsError(string url) + // use default endpoint + var endpoint = new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential); + Assert.Equal(ServiceEndpoint.BuildEndpointString(DefaultServiceEndpointUri), endpoint.Endpoint); + var key = Assert.IsType(endpoint.AccessKey); + Assert.Equal(DefaultServiceEndpointUri, endpoint.ServerEndpoint); + Assert.Equal($"{DefaultServiceEndpoint}/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); + + // throw if server endpoint is invalid + var invalidUri = new Uri("ftp://foo"); + Assert.Throws(() => new ServiceEndpoint(invalidUri, credential)); + Assert.Throws(() => new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential, clientEndpoint: invalidUri)); + Assert.Throws(() => new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential) { - var uri = new Uri(url); - Assert.Throws(() => new ServiceEndpoint(uri, new DefaultAzureCredential())); - } + ClientEndpoint = invalidUri, + }); - [Theory] - [InlineData("", "", EndpointType.Primary)] - [InlineData("foo", "foo", EndpointType.Primary)] - [InlineData("foo:primary", "foo", EndpointType.Primary)] - [InlineData("foo:secondary", "foo", EndpointType.Secondary)] - [InlineData("foo:SECONDARY", "foo", EndpointType.Secondary)] - [InlineData("foo:bar", "foo:bar", EndpointType.Primary)] - [InlineData(":", ":", EndpointType.Primary)] - [InlineData(":bar", ":bar", EndpointType.Primary)] - [InlineData(":primary", "", EndpointType.Primary)] - [InlineData(":secondary", "", EndpointType.Secondary)] - public void TestAzureADConstructorWithKey(string key, string name, EndpointType type) + // use constructor param + var serverEndpoint1 = new Uri("http://serverEndpoint:123"); + endpoint = new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential, serverEndpoint: serverEndpoint1); + Assert.Equal(ServiceEndpoint.BuildEndpointString(DefaultServiceEndpointUri), endpoint.Endpoint); + key = Assert.IsType(endpoint.AccessKey); + Assert.Equal(serverEndpoint1, endpoint.ServerEndpoint); + Assert.Equal("http://serverEndpoint:123/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); + + // use object initializer + var serverEndpoint2 = new Uri("http://serverEndpoint:123/path"); + endpoint = new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential) { - var uri = new Uri("http://localhost"); - var serviceEndpoint = new ServiceEndpoint(key, uri, new DefaultAzureCredential()); - Assert.IsType(serviceEndpoint.AccessKey); - Assert.Equal(name, serviceEndpoint.Name); - Assert.Equal(type, serviceEndpoint.EndpointType); - TestCopyConstructor(serviceEndpoint); - } + ServerEndpoint = serverEndpoint2 + }; + Assert.Equal(ServiceEndpoint.BuildEndpointString(DefaultServiceEndpointUri), endpoint.Endpoint); + key = Assert.IsType(endpoint.AccessKey); + Assert.Equal(serverEndpoint2, endpoint.ServerEndpoint); + Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); - [Fact] - public void TestAzureADConstructorWithServerEndpoint() + // object initializer should override constructor param + endpoint = new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential, serverEndpoint: serverEndpoint1) { - var serverEndpoint1 = new Uri("http://serverEndpoint:123"); - var serverEndpoint2 = new Uri("http://serverEndpoint:123/path"); - var serviceEndpoint = "https://test.service.signalr.net"; - var endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential()) - { - ServerEndpoint = serverEndpoint1 - }; - var key = Assert.IsType(endpoint.AccessKey); - Assert.Same(key, endpoint.AccessKey); - Assert.Equal("http://serverEndpoint:123/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); + ServerEndpoint = serverEndpoint2 + }; + Assert.Equal(ServiceEndpoint.BuildEndpointString(DefaultServiceEndpointUri), endpoint.Endpoint); + key = Assert.IsType(endpoint.AccessKey); + Assert.Equal(serverEndpoint2, endpoint.ServerEndpoint); + Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); + } - endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential(), serverEndpoint: serverEndpoint2); - key = Assert.IsType(endpoint.AccessKey); - Assert.Same(key, endpoint.AccessKey); - Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); + [Fact] + public void TestAzureCredentialConstructorWithClientEndpoint() + { + var credential = new DefaultAzureCredential(); - endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential(), serverEndpoint: serverEndpoint1) - { - ServerEndpoint = serverEndpoint2 // property initialize should override constructor param. - }; - key = Assert.IsType(endpoint.AccessKey); - Assert.Same(key, endpoint.AccessKey); - Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); - } + // use default endpoint + var endpoint = new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential); + Assert.Equal(ServiceEndpoint.BuildEndpointString(DefaultServiceEndpointUri), endpoint.Endpoint); + Assert.IsType(endpoint.AccessKey); + Assert.Equal(DefaultServiceEndpointUri, endpoint.ClientEndpoint); - [Theory] - [ClassData(typeof(EndpointEqualityTestData))] - public void TestEndpointsEquality(ServiceEndpoint first, ServiceEndpoint second, bool expected) + // throw if client endpoint is invalid + var invalidUri = new Uri("ftp://foo"); + Assert.Throws(() => new ServiceEndpoint(invalidUri, credential)); + Assert.Throws(() => new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential, clientEndpoint: invalidUri)); + Assert.Throws(() => new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential) { - Assert.Equal(expected, first.Equals(second)); - Assert.Equal(expected, first.GetHashCode() == second.GetHashCode()); - } + ClientEndpoint = invalidUri, + }); - private static void TestCopyConstructor(ServiceEndpoint endpoint) + // use constructor param. + var clientEndpoint1 = new Uri("https://clientEndpoint:123"); + endpoint = new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential, clientEndpoint: clientEndpoint1); + Assert.Equal(ServiceEndpoint.BuildEndpointString(DefaultServiceEndpointUri), endpoint.Endpoint); + Assert.IsType(endpoint.AccessKey); + Assert.Equal(clientEndpoint1, endpoint.ClientEndpoint); + + // use object initializer. + var clientEndpoint2 = new Uri("https://clientEndpoint:123/path"); + endpoint = new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential) { - var other = new ServiceEndpoint(endpoint); - Assert.Equal(endpoint.Name, other.Name); - Assert.Equal(endpoint.EndpointType, other.EndpointType); - Assert.Equal(endpoint.Endpoint, other.Endpoint); - Assert.Equal(endpoint.ClientEndpoint, other.ClientEndpoint); - Assert.Equal(endpoint.Version, other.Version); - Assert.Equal(endpoint.AccessKey, other.AccessKey); - } + ClientEndpoint = clientEndpoint2, + }; + Assert.Equal(ServiceEndpoint.BuildEndpointString(DefaultServiceEndpointUri), endpoint.Endpoint); + Assert.IsType(endpoint.AccessKey); + Assert.Equal(clientEndpoint2, endpoint.ClientEndpoint); - public class EndpointEqualityTestData : IEnumerable + // object initializer should override constructor param. + endpoint = new ServiceEndpoint(new Uri(DefaultServiceEndpoint), credential, clientEndpoint: clientEndpoint1) { - public IEnumerator GetEnumerator() - { - yield return new object[] - { - new ServiceEndpoint("a", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - false, - }; - yield return new object[] - { - new ServiceEndpoint("a", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint(":primary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080"), - false, - }; - yield return new object[] - { - new ServiceEndpoint("a", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint("a:secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"), - false, - }; - yield return new object[] - { - new ServiceEndpoint("a", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint("b", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), - false, - }; - yield return new object[] - { - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint(":secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), - false, - }; - yield return new object[] - { - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0", EndpointType.Secondary), - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", name: "Name1"), - false, - }; - yield return new object[] - { - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint("Endpoint=https://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), - false, - }; - yield return new object[] - { - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), - false, // ports are different - }; - yield return new object[] - { - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint(":primary", "Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), - false, // ports are different - }; - yield return new object[] - { - new ServiceEndpoint(":secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", EndpointType.Secondary), - false, // ports are different - }; - yield return new object[] - { - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), - new ServiceEndpoint("Endpoint=http://localhost:8080;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), - true, - }; - yield return new object[] - { - new ServiceEndpoint("foo:bar", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Version=1.0"), - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", name : "foo:bar"), - true, - }; - yield return new object[] - { - new ServiceEndpoint(":primary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Version=1.0"), - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), - true, - }; - yield return new object[] - { - new ServiceEndpoint(":secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Version=1.0"), - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", EndpointType.Secondary), - true, - }; - yield return new object[] - { - new ServiceEndpoint("foo:secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Version=1.0"), - new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", EndpointType.Secondary, "foo"), - true, - }; - } + ClientEndpoint = clientEndpoint2, + }; + Assert.Equal(ServiceEndpoint.BuildEndpointString(DefaultServiceEndpointUri), endpoint.Endpoint); + Assert.IsType(endpoint.AccessKey); + Assert.Equal(clientEndpoint2, endpoint.ClientEndpoint); + } - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - } + [Theory] + [ClassData(typeof(EndpointEqualityTestData))] + public void TestEndpointsEquality(ServiceEndpoint first, ServiceEndpoint second, bool expected) + { + Assert.Equal(expected, first.Equals(second)); + Assert.Equal(expected, first.GetHashCode() == second.GetHashCode()); + } + + private static void TestCopyConstructor(ServiceEndpoint endpoint) + { + var other = new ServiceEndpoint(endpoint); + Assert.Equal(endpoint.Name, other.Name); + Assert.Equal(endpoint.EndpointType, other.EndpointType); + Assert.Equal(endpoint.Endpoint, other.Endpoint); + Assert.Equal(endpoint.ClientEndpoint, other.ClientEndpoint); + Assert.Equal(endpoint.Version, other.Version); + Assert.Equal(endpoint.AccessKey, other.AccessKey); + } - public class EndpointEndWithSlash : IEnumerable + public class EndpointEqualityTestData : IEnumerable + { + public IEnumerator GetEnumerator() { - public IEnumerator GetEnumerator() + yield return new object[] { - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", HttpEndpoint }; - yield return new object[] { $"endpoint={HttpEndpoint}/;accesskey={DefaultKey}", HttpEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint};accesskey={DefaultKey}", HttpsEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint}/;accesskey={DefaultKey}", HttpsEndpoint }; - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + new ServiceEndpoint("a", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + false, + }; + yield return new object[] + { + new ServiceEndpoint("a", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint(":primary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080"), + false, + }; + yield return new object[] + { + new ServiceEndpoint("a", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint("a:secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"), + false, + }; + yield return new object[] + { + new ServiceEndpoint("a", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint("b", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), + false, + }; + yield return new object[] + { + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint(":secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), + false, + }; + yield return new object[] + { + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0", EndpointType.Secondary), + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", name: "Name1"), + false, + }; + yield return new object[] + { + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint("Endpoint=https://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), + false, + }; + yield return new object[] + { + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), + false, // ports are different + }; + yield return new object[] + { + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint(":primary", "Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), + false, // ports are different + }; + yield return new object[] + { + new ServiceEndpoint(":secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", EndpointType.Secondary), + false, // ports are different + }; + yield return new object[] + { + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0"), + new ServiceEndpoint("Endpoint=http://localhost:8080;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), + true, + }; + yield return new object[] + { + new ServiceEndpoint("foo:bar", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Version=1.0"), + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", name : "foo:bar"), + true, + }; + yield return new object[] + { + new ServiceEndpoint(":primary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Version=1.0"), + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780"), + true, + }; + yield return new object[] + { + new ServiceEndpoint(":secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Version=1.0"), + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", EndpointType.Secondary), + true, + }; + yield return new object[] + { + new ServiceEndpoint("foo:secondary", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Version=1.0"), + new ServiceEndpoint("Endpoint=http://localhost;AccessKey=OPQRSTUVWXYZ0123456780ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456780", EndpointType.Secondary, "foo"), + true, + }; } - public class EndpointAndPortTestData : IEnumerable - { - public IEnumerator GetEnumerator() - { - // http - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", HttpEndpoint + "/", HttpEndpoint }; - yield return new object[] { $"endpoint={HttpEndpoint}:80;accesskey={DefaultKey}", HttpEndpoint + "/", HttpEndpoint }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", HttpEndpoint + "/", HttpEndpoint + ":500" }; - // https - yield return new object[] { $"endpoint={HttpsEndpoint};accesskey={DefaultKey}", HttpsEndpoint + "/", HttpsEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint}:443;accesskey={DefaultKey}", HttpsEndpoint + "/", HttpsEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint}:500;accesskey={DefaultKey}", HttpsEndpoint + "/", HttpsEndpoint + ":500" }; - // uppercase endpoint - yield return new object[] { $"endpoint={HttpEndpoint.ToUpper()};accesskey={DefaultKey}", HttpEndpoint + "/", HttpEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint.ToUpper()};accesskey={DefaultKey}", HttpsEndpoint + "/", HttpsEndpoint }; - // port override - yield return new object[] { $"endpoint={HttpsEndpoint};accesskey={DefaultKey};port=500", HttpsEndpoint + "/", HttpsEndpoint + ":500" }; - yield return new object[] { $"endpoint={HttpsEndpoint}:500;accesskey={DefaultKey};port=443", HttpsEndpoint + "/", HttpsEndpoint }; - // uppercase property name - yield return new object[] { $"ENDPOINT={HttpEndpoint};ACCESSKEY={DefaultKey}", HttpEndpoint + "/", HttpEndpoint }; - yield return new object[] { $"ENDPOINT={HttpsEndpoint}:500;ACCESSKEY={DefaultKey};PORT=443", HttpsEndpoint + "/", HttpsEndpoint }; - } + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public class EndpointEndWithSlash : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", HttpEndpoint }; + yield return new object[] { $"endpoint={HttpEndpoint}/;accesskey={DefaultKey}", HttpEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint};accesskey={DefaultKey}", HttpsEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint}/;accesskey={DefaultKey}", HttpsEndpoint }; } - public class ClientEndpointTestData : IEnumerable + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + public class EndpointAndPortTestData : IEnumerable + { + public IEnumerator GetEnumerator() { - public IEnumerator GetEnumerator() - { - yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;clientEndpoint={HttpClientEndpoint}", HttpClientEndpoint }; - yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;clientEndpoint={HttpClientEndpoint}:80", HttpClientEndpoint + ":80" }; - yield return new object[] { $"endpoint={HttpsEndpoint};authType=aad;clientEndpoint={HttpsClientEndpoint}", HttpsClientEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint};authType=aad;clientEndpoint={HttpsClientEndpoint}:443", HttpsClientEndpoint + ":443" }; - } + // http + yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", HttpEndpoint + "/", HttpEndpoint }; + yield return new object[] { $"endpoint={HttpEndpoint}:80;accesskey={DefaultKey}", HttpEndpoint + "/", HttpEndpoint }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", HttpEndpoint + "/", HttpEndpoint + ":500" }; + + // https + yield return new object[] { $"endpoint={HttpsEndpoint};accesskey={DefaultKey}", HttpsEndpoint + "/", HttpsEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint}:443;accesskey={DefaultKey}", HttpsEndpoint + "/", HttpsEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint}:500;accesskey={DefaultKey}", HttpsEndpoint + "/", HttpsEndpoint + ":500" }; + + // uppercase endpoint + yield return new object[] { $"endpoint={HttpEndpoint.ToUpper()};accesskey={DefaultKey}", HttpEndpoint + "/", HttpEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint.ToUpper()};accesskey={DefaultKey}", HttpsEndpoint + "/", HttpsEndpoint }; - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + // port override + yield return new object[] { $"endpoint={HttpsEndpoint};accesskey={DefaultKey};port=500", HttpsEndpoint + "/", HttpsEndpoint + ":500" }; + yield return new object[] { $"endpoint={HttpsEndpoint}:500;accesskey={DefaultKey};port=443", HttpsEndpoint + "/", HttpsEndpoint }; + + // uppercase property name + yield return new object[] { $"ENDPOINT={HttpEndpoint};ACCESSKEY={DefaultKey}", HttpEndpoint + "/", HttpEndpoint }; + yield return new object[] { $"ENDPOINT={HttpsEndpoint}:500;ACCESSKEY={DefaultKey};PORT=443", HttpsEndpoint + "/", HttpsEndpoint }; } - public static IEnumerable ServerEndpointTestData + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + public class ClientEndpointTestData : IEnumerable + { + public IEnumerator GetEnumerator() { - get - { - yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;serverEndpoint={HttpClientEndpoint}", HttpClientEndpoint }; - yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;serverEndpoint={HttpClientEndpoint}:80", HttpClientEndpoint + ":80" }; - yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;serverEndpoint={HttpClientEndpoint}:80/abc", HttpClientEndpoint + ":80/abc" }; - yield return new object[] { $"endpoint={HttpsEndpoint};authType=aad;serverEndpoint={HttpsClientEndpoint}", HttpsClientEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint};authType=aad;serverEndpoint={HttpsClientEndpoint}:443", HttpsClientEndpoint + ":443" }; - } + yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;clientEndpoint={HttpClientEndpoint}", HttpClientEndpoint }; + yield return new object[] { $"endpoint={HttpEndpoint};authType=aad;clientEndpoint={HttpClientEndpoint}:80", HttpClientEndpoint + ":80" }; + yield return new object[] { $"endpoint={HttpsEndpoint};authType=aad;clientEndpoint={HttpsClientEndpoint}", HttpsClientEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint};authType=aad;clientEndpoint={HttpsClientEndpoint}:443", HttpsClientEndpoint + ":443" }; } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); } -} \ No newline at end of file +} diff --git a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MessageOrderTests/MockServiceMessageOrderTestParams.cs b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MessageOrderTests/MockServiceMessageOrderTestParams.cs index c75ff7e9b..076a213bb 100644 --- a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MessageOrderTests/MockServiceMessageOrderTestParams.cs +++ b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MessageOrderTests/MockServiceMessageOrderTestParams.cs @@ -1,20 +1,23 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -namespace Microsoft.Azure.SignalR.IntegrationTests.Infrastructure.MessageOrderTests +namespace Microsoft.Azure.SignalR.IntegrationTests.Infrastructure.MessageOrderTests; + +internal class MockServiceMessageOrderTestParams : IIntegrationTestStartupParameters { - internal class MockServiceMessageOrderTestParams : IIntegrationTestStartupParameters - { - public static int ConnectionCount = 2; - public static GracefulShutdownMode ShutdownMode = GracefulShutdownMode.WaitForClientsClose; - public static ServiceEndpoint[] ServiceEndpoints = new[] { - new ServiceEndpoint("Endpoint=http://127.0.0.1;AccessKey=AAAAAAAAAAAAAAAAAAAAAAAAAA0A2A4A6A8A;Version=1.0;Port=8080", type: EndpointType.Primary, name: "primary"), - new ServiceEndpoint("Endpoint=http://127.0.1.0;AccessKey=BBBBBBBBBBBBBBBBBBBBBBBBBB0B2B4B6B8B;Version=1.0;Port=8080", type: EndpointType.Secondary, name: "secondary1"), - new ServiceEndpoint("Endpoint=http://127.1.0.0;AccessKey=CCCCCCCCCCCCCCCCCCCCCCCCCCCC2C4C6C8C;Version=1.0;Port=8080", type: EndpointType.Secondary, name: "secondary2") - }; + public static int ConnectionCount = 2; + + public static GracefulShutdownMode ShutdownMode = GracefulShutdownMode.WaitForClientsClose; + + public static ServiceEndpoint[] ServiceEndpoints = [ + new ServiceEndpoint("Endpoint=http://127.0.0.1;AccessKey=AAAAAAAAAAAAAAAAAAAAAAAAAA0A2A4A6A8A;Version=1.0;Port=8080", type: EndpointType.Primary, name: "primary"), + new ServiceEndpoint("Endpoint=http://127.0.1.0;AccessKey=BBBBBBBBBBBBBBBBBBBBBBBBBB0B2B4B6B8B;Version=1.0;Port=8080", type: EndpointType.Secondary, name: "secondary1"), + new ServiceEndpoint("Endpoint=http://127.1.0.0;AccessKey=CCCCCCCCCCCCCCCCCCCCCCCCCCCC2C4C6C8C;Version=1.0;Port=8080", type: EndpointType.Secondary, name: "secondary2") + ]; + + int IIntegrationTestStartupParameters.ConnectionCount => ConnectionCount; + + ServiceEndpoint[] IIntegrationTestStartupParameters.ServiceEndpoints => ServiceEndpoints; - int IIntegrationTestStartupParameters.ConnectionCount => ConnectionCount; - ServiceEndpoint[] IIntegrationTestStartupParameters.ServiceEndpoints => ServiceEndpoints; - GracefulShutdownMode IIntegrationTestStartupParameters.ShutdownMode => GracefulShutdownMode.WaitForClientsClose; - } + GracefulShutdownMode IIntegrationTestStartupParameters.ShutdownMode => GracefulShutdownMode.WaitForClientsClose; } diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceEndpoint.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceEndpoint.cs index 5d7d3df9d..6f386e3ca 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceEndpoint.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceEndpoint.cs @@ -3,13 +3,15 @@ namespace Microsoft.Azure.SignalR.Tests.Common; +#nullable enable + internal class TestServiceEndpoint : ServiceEndpoint { private static Uri DefaultEndpoint = new Uri("https://localhost"); - private const string _defaultConnectionString = "Endpoint=https://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ;Version=1.0"; + private const string DefaultConnectionString = "Endpoint=https://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ;Version=1.0"; - public TestServiceEndpoint(string name = "", string connectionString = null) : base(connectionString ?? _defaultConnectionString, name: name) + public TestServiceEndpoint(string name = "", string? connectionString = null) : base(connectionString ?? DefaultConnectionString, name: name) { } diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs index 0bf317a21..497cf3417 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs @@ -371,7 +371,7 @@ public async Task StartAsync(string? target = null) { while (ServiceProtocol.TryParseMessage(ref buffer, out var message)) { - Messages.Add(message); + Messages.Add(message!); } } diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs index 4fff62867..b9f16c9f8 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs @@ -351,8 +351,8 @@ private ServiceEndpoint MockServiceEndpoint(string keyTypeName) case nameof(MicrosoftEntraAccessKey): var endpoint = new ServiceEndpoint(MicrosoftEntraConnectionString); - var p = typeof(ServiceEndpoint).GetProperty("AccessKey", BindingFlags.NonPublic | BindingFlags.Instance); - p.SetValue(endpoint, new TestAadAccessKey()); + var field = typeof(ServiceEndpoint).GetField("_accessKey", BindingFlags.NonPublic | BindingFlags.Instance); + field.SetValue(endpoint, new TestAadAccessKey()); return endpoint; default: