diff --git a/build/dependencies.props b/build/dependencies.props
index 2134695c4..451dfddd2 100644
--- a/build/dependencies.props
+++ b/build/dependencies.props
@@ -13,7 +13,7 @@
1.0.0
2.1.0
4.5.1
- 4.5.4
+ 4.5.5
6.0.0
1.11.4
2.1.0
diff --git a/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs
index ec20550e1..ee5983391 100644
--- a/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs
+++ b/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-
+#nullable enable
using System;
using System.Threading;
using System.Threading.Tasks;
diff --git a/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs
index 297170049..1e2dd9bff 100644
--- a/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs
+++ b/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs
@@ -54,7 +54,7 @@ public OpenConnectionMessage(string connectionId, Claim[]? claims)
/// An array of associated with the connection.
/// A associated with the connection.
/// Query string associated with the connection.
- public OpenConnectionMessage(string connectionId, Claim[]? claims, IDictionary headers, string queryString)
+ public OpenConnectionMessage(string connectionId, Claim[]? claims, IDictionary headers, string? queryString)
: base(connectionId)
{
Claims = claims ?? [];
@@ -75,7 +75,7 @@ public OpenConnectionMessage(string connectionId, Claim[]? claims, IDictionary
/// Gets or sets the associated query string.
///
- public string QueryString { get; set; }
+ public string? QueryString { get; set; }
///
/// Gets or sets the protocol for new connection.
@@ -99,7 +99,7 @@ public class CloseConnectionMessage : ConnectionMessage, IMessageWithTracingId
/// The connection Id.
/// Optional error message.
/// A associated with the connection.
- public CloseConnectionMessage(string connectionId, string errorMessage, IDictionary? headers = null) : base(connectionId)
+ public CloseConnectionMessage(string connectionId, string? errorMessage, IDictionary? headers = null) : base(connectionId)
{
ErrorMessage = errorMessage ?? "";
Headers = headers ?? new Dictionary();
@@ -230,7 +230,7 @@ public class ClientCompletionMessage : ServiceCompletionMessage, IHasProtocol
/// The protocol of the connection.
/// The payload of the completion result.
/// The tracing Id of the message.
- public ClientCompletionMessage(string invocationId, string connectionId, string callerServerId, string protocol, ReadOnlyMemory payload, ulong? tracingId = null)
+ public ClientCompletionMessage(string invocationId, string connectionId, string callerServerId, string? protocol, ReadOnlyMemory payload, ulong? tracingId = null)
: base(invocationId, connectionId, callerServerId, tracingId)
{
Protocol = protocol;
@@ -261,10 +261,10 @@ public class ErrorCompletionMessage : ServiceCompletionMessage
/// The serverId that wrap the completion result.
/// The error information about invacation failure.
/// The tracing Id of the message.
- public ErrorCompletionMessage(string invocationId, string connectionId, string callerServerId, string error, ulong? tracingId = null)
+ public ErrorCompletionMessage(string invocationId, string connectionId, string callerServerId, string? error, ulong? tracingId = null)
: base(invocationId, connectionId, callerServerId, tracingId)
{
- Error = error;
+ Error = error ?? string.Empty;
}
///
diff --git a/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs
index 8a39a3807..da3266cc6 100644
--- a/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs
+++ b/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs
@@ -53,7 +53,7 @@ public class LeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId
///
/// Gets or sets the group name.
///
- public string GroupName { get; set; }
+ public string? GroupName { get; set; }
///
/// Gets or sets the tracing Id
@@ -68,7 +68,7 @@ public class LeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId
/// The connection Id.
/// The group name, from which the connection will leave.
/// The tracing Id of the message.
- public LeaveGroupMessage(string connectionId, string groupName, ulong? tracingId = null)
+ public LeaveGroupMessage(string connectionId, string? groupName, ulong? tracingId = null)
{
ConnectionId = connectionId;
GroupName = groupName;
@@ -130,7 +130,7 @@ public class UserLeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTraci
///
/// Gets or sets the group name.
///
- public string GroupName { get; set; }
+ public string? GroupName { get; set; }
///
/// Gets or sets the tracing Id
@@ -145,7 +145,7 @@ public class UserLeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTraci
/// The user Id.
/// The group name, from which the user will leave.
/// The tracing Id of the message.
- public UserLeaveGroupMessage(string userId, string groupName, ulong? tracingId = null)
+ public UserLeaveGroupMessage(string userId, string? groupName, ulong? tracingId = null)
{
UserId = userId;
GroupName = groupName;
@@ -216,7 +216,7 @@ public class UserLeaveGroupWithAckMessage : ExtensibleServiceMessage, IMessageWi
///
/// Gets or sets the group name.
///
- public string GroupName { get; set; }
+ public string? GroupName { get; set; }
///
/// Gets or sets the tracing Id
@@ -237,7 +237,7 @@ public class UserLeaveGroupWithAckMessage : ExtensibleServiceMessage, IMessageWi
/// The group name, from which the user will leave.
/// The ack Id.
/// The tracing Id of the message.
- public UserLeaveGroupWithAckMessage(string userId, string groupName, int ackId, ulong? tracingId = null)
+ public UserLeaveGroupWithAckMessage(string userId, string? groupName, int ackId, ulong? tracingId = null)
{
UserId = userId;
GroupName = groupName;
@@ -312,7 +312,7 @@ public class LeaveGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessag
///
/// Gets or sets the group name.
///
- public string GroupName { get; set; }
+ public string? GroupName { get; set; }
///
/// Gets or sets the ack id.
@@ -332,7 +332,7 @@ public class LeaveGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessag
/// The connection Id.
/// The group name, from which the connection will leave.
/// The tracing Id of the message.
- public LeaveGroupWithAckMessage(string connectionId, string groupName, ulong? tracingId = null): this(connectionId, groupName, 0, tracingId)
+ public LeaveGroupWithAckMessage(string connectionId, string? groupName, ulong? tracingId = null): this(connectionId, groupName, 0, tracingId)
{
}
@@ -343,7 +343,7 @@ public LeaveGroupWithAckMessage(string connectionId, string groupName, ulong? tr
/// The group name, from which the connection will leave.
/// The ack Id
/// The tracing Id of the message.
- public LeaveGroupWithAckMessage(string connectionId, string groupName, int ackId, ulong? tracingId = null)
+ public LeaveGroupWithAckMessage(string connectionId, string? groupName, int ackId, ulong? tracingId = null)
{
ConnectionId = connectionId;
GroupName = groupName;
diff --git a/src/Microsoft.Azure.SignalR.Protocols/MessagePackPitfalls.cs b/src/Microsoft.Azure.SignalR.Protocols/MessagePackPitfalls.cs
new file mode 100644
index 000000000..a3fb600dd
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Protocols/MessagePackPitfalls.cs
@@ -0,0 +1,13 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace System.Diagnostics.CodeAnalysis;
+
+[AttributeUsage(AttributeTargets.Method, Inherited = false)]
+internal sealed class DoesNotReturnAttribute : Attribute
+{
+}
diff --git a/src/Microsoft.Azure.SignalR.Protocols/Microsoft.Azure.SignalR.Protocols.csproj b/src/Microsoft.Azure.SignalR.Protocols/Microsoft.Azure.SignalR.Protocols.csproj
index f1bec5ab8..1e2b96677 100644
--- a/src/Microsoft.Azure.SignalR.Protocols/Microsoft.Azure.SignalR.Protocols.csproj
+++ b/src/Microsoft.Azure.SignalR.Protocols/Microsoft.Azure.SignalR.Protocols.csproj
@@ -4,6 +4,7 @@
netstandard2.0
Microsoft.Azure.SignalR.Protocol
true
+ enable
diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs
index a4cd8d0f5..02da7ac0d 100644
--- a/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs
+++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs
@@ -71,7 +71,7 @@ public abstract class ServiceMessage
///
public virtual ServiceMessage Clone() => (MemberwiseClone() as ServiceMessage)!;
- public static byte GeneratePartitionKey(string input)
+ public static byte GeneratePartitionKey(string? input)
{
return (byte)((input?.GetHashCode() ?? 0) & 0xFF);
}
@@ -393,7 +393,7 @@ public class HandshakeResponseMessage : ExtensibleServiceMessage
///
/// Gets or sets the optional error message.
///
- public string ErrorMessage { get; set; }
+ public string? ErrorMessage { get; set; }
///
/// Gets or sets the id of this connection.
@@ -411,7 +411,7 @@ public HandshakeResponseMessage() : this(string.Empty)
/// Initializes a new instance of the class.
///
/// An optional response error message. A null or empty error message indicates a successful handshake.
- public HandshakeResponseMessage(string errorMessage)
+ public HandshakeResponseMessage(string? errorMessage)
{
ErrorMessage = errorMessage;
}
@@ -427,7 +427,7 @@ public class PingMessage : ServiceMessage
///
public static PingMessage Instance = new PingMessage();
- public string[] Messages { get; set; } = Array.Empty();
+ public string?[] Messages { get; set; } = Array.Empty();
}
///
@@ -444,9 +444,9 @@ public class ServiceErrorMessage : ServiceMessage
/// Initializes a new instance of the class.
///
/// An error message.
- public ServiceErrorMessage(string errorMessage)
+ public ServiceErrorMessage(string? errorMessage)
{
- ErrorMessage = errorMessage;
+ ErrorMessage = errorMessage ?? string.Empty;
}
}
@@ -463,7 +463,7 @@ public class ServiceEventMessage : ExtensibleServiceMessage
///
/// Gets or sets the id of event object.
///
- public string Id { get; set; }
+ public string? Id { get; set; }
///
/// Gets or sets the kind of event.
@@ -482,12 +482,12 @@ public class ServiceEventMessage : ExtensibleServiceMessage
/// An id of event object.
/// A kind of event.
/// A message of event.
- public ServiceEventMessage(ServiceEventObjectType type, string id, ServiceEventKind kind, string message)
+ public ServiceEventMessage(ServiceEventObjectType type, string? id, ServiceEventKind kind, string? message)
{
Type = type;
Id = id;
Kind = kind;
- Message = message;
+ Message = message ?? string.Empty;
}
}
@@ -526,11 +526,11 @@ public AckMessage(int ackId, int status) : this(ackId, status, string.Empty)
/// The ack Id
/// The status code
/// The ack message
- public AckMessage(int ackId, int status, string message)
+ public AckMessage(int ackId, int status, string? message)
{
AckId = ackId;
Status = status;
- Message = message;
+ Message = message ?? string.Empty;
}
}
diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs
index 133a770a3..a258e536a 100644
--- a/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs
+++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs
@@ -12,8 +12,6 @@
namespace Microsoft.Azure.SignalR.Protocol;
-#nullable enable
-
///
/// Implements the Azure SignalR Service Protocol.
///
@@ -700,7 +698,7 @@ private static void WriteErrorCompletionMessage(ref MessagePackWriter writer, Er
writer.Write(message.Error);
message.WriteExtensionMembers(ref writer);
}
-
+
private static void WriteServiceMappingMessage(ref MessagePackWriter writer, ServiceMappingMessage message)
{
writer.WriteArrayHeader(5);
@@ -849,7 +847,7 @@ private static PingMessage CreatePingMessage(ref MessagePackReader reader, int a
if (arrayLength > 1)
{
var length = arrayLength - 1;
- var values = new string[length];
+ var values = new string?[length];
for (int i = 0; i < length; i++)
{
values[i] = ReadString(ref reader, "messages[{0}]", i);
@@ -862,7 +860,7 @@ private static PingMessage CreatePingMessage(ref MessagePackReader reader, int a
private static OpenConnectionMessage CreateOpenConnectionMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var claims = ReadClaims(ref reader);
// Backward compatible with old versions
@@ -885,7 +883,7 @@ private static OpenConnectionMessage CreateOpenConnectionMessage(ref MessagePack
private static CloseConnectionMessage CreateCloseConnectionMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var errorMessage = ReadString(ref reader, "errorMessage");
var headers = arrayLength >= 4 ? ReadHeaders(ref reader) : new Dictionary();
var result = new CloseConnectionMessage(connectionId, errorMessage, headers);
@@ -899,7 +897,7 @@ private static CloseConnectionMessage CreateCloseConnectionMessage(ref MessagePa
[Obsolete]
private static CloseConnectionWithAckMessage CreateCloseConnectionWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var reason = ReadString(ref reader, "reason");
var ackId = ReadInt32(ref reader, "ackId");
var result = new CloseConnectionWithAckMessage(connectionId, ackId)
@@ -934,7 +932,7 @@ private static CloseConnectionsWithAckMessage CreateCloseConnectionsWithAckMessa
private static CloseUserConnectionsWithAckMessage CreateCloseUserConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var userId = ReadString(ref reader, "userId");
+ var userId = ReadStringNotNull(ref reader, "userId");
var reason = ReadString(ref reader, "reason");
var ackId = ReadInt32(ref reader, "ackId");
var excluded = ReadStringArray(ref reader, "excluded");
@@ -953,7 +951,7 @@ private static CloseUserConnectionsWithAckMessage CreateCloseUserConnectionsWith
private static CloseGroupConnectionsWithAckMessage CreateCloseGroupConnectionsWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var group = ReadString(ref reader, "group");
+ var group = ReadStringNotNull(ref reader, "group");
var reason = ReadString(ref reader, "reason");
var ackId = ReadInt32(ref reader, "ackId");
var excluded = ReadStringArray(ref reader, "excluded");
@@ -972,7 +970,7 @@ private static CloseGroupConnectionsWithAckMessage CreateCloseGroupConnectionsWi
private static ConnectionDataMessage CreateConnectionDataMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var payload = ReadBytes(ref reader, "payload");
var result = new ConnectionDataMessage(connectionId, payload);
@@ -985,7 +983,7 @@ private static ConnectionDataMessage CreateConnectionDataMessage(ref MessagePack
private static ConnectionReconnectMessage CreateConnectionReconnectMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var result = new ConnectionReconnectMessage(connectionId);
result.ReadExtensionMembers(ref reader);
@@ -1007,7 +1005,7 @@ private static MultiConnectionDataMessage CreateMultiConnectionDataMessage(ref M
private static ServiceMessage CreateUserDataMessage(ref MessagePackReader reader, int arrayLength)
{
- var userId = ReadString(ref reader, "userId");
+ var userId = ReadStringNotNull(ref reader, "userId");
var payloads = ReadPayloads(ref reader);
var result = new UserDataMessage(userId, payloads);
@@ -1046,8 +1044,8 @@ private static BroadcastDataMessage CreateBroadcastDataMessage(ref MessagePackRe
private static JoinGroupMessage CreateJoinGroupMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
- var groupName = ReadString(ref reader, "groupName");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
+ var groupName = ReadStringNotNull(ref reader, "groupName");
var result = new JoinGroupMessage(connectionId, groupName);
if (arrayLength >= 4)
@@ -1059,7 +1057,7 @@ private static JoinGroupMessage CreateJoinGroupMessage(ref MessagePackReader rea
private static LeaveGroupMessage CreateLeaveGroupMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var groupName = ReadString(ref reader, "groupName");
var result = new LeaveGroupMessage(connectionId, groupName);
@@ -1072,8 +1070,8 @@ private static LeaveGroupMessage CreateLeaveGroupMessage(ref MessagePackReader r
private static UserJoinGroupMessage CreateUserJoinGroupMessage(ref MessagePackReader reader, int arrayLength)
{
- var userId = ReadString(ref reader, "userId");
- var groupName = ReadString(ref reader, "groupName");
+ var userId = ReadStringNotNull(ref reader, "userId");
+ var groupName = ReadStringNotNull(ref reader, "groupName");
var result = new UserJoinGroupMessage(userId, groupName);
if (arrayLength >= 4)
@@ -1085,7 +1083,7 @@ private static UserJoinGroupMessage CreateUserJoinGroupMessage(ref MessagePackRe
private static UserLeaveGroupMessage CreateUserLeaveGroupMessage(ref MessagePackReader reader, int arrayLength)
{
- var userId = ReadString(ref reader, "userId");
+ var userId = ReadStringNotNull(ref reader, "userId");
var groupName = ReadString(ref reader, "groupName");
var result = new UserLeaveGroupMessage(userId, groupName);
@@ -1098,8 +1096,8 @@ private static UserLeaveGroupMessage CreateUserLeaveGroupMessage(ref MessagePack
private static UserJoinGroupWithAckMessage CreateUserJoinGroupWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var userId = ReadString(ref reader, "userId");
- var groupName = ReadString(ref reader, "groupName");
+ var userId = ReadStringNotNull(ref reader, "userId");
+ var groupName = ReadStringNotNull(ref reader, "groupName");
var ackId = ReadInt32(ref reader, "ackId");
var result = new UserJoinGroupWithAckMessage(userId, groupName, ackId);
@@ -1109,7 +1107,7 @@ private static UserJoinGroupWithAckMessage CreateUserJoinGroupWithAckMessage(ref
private static UserLeaveGroupWithAckMessage CreateUserLeaveGroupWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var userId = ReadString(ref reader, "userId");
+ var userId = ReadStringNotNull(ref reader, "userId");
var groupName = ReadString(ref reader, "groupName");
var ackId = ReadInt32(ref reader, "ackId");
@@ -1120,7 +1118,7 @@ private static UserLeaveGroupWithAckMessage CreateUserLeaveGroupWithAckMessage(r
private static GroupBroadcastDataMessage CreateGroupBroadcastDataMessage(ref MessagePackReader reader, int arrayLength)
{
- var groupName = ReadString(ref reader, "groupName");
+ var groupName = ReadStringNotNull(ref reader, "groupName");
var excludedList = ReadStringArray(ref reader, "excludedList");
var payloads = ReadPayloads(ref reader);
@@ -1172,8 +1170,8 @@ private static ServiceEventMessage CreateServiceEventMessage(ref MessagePackRead
private static JoinGroupWithAckMessage CreateJoinGroupWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
- var groupName = ReadString(ref reader, "groupName");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
+ var groupName = ReadStringNotNull(ref reader, "groupName");
var ackId = ReadInt32(ref reader, "ackId");
var result = new JoinGroupWithAckMessage(connectionId, groupName, ackId);
@@ -1186,7 +1184,7 @@ private static JoinGroupWithAckMessage CreateJoinGroupWithAckMessage(ref Message
private static LeaveGroupWithAckMessage CreateLeaveGroupWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var groupName = ReadString(ref reader, "groupName");
var ackId = ReadInt32(ref reader, "ackId");
@@ -1200,8 +1198,8 @@ private static LeaveGroupWithAckMessage CreateLeaveGroupWithAckMessage(ref Messa
private static CheckUserInGroupWithAckMessage CreateCheckUserInGroupWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var userId = ReadString(ref reader, "userId");
- var groupName = ReadString(ref reader, "groupName");
+ var userId = ReadStringNotNull(ref reader, "userId");
+ var groupName = ReadStringNotNull(ref reader, "groupName");
var ackId = ReadInt32(ref reader, "ackId");
var result = new CheckUserInGroupWithAckMessage(userId, groupName, ackId);
@@ -1214,7 +1212,7 @@ private static CheckUserInGroupWithAckMessage CreateCheckUserInGroupWithAckMessa
private static CheckGroupExistenceWithAckMessage CreateGroupExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var groupName = ReadString(ref reader, "groupName");
+ var groupName = ReadStringNotNull(ref reader, "groupName");
var ackId = ReadInt32(ref reader, "ackId");
var result = new CheckGroupExistenceWithAckMessage(groupName, ackId);
@@ -1224,7 +1222,7 @@ private static CheckGroupExistenceWithAckMessage CreateGroupExistenceWithAckMess
private static CheckConnectionExistenceWithAckMessage CreateCheckConnectionExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var ackId = ReadInt32(ref reader, "ackId");
var result = new CheckConnectionExistenceWithAckMessage(connectionId, ackId);
@@ -1234,7 +1232,7 @@ private static CheckConnectionExistenceWithAckMessage CreateCheckConnectionExist
private static CheckUserExistenceWithAckMessage CreateCheckUserExistenceWithAckMessage(ref MessagePackReader reader, int arrayLength)
{
- var userId = ReadString(ref reader, "userId");
+ var userId = ReadStringNotNull(ref reader, "userId");
var ackId = ReadInt32(ref reader, "ackId");
var result = new CheckUserExistenceWithAckMessage(userId, ackId);
@@ -1258,9 +1256,9 @@ private static AckMessage CreateAckMessage(ref MessagePackReader reader, int arr
private static ClientInvocationMessage CreateClientInvocationMessage(ref MessagePackReader reader, int arrayLength)
{
- var invocationId = ReadString(ref reader, "invocationId");
- var connectionId = ReadString(ref reader, "connectionId");
- var callerServerId = ReadString(ref reader, "callerServerId");
+ var invocationId = ReadStringNotNull(ref reader, "invocationId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
+ var callerServerId = ReadStringNotNull(ref reader, "callerServerId");
var payloads = ReadPayloads(ref reader);
var result = new ClientInvocationMessage(invocationId, connectionId, callerServerId, payloads);
@@ -1270,9 +1268,9 @@ private static ClientInvocationMessage CreateClientInvocationMessage(ref Message
private static ClientCompletionMessage CreateClientCompletionMessage(ref MessagePackReader reader, int arrayLength)
{
- var invocationId = ReadString(ref reader, "invocationId");
- var connectionId = ReadString(ref reader, "connectionId");
- var callerServerId = ReadString(ref reader, "callerServerId");
+ var invocationId = ReadStringNotNull(ref reader, "invocationId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
+ var callerServerId = ReadStringNotNull(ref reader, "callerServerId");
var protocol = ReadString(ref reader, "protocol");
var payload = ReadBytes(ref reader, "payload");
@@ -1284,9 +1282,9 @@ private static ClientCompletionMessage CreateClientCompletionMessage(ref Message
private static ErrorCompletionMessage CreateErrorCompletionMessage(ref MessagePackReader reader, int arrayLength)
{
- var invocationId = ReadString(ref reader, "invocationId");
- var connectionId = ReadString(ref reader, "connectionId");
- var callerServerId = ReadString(ref reader, "callerServerId");
+ var invocationId = ReadStringNotNull(ref reader, "invocationId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
+ var callerServerId = ReadStringNotNull(ref reader, "callerServerId");
var error = ReadString(ref reader, "error");
var result = new ErrorCompletionMessage(invocationId, connectionId, callerServerId, error);
@@ -1297,9 +1295,9 @@ private static ErrorCompletionMessage CreateErrorCompletionMessage(ref MessagePa
private static ServiceMappingMessage CreateServiceMappingMessage(ref MessagePackReader reader, int arrayLength)
{
- var invocationId = ReadString(ref reader, "invocationId");
- var connectionId = ReadString(ref reader, "connectionId");
- var instanceId = ReadString(ref reader, "instanceId");
+ var invocationId = ReadStringNotNull(ref reader, "invocationId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
+ var instanceId = ReadStringNotNull(ref reader, "instanceId");
var result = new ServiceMappingMessage(invocationId, connectionId, instanceId);
@@ -1309,7 +1307,7 @@ private static ServiceMappingMessage CreateServiceMappingMessage(ref MessagePack
private static ConnectionFlowControlMessage CreateConnectionFlowControlMessage(ref MessagePackReader reader, int arrayLength)
{
- var connectionId = ReadString(ref reader, "connectionId");
+ var connectionId = ReadStringNotNull(ref reader, "connectionId");
var connectionType = ReadInt32(ref reader, "connectionType");
var operation = ReadInt32(ref reader, "operation");
@@ -1368,7 +1366,9 @@ private static IDictionary> ReadPayloads(ref Messag
var payloads = new ArrayDictionary>((int)payloadCount, StringComparer.OrdinalIgnoreCase);
for (var i = 0; i < payloadCount; i++)
{
- var key = ReadString(ref reader, "payloads[{0}].key", i);
+ var keyName = $"payloads[{i}].key";
+
+ var key = ReadStringNotNull(ref reader, keyName);
var value = ReadBytes(ref reader, "payloads[{0}].value", i);
payloads.Add(key, value);
}
@@ -1387,9 +1387,10 @@ private static IDictionary ReadHeaders(ref MessagePackRead
var headers = new Dictionary((int)headerCount, StringComparer.OrdinalIgnoreCase);
for (var i = 0; i < headerCount; i++)
{
- var key = ReadString(ref reader, $"headers[{i}].key");
+ var keyName = $"headers[{i}].key";
+ var key = ReadStringNotNull(ref reader, keyName);
var count = ReadArrayLength(ref reader, $"headers[{i}].value.length");
- var stringValues = new string[count];
+ var stringValues = new string?[count];
for (var j = 0; j < count; j++)
{
stringValues[j] = ReadString(ref reader, $"headers[{i}].value[{j}]");
@@ -1412,6 +1413,7 @@ private static bool ReadBoolean(ref MessagePackReader reader, string field)
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex);
+
}
}
@@ -1425,10 +1427,9 @@ private static int ReadInt32(ref MessagePackReader reader, string field)
{
throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex);
}
-
}
- private static string ReadString(ref MessagePackReader reader, string field)
+ private static string? ReadString(ref MessagePackReader reader, string field)
{
try
{
@@ -1440,19 +1441,28 @@ private static string ReadString(ref MessagePackReader reader, string field)
}
}
- private static string ReadString(ref MessagePackReader reader, string formatField, int param)
+ private static string ReadStringNotNull(ref MessagePackReader reader, string field)
{
+ string? result = null;
try
{
- return reader.ReadString();
+ result = reader.ReadString();
}
catch (Exception ex)
{
- throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as String failed.", ex);
+ throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
+
}
+
+ if (result == null)
+ {
+ throw new InvalidDataException($"Reading '{field}' as Not-Null String failed.");
+ }
+
+ return result;
}
- private static string ReadString(ref MessagePackReader reader, string formatField, string param1, int param2)
+ private static string? ReadString(ref MessagePackReader reader, string formatField, int param)
{
try
{
@@ -1460,7 +1470,7 @@ private static string ReadString(ref MessagePackReader reader, string formatFiel
}
catch (Exception ex)
{
- throw new InvalidDataException($"Reading '{string.Format(formatField, param1, param2)}' as String failed.", ex);
+ throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as String failed.", ex);
}
}
@@ -1472,7 +1482,8 @@ private static string[] ReadStringArray(ref MessagePackReader reader, string fie
var array = new string[arrayLength];
for (int i = 0; i < arrayLength; i++)
{
- array[i] = ReadString(ref reader, "{0}[{1}]", field, i);
+ var fieldName = $"{field}[{i}]";
+ array[i] = ReadStringNotNull(ref reader, fieldName);
}
return array;
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackReaderExtensions.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackReaderExtensions.cs
index d4b6c8c4f..fa3198d0b 100644
--- a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackReaderExtensions.cs
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackReaderExtensions.cs
@@ -24,7 +24,7 @@ public static int ReadInt32(ref this MessagePackReader reader, string field)
}
}
- public static string ReadString(ref this MessagePackReader reader, string field)
+ public static string? ReadString(ref this MessagePackReader reader, string field)
{
try
{
@@ -87,7 +87,7 @@ public static int ReadArrayLength(ref this MessagePackReader reader, string fiel
var map = new Dictionary();
for (var i = 0; i < propertyCount; i++)
{
- var key = reader.ReadString();
+ var key = reader.ReadString().ThrowWhenNull();
var value = reader.ReadObject(field);
map[key] = value;
}
@@ -115,4 +115,8 @@ public static void SkipHeader(ref this MessagePackReader reader)
}
}
+ public static string ThrowWhenNull(this string? input)
+ {
+ return input ?? throw new ArgumentNullException(nameof(input));
+ }
}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackPitfalls.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackPitfalls.cs
new file mode 100644
index 000000000..a3fb600dd
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackPitfalls.cs
@@ -0,0 +1,13 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace System.Diagnostics.CodeAnalysis;
+
+[AttributeUsage(AttributeTargets.Method, Inherited = false)]
+internal sealed class DoesNotReturnAttribute : Attribute
+{
+}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj
index e2a9d768a..d6f53bea8 100644
--- a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj
@@ -4,6 +4,7 @@
.NET Standard SDK for Azure SignalR Service serverless protocol.
netstandard2.0
Microsoft.Azure.SignalR.Serverless.Protocols
+ enable
diff --git a/src/submodules/MessagePack-CSharp b/src/submodules/MessagePack-CSharp
index 38f095bf0..d3d435b96 160000
--- a/src/submodules/MessagePack-CSharp
+++ b/src/submodules/MessagePack-CSharp
@@ -1 +1 @@
-Subproject commit 38f095bf0d9b72e2f9e8484282937f433bf3dfcb
+Subproject commit d3d435b96cf09b9b0afc313bdea4257de6481c9f