diff --git a/build/dependencies.props b/build/dependencies.props index 2134695c4..451dfddd2 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -13,7 +13,7 @@ 1.0.0 2.1.0 4.5.1 - 4.5.4 + 4.5.5 6.0.0 1.11.4 2.1.0 diff --git a/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs index ec20550e1..ee5983391 100644 --- a/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. - +#nullable enable using System; using System.Threading; using System.Threading.Tasks; diff --git a/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs index 297170049..1e2dd9bff 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs @@ -54,7 +54,7 @@ public OpenConnectionMessage(string connectionId, Claim[]? claims) /// An array of associated with the connection. /// A associated with the connection. /// Query string associated with the connection. - public OpenConnectionMessage(string connectionId, Claim[]? claims, IDictionary headers, string queryString) + public OpenConnectionMessage(string connectionId, Claim[]? claims, IDictionary headers, string? queryString) : base(connectionId) { Claims = claims ?? []; @@ -75,7 +75,7 @@ public OpenConnectionMessage(string connectionId, Claim[]? claims, IDictionary /// Gets or sets the associated query string. /// - public string QueryString { get; set; } + public string? QueryString { get; set; } /// /// Gets or sets the protocol for new connection. @@ -99,7 +99,7 @@ public class CloseConnectionMessage : ConnectionMessage, IMessageWithTracingId /// The connection Id. /// Optional error message. /// A associated with the connection. - public CloseConnectionMessage(string connectionId, string errorMessage, IDictionary? headers = null) : base(connectionId) + public CloseConnectionMessage(string connectionId, string? errorMessage, IDictionary? headers = null) : base(connectionId) { ErrorMessage = errorMessage ?? ""; Headers = headers ?? new Dictionary(); @@ -230,7 +230,7 @@ public class ClientCompletionMessage : ServiceCompletionMessage, IHasProtocol /// The protocol of the connection. /// The payload of the completion result. /// The tracing Id of the message. - public ClientCompletionMessage(string invocationId, string connectionId, string callerServerId, string protocol, ReadOnlyMemory payload, ulong? tracingId = null) + public ClientCompletionMessage(string invocationId, string connectionId, string callerServerId, string? protocol, ReadOnlyMemory payload, ulong? tracingId = null) : base(invocationId, connectionId, callerServerId, tracingId) { Protocol = protocol; @@ -261,10 +261,10 @@ public class ErrorCompletionMessage : ServiceCompletionMessage /// The serverId that wrap the completion result. /// The error information about invacation failure. /// The tracing Id of the message. - public ErrorCompletionMessage(string invocationId, string connectionId, string callerServerId, string error, ulong? tracingId = null) + public ErrorCompletionMessage(string invocationId, string connectionId, string callerServerId, string? error, ulong? tracingId = null) : base(invocationId, connectionId, callerServerId, tracingId) { - Error = error; + Error = error ?? string.Empty; } /// diff --git a/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs index 8a39a3807..da3266cc6 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs @@ -53,7 +53,7 @@ public class LeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId /// /// Gets or sets the group name. /// - public string GroupName { get; set; } + public string? GroupName { get; set; } /// /// Gets or sets the tracing Id @@ -68,7 +68,7 @@ public class LeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId /// The connection Id. /// The group name, from which the connection will leave. /// The tracing Id of the message. - public LeaveGroupMessage(string connectionId, string groupName, ulong? tracingId = null) + public LeaveGroupMessage(string connectionId, string? groupName, ulong? tracingId = null) { ConnectionId = connectionId; GroupName = groupName; @@ -130,7 +130,7 @@ public class UserLeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTraci /// /// Gets or sets the group name. /// - public string GroupName { get; set; } + public string? GroupName { get; set; } /// /// Gets or sets the tracing Id @@ -145,7 +145,7 @@ public class UserLeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTraci /// The user Id. /// The group name, from which the user will leave. /// The tracing Id of the message. - public UserLeaveGroupMessage(string userId, string groupName, ulong? tracingId = null) + public UserLeaveGroupMessage(string userId, string? groupName, ulong? tracingId = null) { UserId = userId; GroupName = groupName; @@ -216,7 +216,7 @@ public class UserLeaveGroupWithAckMessage : ExtensibleServiceMessage, IMessageWi /// /// Gets or sets the group name. /// - public string GroupName { get; set; } + public string? GroupName { get; set; } /// /// Gets or sets the tracing Id @@ -237,7 +237,7 @@ public class UserLeaveGroupWithAckMessage : ExtensibleServiceMessage, IMessageWi /// The group name, from which the user will leave. /// The ack Id. /// The tracing Id of the message. - public UserLeaveGroupWithAckMessage(string userId, string groupName, int ackId, ulong? tracingId = null) + public UserLeaveGroupWithAckMessage(string userId, string? groupName, int ackId, ulong? tracingId = null) { UserId = userId; GroupName = groupName; @@ -312,7 +312,7 @@ public class LeaveGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessag /// /// Gets or sets the group name. /// - public string GroupName { get; set; } + public string? GroupName { get; set; } /// /// Gets or sets the ack id. @@ -332,7 +332,7 @@ public class LeaveGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessag /// The connection Id. /// The group name, from which the connection will leave. /// The tracing Id of the message. - public LeaveGroupWithAckMessage(string connectionId, string groupName, ulong? tracingId = null): this(connectionId, groupName, 0, tracingId) + public LeaveGroupWithAckMessage(string connectionId, string? groupName, ulong? tracingId = null): this(connectionId, groupName, 0, tracingId) { } @@ -343,7 +343,7 @@ public LeaveGroupWithAckMessage(string connectionId, string groupName, ulong? tr /// The group name, from which the connection will leave. /// The ack Id /// The tracing Id of the message. - public LeaveGroupWithAckMessage(string connectionId, string groupName, int ackId, ulong? tracingId = null) + public LeaveGroupWithAckMessage(string connectionId, string? groupName, int ackId, ulong? tracingId = null) { ConnectionId = connectionId; GroupName = groupName; diff --git a/src/Microsoft.Azure.SignalR.Protocols/MessagePackPitfalls.cs b/src/Microsoft.Azure.SignalR.Protocols/MessagePackPitfalls.cs new file mode 100644 index 000000000..a3fb600dd --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Protocols/MessagePackPitfalls.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace System.Diagnostics.CodeAnalysis; + +[AttributeUsage(AttributeTargets.Method, Inherited = false)] +internal sealed class DoesNotReturnAttribute : Attribute +{ +} diff --git a/src/Microsoft.Azure.SignalR.Protocols/Microsoft.Azure.SignalR.Protocols.csproj b/src/Microsoft.Azure.SignalR.Protocols/Microsoft.Azure.SignalR.Protocols.csproj index f1bec5ab8..1e2b96677 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/Microsoft.Azure.SignalR.Protocols.csproj +++ b/src/Microsoft.Azure.SignalR.Protocols/Microsoft.Azure.SignalR.Protocols.csproj @@ -4,6 +4,7 @@ netstandard2.0 Microsoft.Azure.SignalR.Protocol true + enable diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs index a4cd8d0f5..02da7ac0d 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs @@ -71,7 +71,7 @@ public abstract class ServiceMessage /// public virtual ServiceMessage Clone() => (MemberwiseClone() as ServiceMessage)!; - public static byte GeneratePartitionKey(string input) + public static byte GeneratePartitionKey(string? input) { return (byte)((input?.GetHashCode() ?? 0) & 0xFF); } @@ -393,7 +393,7 @@ public class HandshakeResponseMessage : ExtensibleServiceMessage /// /// Gets or sets the optional error message. /// - public string ErrorMessage { get; set; } + public string? ErrorMessage { get; set; } /// /// Gets or sets the id of this connection. @@ -411,7 +411,7 @@ public HandshakeResponseMessage() : this(string.Empty) /// Initializes a new instance of the class. /// /// An optional response error message. A null or empty error message indicates a successful handshake. - public HandshakeResponseMessage(string errorMessage) + public HandshakeResponseMessage(string? errorMessage) { ErrorMessage = errorMessage; } @@ -427,7 +427,7 @@ public class PingMessage : ServiceMessage /// public static PingMessage Instance = new PingMessage(); - public string[] Messages { get; set; } = Array.Empty(); + public string?[] Messages { get; set; } = Array.Empty(); } /// @@ -444,9 +444,9 @@ public class ServiceErrorMessage : ServiceMessage /// Initializes a new instance of the class. /// /// An error message. - public ServiceErrorMessage(string errorMessage) + public ServiceErrorMessage(string? errorMessage) { - ErrorMessage = errorMessage; + ErrorMessage = errorMessage ?? string.Empty; } } @@ -463,7 +463,7 @@ public class ServiceEventMessage : ExtensibleServiceMessage /// /// Gets or sets the id of event object. /// - public string Id { get; set; } + public string? Id { get; set; } /// /// Gets or sets the kind of event. @@ -482,12 +482,12 @@ public class ServiceEventMessage : ExtensibleServiceMessage /// An id of event object. /// A kind of event. /// A message of event. - public ServiceEventMessage(ServiceEventObjectType type, string id, ServiceEventKind kind, string message) + public ServiceEventMessage(ServiceEventObjectType type, string? id, ServiceEventKind kind, string? message) { Type = type; Id = id; Kind = kind; - Message = message; + Message = message ?? string.Empty; } } @@ -526,11 +526,11 @@ public AckMessage(int ackId, int status) : this(ackId, status, string.Empty) /// The ack Id /// The status code /// The ack message - public AckMessage(int ackId, int status, string message) + public AckMessage(int ackId, int status, string? message) { AckId = ackId; Status = status; - Message = message; + Message = message ?? string.Empty; } } diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs index 133a770a3..a258e536a 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs @@ -12,8 +12,6 @@ namespace Microsoft.Azure.SignalR.Protocol; -#nullable enable - /// /// Implements the Azure SignalR Service Protocol. /// @@ -700,7 +698,7 @@ private static void WriteErrorCompletionMessage(ref MessagePackWriter writer, Er writer.Write(message.Error); message.WriteExtensionMembers(ref writer); } - + private static void WriteServiceMappingMessage(ref MessagePackWriter writer, ServiceMappingMessage message) { writer.WriteArrayHeader(5); @@ -849,7 +847,7 @@ private static PingMessage CreatePingMessage(ref MessagePackReader reader, int a if (arrayLength > 1) { var length = arrayLength - 1; - var values = new string[length]; + var values = new string?[length]; for (int i = 0; i < length; i++) { values[i] = ReadString(ref reader, "messages[{0}]", i); @@ -862,7 +860,7 @@ private static PingMessage CreatePingMessage(ref MessagePackReader reader, int a private static OpenConnectionMessage CreateOpenConnectionMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var claims = ReadClaims(ref reader); // Backward compatible with old versions @@ -885,7 +883,7 @@ private static OpenConnectionMessage CreateOpenConnectionMessage(ref MessagePack private static CloseConnectionMessage CreateCloseConnectionMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var errorMessage = ReadString(ref reader, "errorMessage"); var headers = arrayLength >= 4 ? ReadHeaders(ref reader) : new Dictionary(); var result = new CloseConnectionMessage(connectionId, errorMessage, headers); @@ -899,7 +897,7 @@ private static CloseConnectionMessage CreateCloseConnectionMessage(ref MessagePa [Obsolete] private static CloseConnectionWithAckMessage CreateCloseConnectionWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var reason = ReadString(ref reader, "reason"); var ackId = ReadInt32(ref reader, "ackId"); var result = new CloseConnectionWithAckMessage(connectionId, ackId) @@ -934,7 +932,7 @@ private static CloseConnectionsWithAckMessage CreateCloseConnectionsWithAckMessa private static CloseUserConnectionsWithAckMessage CreateCloseUserConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var userId = ReadString(ref reader, "userId"); + var userId = ReadStringNotNull(ref reader, "userId"); var reason = ReadString(ref reader, "reason"); var ackId = ReadInt32(ref reader, "ackId"); var excluded = ReadStringArray(ref reader, "excluded"); @@ -953,7 +951,7 @@ private static CloseUserConnectionsWithAckMessage CreateCloseUserConnectionsWith private static CloseGroupConnectionsWithAckMessage CreateCloseGroupConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var group = ReadString(ref reader, "group"); + var group = ReadStringNotNull(ref reader, "group"); var reason = ReadString(ref reader, "reason"); var ackId = ReadInt32(ref reader, "ackId"); var excluded = ReadStringArray(ref reader, "excluded"); @@ -972,7 +970,7 @@ private static CloseGroupConnectionsWithAckMessage CreateCloseGroupConnectionsWi private static ConnectionDataMessage CreateConnectionDataMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var payload = ReadBytes(ref reader, "payload"); var result = new ConnectionDataMessage(connectionId, payload); @@ -985,7 +983,7 @@ private static ConnectionDataMessage CreateConnectionDataMessage(ref MessagePack private static ConnectionReconnectMessage CreateConnectionReconnectMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var result = new ConnectionReconnectMessage(connectionId); result.ReadExtensionMembers(ref reader); @@ -1007,7 +1005,7 @@ private static MultiConnectionDataMessage CreateMultiConnectionDataMessage(ref M private static ServiceMessage CreateUserDataMessage(ref MessagePackReader reader, int arrayLength) { - var userId = ReadString(ref reader, "userId"); + var userId = ReadStringNotNull(ref reader, "userId"); var payloads = ReadPayloads(ref reader); var result = new UserDataMessage(userId, payloads); @@ -1046,8 +1044,8 @@ private static BroadcastDataMessage CreateBroadcastDataMessage(ref MessagePackRe private static JoinGroupMessage CreateJoinGroupMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); - var groupName = ReadString(ref reader, "groupName"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); + var groupName = ReadStringNotNull(ref reader, "groupName"); var result = new JoinGroupMessage(connectionId, groupName); if (arrayLength >= 4) @@ -1059,7 +1057,7 @@ private static JoinGroupMessage CreateJoinGroupMessage(ref MessagePackReader rea private static LeaveGroupMessage CreateLeaveGroupMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var groupName = ReadString(ref reader, "groupName"); var result = new LeaveGroupMessage(connectionId, groupName); @@ -1072,8 +1070,8 @@ private static LeaveGroupMessage CreateLeaveGroupMessage(ref MessagePackReader r private static UserJoinGroupMessage CreateUserJoinGroupMessage(ref MessagePackReader reader, int arrayLength) { - var userId = ReadString(ref reader, "userId"); - var groupName = ReadString(ref reader, "groupName"); + var userId = ReadStringNotNull(ref reader, "userId"); + var groupName = ReadStringNotNull(ref reader, "groupName"); var result = new UserJoinGroupMessage(userId, groupName); if (arrayLength >= 4) @@ -1085,7 +1083,7 @@ private static UserJoinGroupMessage CreateUserJoinGroupMessage(ref MessagePackRe private static UserLeaveGroupMessage CreateUserLeaveGroupMessage(ref MessagePackReader reader, int arrayLength) { - var userId = ReadString(ref reader, "userId"); + var userId = ReadStringNotNull(ref reader, "userId"); var groupName = ReadString(ref reader, "groupName"); var result = new UserLeaveGroupMessage(userId, groupName); @@ -1098,8 +1096,8 @@ private static UserLeaveGroupMessage CreateUserLeaveGroupMessage(ref MessagePack private static UserJoinGroupWithAckMessage CreateUserJoinGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var userId = ReadString(ref reader, "userId"); - var groupName = ReadString(ref reader, "groupName"); + var userId = ReadStringNotNull(ref reader, "userId"); + var groupName = ReadStringNotNull(ref reader, "groupName"); var ackId = ReadInt32(ref reader, "ackId"); var result = new UserJoinGroupWithAckMessage(userId, groupName, ackId); @@ -1109,7 +1107,7 @@ private static UserJoinGroupWithAckMessage CreateUserJoinGroupWithAckMessage(ref private static UserLeaveGroupWithAckMessage CreateUserLeaveGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var userId = ReadString(ref reader, "userId"); + var userId = ReadStringNotNull(ref reader, "userId"); var groupName = ReadString(ref reader, "groupName"); var ackId = ReadInt32(ref reader, "ackId"); @@ -1120,7 +1118,7 @@ private static UserLeaveGroupWithAckMessage CreateUserLeaveGroupWithAckMessage(r private static GroupBroadcastDataMessage CreateGroupBroadcastDataMessage(ref MessagePackReader reader, int arrayLength) { - var groupName = ReadString(ref reader, "groupName"); + var groupName = ReadStringNotNull(ref reader, "groupName"); var excludedList = ReadStringArray(ref reader, "excludedList"); var payloads = ReadPayloads(ref reader); @@ -1172,8 +1170,8 @@ private static ServiceEventMessage CreateServiceEventMessage(ref MessagePackRead private static JoinGroupWithAckMessage CreateJoinGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); - var groupName = ReadString(ref reader, "groupName"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); + var groupName = ReadStringNotNull(ref reader, "groupName"); var ackId = ReadInt32(ref reader, "ackId"); var result = new JoinGroupWithAckMessage(connectionId, groupName, ackId); @@ -1186,7 +1184,7 @@ private static JoinGroupWithAckMessage CreateJoinGroupWithAckMessage(ref Message private static LeaveGroupWithAckMessage CreateLeaveGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var groupName = ReadString(ref reader, "groupName"); var ackId = ReadInt32(ref reader, "ackId"); @@ -1200,8 +1198,8 @@ private static LeaveGroupWithAckMessage CreateLeaveGroupWithAckMessage(ref Messa private static CheckUserInGroupWithAckMessage CreateCheckUserInGroupWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var userId = ReadString(ref reader, "userId"); - var groupName = ReadString(ref reader, "groupName"); + var userId = ReadStringNotNull(ref reader, "userId"); + var groupName = ReadStringNotNull(ref reader, "groupName"); var ackId = ReadInt32(ref reader, "ackId"); var result = new CheckUserInGroupWithAckMessage(userId, groupName, ackId); @@ -1214,7 +1212,7 @@ private static CheckUserInGroupWithAckMessage CreateCheckUserInGroupWithAckMessa private static CheckGroupExistenceWithAckMessage CreateGroupExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var groupName = ReadString(ref reader, "groupName"); + var groupName = ReadStringNotNull(ref reader, "groupName"); var ackId = ReadInt32(ref reader, "ackId"); var result = new CheckGroupExistenceWithAckMessage(groupName, ackId); @@ -1224,7 +1222,7 @@ private static CheckGroupExistenceWithAckMessage CreateGroupExistenceWithAckMess private static CheckConnectionExistenceWithAckMessage CreateCheckConnectionExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var ackId = ReadInt32(ref reader, "ackId"); var result = new CheckConnectionExistenceWithAckMessage(connectionId, ackId); @@ -1234,7 +1232,7 @@ private static CheckConnectionExistenceWithAckMessage CreateCheckConnectionExist private static CheckUserExistenceWithAckMessage CreateCheckUserExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength) { - var userId = ReadString(ref reader, "userId"); + var userId = ReadStringNotNull(ref reader, "userId"); var ackId = ReadInt32(ref reader, "ackId"); var result = new CheckUserExistenceWithAckMessage(userId, ackId); @@ -1258,9 +1256,9 @@ private static AckMessage CreateAckMessage(ref MessagePackReader reader, int arr private static ClientInvocationMessage CreateClientInvocationMessage(ref MessagePackReader reader, int arrayLength) { - var invocationId = ReadString(ref reader, "invocationId"); - var connectionId = ReadString(ref reader, "connectionId"); - var callerServerId = ReadString(ref reader, "callerServerId"); + var invocationId = ReadStringNotNull(ref reader, "invocationId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); + var callerServerId = ReadStringNotNull(ref reader, "callerServerId"); var payloads = ReadPayloads(ref reader); var result = new ClientInvocationMessage(invocationId, connectionId, callerServerId, payloads); @@ -1270,9 +1268,9 @@ private static ClientInvocationMessage CreateClientInvocationMessage(ref Message private static ClientCompletionMessage CreateClientCompletionMessage(ref MessagePackReader reader, int arrayLength) { - var invocationId = ReadString(ref reader, "invocationId"); - var connectionId = ReadString(ref reader, "connectionId"); - var callerServerId = ReadString(ref reader, "callerServerId"); + var invocationId = ReadStringNotNull(ref reader, "invocationId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); + var callerServerId = ReadStringNotNull(ref reader, "callerServerId"); var protocol = ReadString(ref reader, "protocol"); var payload = ReadBytes(ref reader, "payload"); @@ -1284,9 +1282,9 @@ private static ClientCompletionMessage CreateClientCompletionMessage(ref Message private static ErrorCompletionMessage CreateErrorCompletionMessage(ref MessagePackReader reader, int arrayLength) { - var invocationId = ReadString(ref reader, "invocationId"); - var connectionId = ReadString(ref reader, "connectionId"); - var callerServerId = ReadString(ref reader, "callerServerId"); + var invocationId = ReadStringNotNull(ref reader, "invocationId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); + var callerServerId = ReadStringNotNull(ref reader, "callerServerId"); var error = ReadString(ref reader, "error"); var result = new ErrorCompletionMessage(invocationId, connectionId, callerServerId, error); @@ -1297,9 +1295,9 @@ private static ErrorCompletionMessage CreateErrorCompletionMessage(ref MessagePa private static ServiceMappingMessage CreateServiceMappingMessage(ref MessagePackReader reader, int arrayLength) { - var invocationId = ReadString(ref reader, "invocationId"); - var connectionId = ReadString(ref reader, "connectionId"); - var instanceId = ReadString(ref reader, "instanceId"); + var invocationId = ReadStringNotNull(ref reader, "invocationId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); + var instanceId = ReadStringNotNull(ref reader, "instanceId"); var result = new ServiceMappingMessage(invocationId, connectionId, instanceId); @@ -1309,7 +1307,7 @@ private static ServiceMappingMessage CreateServiceMappingMessage(ref MessagePack private static ConnectionFlowControlMessage CreateConnectionFlowControlMessage(ref MessagePackReader reader, int arrayLength) { - var connectionId = ReadString(ref reader, "connectionId"); + var connectionId = ReadStringNotNull(ref reader, "connectionId"); var connectionType = ReadInt32(ref reader, "connectionType"); var operation = ReadInt32(ref reader, "operation"); @@ -1368,7 +1366,9 @@ private static IDictionary> ReadPayloads(ref Messag var payloads = new ArrayDictionary>((int)payloadCount, StringComparer.OrdinalIgnoreCase); for (var i = 0; i < payloadCount; i++) { - var key = ReadString(ref reader, "payloads[{0}].key", i); + var keyName = $"payloads[{i}].key"; + + var key = ReadStringNotNull(ref reader, keyName); var value = ReadBytes(ref reader, "payloads[{0}].value", i); payloads.Add(key, value); } @@ -1387,9 +1387,10 @@ private static IDictionary ReadHeaders(ref MessagePackRead var headers = new Dictionary((int)headerCount, StringComparer.OrdinalIgnoreCase); for (var i = 0; i < headerCount; i++) { - var key = ReadString(ref reader, $"headers[{i}].key"); + var keyName = $"headers[{i}].key"; + var key = ReadStringNotNull(ref reader, keyName); var count = ReadArrayLength(ref reader, $"headers[{i}].value.length"); - var stringValues = new string[count]; + var stringValues = new string?[count]; for (var j = 0; j < count; j++) { stringValues[j] = ReadString(ref reader, $"headers[{i}].value[{j}]"); @@ -1412,6 +1413,7 @@ private static bool ReadBoolean(ref MessagePackReader reader, string field) catch (Exception ex) { throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex); + } } @@ -1425,10 +1427,9 @@ private static int ReadInt32(ref MessagePackReader reader, string field) { throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex); } - } - private static string ReadString(ref MessagePackReader reader, string field) + private static string? ReadString(ref MessagePackReader reader, string field) { try { @@ -1440,19 +1441,28 @@ private static string ReadString(ref MessagePackReader reader, string field) } } - private static string ReadString(ref MessagePackReader reader, string formatField, int param) + private static string ReadStringNotNull(ref MessagePackReader reader, string field) { + string? result = null; try { - return reader.ReadString(); + result = reader.ReadString(); } catch (Exception ex) { - throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as String failed.", ex); + throw new InvalidDataException($"Reading '{field}' as String failed.", ex); + } + + if (result == null) + { + throw new InvalidDataException($"Reading '{field}' as Not-Null String failed."); + } + + return result; } - private static string ReadString(ref MessagePackReader reader, string formatField, string param1, int param2) + private static string? ReadString(ref MessagePackReader reader, string formatField, int param) { try { @@ -1460,7 +1470,7 @@ private static string ReadString(ref MessagePackReader reader, string formatFiel } catch (Exception ex) { - throw new InvalidDataException($"Reading '{string.Format(formatField, param1, param2)}' as String failed.", ex); + throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as String failed.", ex); } } @@ -1472,7 +1482,8 @@ private static string[] ReadStringArray(ref MessagePackReader reader, string fie var array = new string[arrayLength]; for (int i = 0; i < arrayLength; i++) { - array[i] = ReadString(ref reader, "{0}[{1}]", field, i); + var fieldName = $"{field}[{i}]"; + array[i] = ReadStringNotNull(ref reader, fieldName); } return array; diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackReaderExtensions.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackReaderExtensions.cs index d4b6c8c4f..fa3198d0b 100644 --- a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackReaderExtensions.cs +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackReaderExtensions.cs @@ -24,7 +24,7 @@ public static int ReadInt32(ref this MessagePackReader reader, string field) } } - public static string ReadString(ref this MessagePackReader reader, string field) + public static string? ReadString(ref this MessagePackReader reader, string field) { try { @@ -87,7 +87,7 @@ public static int ReadArrayLength(ref this MessagePackReader reader, string fiel var map = new Dictionary(); for (var i = 0; i < propertyCount; i++) { - var key = reader.ReadString(); + var key = reader.ReadString().ThrowWhenNull(); var value = reader.ReadObject(field); map[key] = value; } @@ -115,4 +115,8 @@ public static void SkipHeader(ref this MessagePackReader reader) } } + public static string ThrowWhenNull(this string? input) + { + return input ?? throw new ArgumentNullException(nameof(input)); + } } diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackPitfalls.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackPitfalls.cs new file mode 100644 index 000000000..a3fb600dd --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackPitfalls.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace System.Diagnostics.CodeAnalysis; + +[AttributeUsage(AttributeTargets.Method, Inherited = false)] +internal sealed class DoesNotReturnAttribute : Attribute +{ +} diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj index e2a9d768a..d6f53bea8 100644 --- a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj @@ -4,6 +4,7 @@ .NET Standard SDK for Azure SignalR Service serverless protocol. netstandard2.0 Microsoft.Azure.SignalR.Serverless.Protocols + enable diff --git a/src/submodules/MessagePack-CSharp b/src/submodules/MessagePack-CSharp index 38f095bf0..d3d435b96 160000 --- a/src/submodules/MessagePack-CSharp +++ b/src/submodules/MessagePack-CSharp @@ -1 +1 @@ -Subproject commit 38f095bf0d9b72e2f9e8484282937f433bf3dfcb +Subproject commit d3d435b96cf09b9b0afc313bdea4257de6481c9f