From e7e8d0d7eb01aa492acad9b777fcd6aa339698d5 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 27 Nov 2024 11:54:33 +0100 Subject: [PATCH] [NRBF] Fixes and fuzzing improvements (#110194) * Simplify array handling to fix issues with jagged and abstract array types Jagged arrays in the payload can contain cycles. In that scenario, no value is correct for `ArrayRecord.FlattenedLength`, and `ArrayRecord.GetArray` does not have enough context to know how to handle the cycles. To address these issues, jagged array handling is simplified so that calling code can handle the cycles in the most appropriate way for the application. Single-dimension arrays can be represented in the payload using abstract types such as `IComparable[]` where the concrete element type is not known. When the concrete element type is known, `ArrayRecord.GetArray` could return either `SZArrayRecord` or `SZArrayRecord`; without a concrete type, we need to return something that represents the element abstractly. 1. `ArrayRecord.FlattenedLength` is removed from the API 2. `ArrayRecord.GetArray` now returns `ArrayRecord[]` for jagged arrays instead of trying to populate them 3. `ArrayRecord.GetArray` now returns `SZArrayRecord` for single-dimension arrays instead of either `SZArrayRecord` or `SZArrayRecord` * extend the Fuzzer to consume all possible data exposed by the NrbfDecoder --- .../Fuzzers/NrbfDecoderFuzzer.cs | 269 +++++++-- .../ref/System.Formats.Nrbf.cs | 1 - .../System.Formats.Nrbf/src/PACKAGE.md | 2 +- .../System/Formats/Nrbf/AllowedRecordType.cs | 3 +- .../src/System/Formats/Nrbf/ArrayRecord.cs | 91 ++- .../Nrbf/ArrayRectangularPrimitiveRecord.cs | 83 +++ .../Formats/Nrbf/ArraySingleObjectRecord.cs | 26 +- .../Nrbf/ArraySinglePrimitiveRecord.cs | 75 ++- .../Formats/Nrbf/ArraySingleStringRecord.cs | 2 +- .../System/Formats/Nrbf/BinaryArrayRecord.cs | 309 ----------- .../System/Formats/Nrbf/ClassWithIdRecord.cs | 48 +- .../System/Formats/Nrbf/JaggedArrayRecord.cs | 65 +++ .../src/System/Formats/Nrbf/MemberTypeInfo.cs | 43 +- .../src/System/Formats/Nrbf/NrbfDecoder.cs | 173 +++++- .../src/System/Formats/Nrbf/RecordMap.cs | 13 +- .../Formats/Nrbf/RectangularArrayRecord.cs | 243 +-------- ...OfClassesRecord.cs => SZArrayOfRecords.cs} | 22 +- .../SystemClassWithMembersAndTypesRecord.cs | 137 +++-- .../Nrbf/Utils/BinaryReaderExtensions.cs | 2 +- .../tests/ArrayOfSerializationRecordsTests.cs | 516 ++++++++++++++++++ .../tests/ArraySinglePrimitiveRecordTests.cs | 85 ++- .../System.Formats.Nrbf/tests/AttackTests.cs | 47 +- .../tests/InvalidInputTests.cs | 30 + .../tests/JaggedArraysTests.cs | 152 ++++-- .../tests/ReadAnythingTests.cs | 65 ++- .../tests/ReadExactTypesTests.cs | 27 +- .../System.Formats.Nrbf/tests/ReadTests.cs | 34 ++ .../tests/RectangularArraysTests.cs | 21 +- .../tests/TypeMatchTests.cs | 2 +- .../Deserializer/ArrayRecordDeserializer.cs | 17 +- .../BinaryFormat/Deserializer/Deserializer.cs | 2 +- .../Common/MultidimensionalArrayTests.cs | 78 +++ .../BinaryFormattedObjectTests.cs | 30 +- .../FormattedObject/HashTableTests.cs | 6 +- .../FormattedObject/ListTests.cs | 14 +- 35 files changed, 1855 insertions(+), 878 deletions(-) create mode 100644 src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs delete mode 100644 src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs create mode 100644 src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs rename src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/{ArrayOfClassesRecord.cs => SZArrayOfRecords.cs} (69%) create mode 100644 src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs diff --git a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs index f217649927b15..6ae41e94782ac 100644 --- a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs +++ b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs @@ -3,6 +3,8 @@ using System.Buffers; using System.Formats.Nrbf; +using System.Numerics; +using System.Runtime.CompilerServices; using System.Runtime.Serialization; using System.Text; @@ -38,64 +40,17 @@ private static void Test(Span testSpan, Stream stream) { if (NrbfDecoder.StartsWithPayloadHeader(testSpan)) { + HashSet visited = new(); + Queue queue = new(); try { SerializationRecord record = NrbfDecoder.Decode(stream, out IReadOnlyDictionary recordMap); - switch (record.RecordType) + + Assert.Equal(true, recordMap.ContainsKey(record.Id)); // make sure the loop below includes it + foreach (SerializationRecord fromMap in recordMap.Values) { - case SerializationRecordType.ArraySingleObject: - SZArrayRecord arrayObj = (SZArrayRecord)record; - object?[] objArray = arrayObj.GetArray(); - Assert.Equal(arrayObj.Length, objArray.Length); - Assert.Equal(1, arrayObj.Rank); - break; - case SerializationRecordType.ArraySingleString: - SZArrayRecord arrayString = (SZArrayRecord)record; - string?[] array = arrayString.GetArray(); - Assert.Equal(arrayString.Length, array.Length); - Assert.Equal(1, arrayString.Rank); - Assert.Equal(true, arrayString.TypeNameMatches(typeof(string[]))); - break; - case SerializationRecordType.ArraySinglePrimitive: - case SerializationRecordType.BinaryArray: - ArrayRecord arrayBinary = (ArrayRecord)record; - Assert.NotNull(arrayBinary.TypeName); - break; - case SerializationRecordType.BinaryObjectString: - _ = ((PrimitiveTypeRecord)record).Value; - break; - case SerializationRecordType.ClassWithId: - case SerializationRecordType.ClassWithMembersAndTypes: - case SerializationRecordType.SystemClassWithMembersAndTypes: - ClassRecord classRecord = (ClassRecord)record; - Assert.NotNull(classRecord.TypeName); - - foreach (string name in classRecord.MemberNames) - { - Assert.Equal(true, classRecord.HasMember(name)); - } - break; - case SerializationRecordType.MemberPrimitiveTyped: - PrimitiveTypeRecord primitiveType = (PrimitiveTypeRecord)record; - Assert.NotNull(primitiveType.Value); - break; - case SerializationRecordType.MemberReference: - Assert.NotNull(record.TypeName); - break; - case SerializationRecordType.BinaryLibrary: - Assert.Equal(false, record.Id.Equals(default)); - break; - case SerializationRecordType.ObjectNull: - case SerializationRecordType.ObjectNullMultiple: - case SerializationRecordType.ObjectNullMultiple256: - Assert.Equal(default, record.Id); - break; - case SerializationRecordType.MessageEnd: - case SerializationRecordType.SerializedStreamHeader: - // case SerializationRecordType.ClassWithMembers: will cause NotSupportedException - // case SerializationRecordType.SystemClassWithMembers: will cause NotSupportedException - default: - throw new Exception("Unexpected RecordType"); + visited.Add(fromMap.Id); + queue.Enqueue(fromMap); } } catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ } @@ -103,6 +58,9 @@ private static void Test(Span testSpan, Stream stream) catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ } catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ } catch (IOException) { /* An I/O error occurred. */ } + + // Lets consume it outside of the try/catch block to not swallow any exceptions by accident. + Consume(visited, queue); } else { @@ -117,6 +75,209 @@ private static void Test(Span testSpan, Stream stream) } } + private static void Consume(HashSet visited, Queue queue) + { + while (queue.Count > 0) + { + SerializationRecord serializationRecord = queue.Dequeue(); + + if (serializationRecord is PrimitiveTypeRecord primitiveTypeRecord) + { + ConsumePrimitiveValue(primitiveTypeRecord.Value); + } + else if (serializationRecord is ClassRecord classRecord) + { + foreach (string memberName in classRecord.MemberNames) + { + ConsumePrimitiveValue(memberName); + + Assert.Equal(true, classRecord.HasMember(memberName)); + + object? rawValue; + + try + { + rawValue = classRecord.GetRawValue(memberName); + } + catch (SerializationException ex) when (ex.Message == "Invalid member reference.") + { + // It was a reference to a non-existing record, just continue. + continue; + } + + if (rawValue is not null) + { + if (rawValue is SerializationRecord nestedRecord) + { + TryEnqueue(nestedRecord); + } + else + { + ConsumePrimitiveValue(rawValue); + } + } + } + } + else if (serializationRecord is ArrayRecord arrayRecord) + { + Type? type; + + try + { + // THIS IS VERY BAD IDEA FOR ANY KIND OF PRODUCT CODE!! + // IT'S USED ONLY FOR THE PURPOSE OF TESTING, DO NOT COPY IT. + type = Type.GetType(arrayRecord.TypeName.AssemblyQualifiedName, throwOnError: false); + if (type is null) + { + continue; + } + } + catch (Exception) // throwOnError passed to GetType does not prevent from all kinds of exceptions + { + // It was some type made up by the Fuzzer. + // Since it's currently impossible to get the array without providing the type, + // we just bail here (in the future we may add an enumerator to ArrayRecord). + continue; + } + + Array? array; + try + { + array = arrayRecord.GetArray(type); + } + catch (SerializationException ex) when (ex.Message == "Invalid member reference.") + { + // It contained a reference to a non-existing record, just continue. + continue; + } + + ReadOnlySpan lengths = arrayRecord.Lengths; + long totalElementsCount = 1; + for (int i = 0; i < arrayRecord.Rank; i++) + { + Assert.Equal(lengths[i], array.GetLength(i)); + totalElementsCount *= lengths[i]; + } + + // This array contains indices that are used to get values of multi-dimensional array. + // At the beginning, all values are set to 0, so we start from the first element. + int[] indices = new int[arrayRecord.Rank]; + + long flatIndex = 0; + for (; flatIndex < totalElementsCount; flatIndex++) + { + object? rawValue = array.GetValue(indices); + if (rawValue is not null) + { + if (rawValue is SerializationRecord record) + { + TryEnqueue(record); + } + else + { + ConsumePrimitiveValue(rawValue); + } + } + + // The loop below is responsible for incrementing the multi-dimensional indices. + // It finds the dimension and then performs an increment. + int dimension = indices.Length - 1; + while (dimension >= 0) + { + indices[dimension]++; + if (indices[dimension] < lengths[dimension]) + { + break; + } + indices[dimension] = 0; + dimension--; + } + } + + // We track the flat index to ensure that we have enumerated over all elements. + Assert.Equal(totalElementsCount, flatIndex); + } + else + { + // The map may currently contain it (it may change in the future) + Assert.Equal(SerializationRecordType.BinaryLibrary, serializationRecord.RecordType); + } + } + + void TryEnqueue(SerializationRecord record) + { + if (visited.Add(record.Id)) // avoid unbounded recursion + { + queue.Enqueue(record); + } + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ConsumePrimitiveValue(object value) + { + if (value is string text) + Assert.Equal(text, text.ToString()); // we want to touch all elements to see if memory is not corrupted + else if (value is bool boolean) + Assert.Equal(true, Unsafe.BitCast(boolean) is 1 or 0); // other values are illegal!! + else if (value is sbyte @sbyte) + TestNumber(@sbyte); + else if (value is byte @byte) + TestNumber(@byte); + else if (value is char character) + TestNumber(character); + else if (value is short @short) + TestNumber(@short); + else if (value is ushort @ushort) + TestNumber(@ushort); + else if (value is int integer) + TestNumber(integer); + else if (value is uint @uint) + TestNumber(@uint); + else if (value is long @long) + TestNumber(@long); + else if (value is ulong @ulong) + TestNumber(@ulong); + else if (value is float @float) + { + if (!float.IsNaN(@float) && !float.IsInfinity(@float)) + { + TestNumber(@float); + } + } + else if (value is double @double) + { + if (!double.IsNaN(@double) && !double.IsInfinity(@double)) + { + TestNumber(@double); + } + } + else if (value is decimal @decimal) + TestNumber(@decimal); + else if (value is nint @nint) + TestNumber(@nint); + else if (value is nuint @nuint) + TestNumber(@nuint); + else if (value is DateTime datetime) + Assert.Equal(true, datetime >= DateTime.MinValue && datetime <= DateTime.MaxValue); + else if (value is TimeSpan timeSpan) + Assert.Equal(true, timeSpan >= TimeSpan.MinValue && timeSpan <= TimeSpan.MaxValue); + else + throw new InvalidOperationException(); + + static void TestNumber(T value) where T : IComparable, IMinMaxValue + { + if (value.CompareTo(T.MinValue) < 0) + { + throw new Exception($"Expected {value} to be more or equal {T.MinValue}, {value.CompareTo(T.MinValue)}."); + } + if (value.CompareTo(T.MaxValue) > 0) + { + throw new Exception($"Expected {value} to be less or equal {T.MaxValue}, {value.CompareTo(T.MaxValue)}."); + } + } + } + private sealed class NonSeekableStream : MemoryStream { public NonSeekableStream(byte[] buffer) : base(buffer) { } diff --git a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs index f82bbb96732c9..292a5eb1038d5 100644 --- a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs +++ b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs @@ -9,7 +9,6 @@ namespace System.Formats.Nrbf public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRecord { internal ArrayRecord() { } - public virtual long FlattenedLength { get { throw null; } } public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } } public abstract System.ReadOnlySpan Lengths { get; } public int Rank { get { throw null; } } diff --git a/src/libraries/System.Formats.Nrbf/src/PACKAGE.md b/src/libraries/System.Formats.Nrbf/src/PACKAGE.md index c301459358838..23e5ac389d3d1 100644 --- a/src/libraries/System.Formats.Nrbf/src/PACKAGE.md +++ b/src/libraries/System.Formats.Nrbf/src/PACKAGE.md @@ -54,7 +54,7 @@ There are more than a dozen different serialization [record types](https://learn - `PrimitiveTypeRecord` derives from the non-generic [PrimitiveTypeRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord), which also exposes a [Value](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord.value) property. But on the base class, the value is returned as `object` (which introduces boxing for value types). - [ClassRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.classrecord): describes all `class` and `struct` besides the aforementioned primitive types. - [ArrayRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.arrayrecord): describes all array records, including jagged and multi-dimensional arrays. -- [`SZArrayRecord`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `ClassRecord`. +- [`SZArrayRecord`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `SerializationRecord`. ```csharp SerializationRecord rootObject = NrbfDecoder.Decode(payload); // payload is a Stream diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs index 063a243078206..60623ac0dbde3 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs @@ -28,12 +28,13 @@ internal enum AllowedRecordTypes : uint ArraySingleString = 1 << SerializationRecordType.ArraySingleString, Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple, + Arrays = ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray, /// /// Any .NET object (a primitive, a reference type, a reference or single null). /// AnyObject = MemberPrimitiveTyped - | ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray + | Arrays | ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes | BinaryObjectString | MemberReference diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs index 237b7b72a2719..c18208668225f 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs @@ -4,6 +4,9 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection.Metadata; using System.Formats.Nrbf.Utils; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.Serialization; namespace System.Formats.Nrbf; @@ -27,12 +30,6 @@ private protected ArrayRecord(ArrayInfo arrayInfo) /// A buffer of integers that represent the number of elements in every dimension. public abstract ReadOnlySpan Lengths { get; } - /// - /// When overridden in a derived class, gets the total number of all elements in every dimension. - /// - /// A number that represent the total number of all elements in every dimension. - public virtual long FlattenedLength => ArrayInfo.FlattenedLength; - /// /// Gets the rank of the array. /// @@ -118,4 +115,86 @@ private void HandleNext(object value, NextInfo info, int size) } internal abstract (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType(); + + internal static void Populate(List source, Array destination, int[] lengths, AllowedRecordTypes allowedRecordTypes, bool allowNulls) + { + int[] indices = new int[lengths.Length]; + nuint numElementsWritten = 0; // only for debugging; not used in release builds + + foreach (SerializationRecord record in source) + { + object? value = GetActualValue(record, allowedRecordTypes, out int incrementCount); + if (value is not null) + { + // null is a default element for all array of reference types, so we don't call SetValue for nulls. + destination.SetValue(value, indices); + Debug.Assert(incrementCount == 1, "IncrementCount other than 1 is allowed only for null records."); + } + else if (!allowNulls) + { + ThrowHelper.ThrowArrayContainedNulls(); + } + + while (incrementCount > 0) + { + incrementCount--; + numElementsWritten++; + int dimension = indices.Length - 1; + while (dimension >= 0) + { + indices[dimension]++; + if (indices[dimension] < lengths[dimension]) + { + break; + } + indices[dimension] = 0; + dimension--; + } + + if (dimension < 0) + { + break; + } + } + } + + Debug.Assert(numElementsWritten == (uint)source.Count, "We should have traversed the entirety of the source records collection."); + Debug.Assert(numElementsWritten == (ulong)destination.LongLength, "We should have traversed the entirety of the destination array."); + } + + private static object? GetActualValue(SerializationRecord record, AllowedRecordTypes allowedRecordTypes, out int repeatCount) + { + repeatCount = 1; + + if (record is NullsRecord nullsRecord) + { + repeatCount = nullsRecord.NullCount; + return null; + } + else if (record.RecordType == SerializationRecordType.MemberReference) + { + record = ((MemberReferenceRecord)record).GetReferencedRecord(); + } + + if (allowedRecordTypes == AllowedRecordTypes.BinaryObjectString) + { + if (record is not BinaryObjectStringRecord stringRecord) + { + throw new SerializationException(SR.Serialization_InvalidReference); + } + + return stringRecord.Value; + } + else if (allowedRecordTypes == AllowedRecordTypes.Arrays) + { + if (record is not ArrayRecord arrayRecord) + { + throw new SerializationException(SR.Serialization_InvalidReference); + } + + return arrayRecord; + } + + return record; + } } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs new file mode 100644 index 0000000000000..39c66c5f2af0d --- /dev/null +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Formats.Nrbf.Utils; +using System.Linq; +using System.Reflection.Metadata; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; + +namespace System.Formats.Nrbf +{ + internal sealed class ArrayRectangularPrimitiveRecord : ArrayRecord where T : unmanaged + { + private readonly int[] _lengths; + private readonly IReadOnlyList _values; + private TypeName? _typeName; + + internal ArrayRectangularPrimitiveRecord(ArrayInfo arrayInfo, int[] lengths, IReadOnlyList values) : base(arrayInfo) + { + _lengths = lengths; + _values = values; + ValuesToRead = 0; // there is nothing to read anymore + } + + public override ReadOnlySpan Lengths => _lengths; + + public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; + + public override TypeName TypeName + => _typeName ??= TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.GetPrimitiveType()).MakeArrayTypeName(Rank); + + internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException(); + + private protected override void AddValue(object value) => throw new InvalidOperationException(); + + [RequiresDynamicCode("May call Array.CreateInstance().")] + private protected override Array Deserialize(Type arrayType, bool allowNulls) + { + Array result = +#if NET9_0_OR_GREATER + Array.CreateInstanceFromArrayType(arrayType, _lengths); +#else + Array.CreateInstance(typeof(T), _lengths); +#endif + int[] indices = new int[_lengths.Length]; + nuint numElementsWritten = 0; // only for debugging; not used in release builds + + for (int i = 0; i < _values.Count; i++) + { + result.SetValue(_values[i], indices); + numElementsWritten++; + + int dimension = indices.Length - 1; + while (dimension >= 0) + { + indices[dimension]++; + if (indices[dimension] < Lengths[dimension]) + { + break; + } + indices[dimension] = 0; + dimension--; + } + + if (dimension < 0) + { + break; + } + } + + Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection."); + Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array."); + + return result; + } + } +} diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs index d0276ff3782e3..2c402af7c35ab 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs @@ -15,9 +15,9 @@ namespace System.Formats.Nrbf; /// /// ArraySingleObject records are described in [MS-NRBF] 2.4.3.2. /// -internal sealed class ArraySingleObjectRecord : SZArrayRecord +internal sealed class ArraySingleObjectRecord : SZArrayRecord { - private ArraySingleObjectRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = []; + internal ArraySingleObjectRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = []; public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleObject; @@ -27,25 +27,26 @@ public override TypeName TypeName private List Records { get; } /// - public override object?[] GetArray(bool allowNulls = true) - => (object?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false)); + public override SerializationRecord?[] GetArray(bool allowNulls = true) + => (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false)); - private object?[] ToArray(bool allowNulls) + private SerializationRecord?[] ToArray(bool allowNulls) { - object?[] values = new object?[Length]; + SerializationRecord?[] values = new SerializationRecord?[Length]; int valueIndex = 0; for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++) { SerializationRecord record = Records[recordIndex]; - int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0; - if (nullCount == 0) + if (record is MemberReferenceRecord referenceRecord) { - // "new object[] { }" is special cased because it allows for storing reference to itself. - values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id) - ? values // a reference to self, and a way to get StackOverflow exception ;) - : record.GetValue(); + record = referenceRecord.GetReferencedRecord(); + } + + if (record is not NullsRecord nullsRecord) + { + values[valueIndex++] = record; continue; } @@ -54,6 +55,7 @@ public override TypeName TypeName ThrowHelper.ThrowArrayContainedNulls(); } + int nullCount = nullsRecord.NullCount; do { values[valueIndex++] = null; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs index a13507b97015a..a28359d9bb13d 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs @@ -47,6 +47,11 @@ public override T[] GetArray(bool allowNulls = true) internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int count) { + if (count == 0) + { + return Array.Empty(); // Empty arrays are allowed. + } + // For decimals, the input is provided as strings, so we can't compute the required size up-front. if (typeof(T) == typeof(decimal)) { @@ -71,18 +76,15 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c // allocations to be proportional to the amount of data present in the input stream, // which is a sufficient defense against DoS. - long requiredBytes = count; - if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)) - { - // We can't assume DateTime as represented by the runtime is 8 bytes. - // The only assumption we can make is that it's 8 bytes on the wire. - requiredBytes *= 8; - } - else if (typeof(T) != typeof(char)) - { - requiredBytes *= Unsafe.SizeOf(); - } + // We can't assume DateTime as represented by the runtime is 8 bytes. + // The only assumption we can make is that it's 8 bytes on the wire. + int sizeOfT = typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan) + ? 8 + : typeof(T) != typeof(char) + ? Unsafe.SizeOf() + : 1; + long requiredBytes = (long)count * sizeOfT; bool? isDataAvailable = reader.IsDataAvailable(requiredBytes); if (!isDataAvailable.HasValue) { @@ -110,26 +112,49 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c // It's safe to pre-allocate, as we have ensured there is enough bytes in the stream. T[] result = new T[count]; - Span resultAsBytes = MemoryMarshal.AsBytes(result); -#if NET - reader.BaseStream.ReadExactly(resultAsBytes); + + // MemoryMarshal.AsBytes can fail for inputs that need more than int.MaxValue bytes. + // To avoid OverflowException, we read the data in chunks. + int MaxChunkLength = +#if !DEBUG + int.MaxValue / sizeOfT; #else - byte[] bytes = ArrayPool.Shared.Rent((int)Math.Min(requiredBytes, 256_000)); + // Let's use a different value for non-release builds to ensure this code path + // is covered with tests without the need of decoding enormous payloads. + 8_000; +#endif - while (!resultAsBytes.IsEmpty) +#if !NET + byte[] rented = ArrayPool.Shared.Rent((int)Math.Min(requiredBytes, 256_000)); +#endif + + Span valuesToRead = result.AsSpan(); + while (!valuesToRead.IsEmpty) { - int bytesRead = reader.Read(bytes, 0, Math.Min(resultAsBytes.Length, bytes.Length)); - if (bytesRead <= 0) + int sliceSize = Math.Min(valuesToRead.Length, MaxChunkLength); + + Span resultAsBytes = MemoryMarshal.AsBytes(valuesToRead.Slice(0, sliceSize)); +#if NET + reader.BaseStream.ReadExactly(resultAsBytes); +#else + while (!resultAsBytes.IsEmpty) { - ArrayPool.Shared.Return(bytes); - ThrowHelper.ThrowEndOfStreamException(); - } + int bytesRead = reader.Read(rented, 0, Math.Min(resultAsBytes.Length, rented.Length)); + if (bytesRead <= 0) + { + ArrayPool.Shared.Return(rented); + ThrowHelper.ThrowEndOfStreamException(); + } - bytes.AsSpan(0, bytesRead).CopyTo(resultAsBytes); - resultAsBytes = resultAsBytes.Slice(bytesRead); + rented.AsSpan(0, bytesRead).CopyTo(resultAsBytes); + resultAsBytes = resultAsBytes.Slice(bytesRead); + } +#endif + valuesToRead = valuesToRead.Slice(sliceSize); } - ArrayPool.Shared.Return(bytes); +#if !NET + ArrayPool.Shared.Return(rented); #endif if (!BitConverter.IsLittleEndian) @@ -176,7 +201,7 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c { // See DontCastBytesToBooleans test to see what could go wrong. bool[] booleans = (bool[])(object)result; - resultAsBytes = MemoryMarshal.AsBytes(result); + Span resultAsBytes = MemoryMarshal.AsBytes(result); for (int i = 0; i < booleans.Length; i++) { // We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this. diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs index 42b9eadd97bd5..38884aadc5469 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs @@ -17,7 +17,7 @@ namespace System.Formats.Nrbf; /// internal sealed class ArraySingleStringRecord : SZArrayRecord { - private ArraySingleStringRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = []; + internal ArraySingleStringRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = []; public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleString; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs deleted file mode 100644 index 41b1f73f03550..0000000000000 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs +++ /dev/null @@ -1,309 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Reflection.Metadata; -using System.Formats.Nrbf.Utils; -using System.Diagnostics; - -namespace System.Formats.Nrbf; - -/// -/// Represents an array other than single dimensional array of primitive types or . -/// -/// -/// BinaryArray records are described in [MS-NRBF] 2.4.3.1. -/// -internal sealed class BinaryArrayRecord : ArrayRecord -{ - private static HashSet PrimitiveTypes { get; } = - [ - typeof(bool), typeof(char), typeof(byte), typeof(sbyte), - typeof(short), typeof(ushort), typeof(int), typeof(uint), - typeof(long), typeof(ulong), typeof(IntPtr), typeof(UIntPtr), - typeof(float), typeof(double), typeof(decimal), typeof(DateTime), - typeof(TimeSpan), typeof(string), typeof(object) - ]; - - private TypeName? _typeName; - private long _totalElementsCount; - - private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) - : base(arrayInfo) - { - MemberTypeInfo = memberTypeInfo; - Values = []; - // We need to parse all elements of the jagged array to obtain total elements count. - _totalElementsCount = -1; - } - - public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; - - /// - public override ReadOnlySpan Lengths => new int[1] { Length }; - - /// - public override long FlattenedLength - { - get - { - if (_totalElementsCount < 0) - { - _totalElementsCount = IsJagged - ? GetJaggedArrayFlattenedLength(this) - : ArrayInfo.FlattenedLength; - } - - return _totalElementsCount; - } - } - - public override TypeName TypeName - => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo); - - private int Length => ArrayInfo.GetSZArrayLength(); - - private MemberTypeInfo MemberTypeInfo { get; } - - private List Values { get; } - - [RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")] - private protected override Array Deserialize(Type arrayType, bool allowNulls) - { - // We can not deserialize non-primitive types. - // This method returns arrays of ClassRecord for arrays of complex types. - Type elementType = MapElementType(arrayType, out bool isClassRecord); - Type actualElementType = arrayType.GetElementType()!; - Array array = -#if NET9_0_OR_GREATER - isClassRecord - ? Array.CreateInstance(elementType, Length) - : Array.CreateInstanceFromArrayType(arrayType, Length); -#else - Array.CreateInstance(elementType, Length); -#endif - - int resultIndex = 0; - foreach (object value in Values) - { - object item = value is MemberReferenceRecord referenceRecord - ? referenceRecord.GetReferencedRecord() - : value; - - if (item is not SerializationRecord record) - { - array.SetValue(item, resultIndex++); - continue; - } - - switch (record.RecordType) - { - case SerializationRecordType.BinaryArray: - case SerializationRecordType.ArraySinglePrimitive: - case SerializationRecordType.ArraySingleObject: - case SerializationRecordType.ArraySingleString: - - // Recursion depth is bounded by the depth of arrayType, which is - // a trustworthy Type instance. Don't need to worry about stack overflow. - - ArrayRecord nestedArrayRecord = (ArrayRecord)record; - Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls); - array.SetValue(nestedArray, resultIndex++); - break; - case SerializationRecordType.ObjectNull: - case SerializationRecordType.ObjectNullMultiple256: - case SerializationRecordType.ObjectNullMultiple: - if (!allowNulls) - { - ThrowHelper.ThrowArrayContainedNulls(); - } - - int nullCount = ((NullsRecord)item).NullCount; - Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount."); - do - { - array.SetValue(null, resultIndex++); - nullCount--; - } - while (nullCount > 0); - break; - default: - array.SetValue(record.GetValue(), resultIndex++); - break; - } - } - - Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array."); - - return array; - } - - internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options) - { - SerializationRecordId objectId = SerializationRecordId.Decode(reader); - BinaryArrayType arrayType = reader.ReadArrayType(); - int rank = reader.ReadInt32(); - - bool isRectangular = arrayType is BinaryArrayType.Rectangular; - - // It is an arbitrary limit in the current CoreCLR type loader. - // Don't change this value without reviewing the loop a few lines below. - const int MaxSupportedArrayRank = 32; - - if (rank < 1 || rank > MaxSupportedArrayRank - || (rank != 1 && !isRectangular) - || (rank == 1 && isRectangular)) - { - ThrowHelper.ThrowInvalidValue(rank); - } - - int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32 - long totalElementCount = 1; // to avoid integer overflow during the multiplication below - for (int i = 0; i < lengths.Length; i++) - { - lengths[i] = ArrayInfo.ParseValidArrayLength(reader); - totalElementCount *= lengths[i]; - - // n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]" - // but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But - // that's the same behavior that newarr and Array.CreateInstance exhibit, so at least - // we're consistent. - - if (totalElementCount > ArrayInfo.MaxArrayLength) - { - ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded - } - } - - // Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so - // we don't need to read the NRBF stream 'LowerBounds' field here. - - MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap); - ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank); - - if (isRectangular) - { - return RectangularArrayRecord.Create(reader, arrayInfo, memberTypeInfo, lengths); - } - - return memberTypeInfo.ShouldBeRepresentedAsArrayOfClassRecords() - ? new ArrayOfClassesRecord(arrayInfo, memberTypeInfo) - : new BinaryArrayRecord(arrayInfo, memberTypeInfo); - } - - private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord) - { - long result = 0; - Queue? jaggedArrayRecords = null; - - do - { - if (jaggedArrayRecords is not null) - { - jaggedArrayRecord = jaggedArrayRecords.Dequeue(); - } - - Debug.Assert(jaggedArrayRecord.IsJagged); - - // In theory somebody could create a payload that would represent - // a very nested array with total elements count > long.MaxValue. - // That is why this method is using checked arithmetic. - result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves - - foreach (object value in jaggedArrayRecord.Values) - { - if (value is not SerializationRecord record) - { - continue; - } - - if (record.RecordType == SerializationRecordType.MemberReference) - { - record = ((MemberReferenceRecord)record).GetReferencedRecord(); - } - - switch (record.RecordType) - { - case SerializationRecordType.ArraySinglePrimitive: - case SerializationRecordType.ArraySingleObject: - case SerializationRecordType.ArraySingleString: - case SerializationRecordType.BinaryArray: - ArrayRecord nestedArrayRecord = (ArrayRecord)record; - if (nestedArrayRecord.IsJagged) - { - (jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord); - } - else - { - // Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion, - // just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value. - result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength); - } - break; - default: - break; - } - } - } - while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0); - - return result; - } - - private protected override void AddValue(object value) => Values.Add(value); - - internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() - { - (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0); - - if (allowed != AllowedRecordTypes.None) - { - // It's an array, it can also contain multiple nulls - return (allowed | AllowedRecordTypes.Nulls, primitiveType); - } - - return (allowed, primitiveType); - } - - /// - /// Complex types must not be instantiated, but represented as ClassRecord. - /// For arrays of primitive types like int, string and object this method returns the element type. - /// For array of complex types, it returns ClassRecord. - /// It takes arrays of arrays into account: - /// - int[][] => int[] - /// - MyClass[][][] => ClassRecord[][] - /// - [RequiresDynamicCode("May call Type.MakeArrayType().")] - private static Type MapElementType(Type arrayType, out bool isClassRecord) - { - Type elementType = arrayType; - int arrayNestingDepth = 0; - - // Loop iteration counts are bound by the nesting depth of arrayType, - // which is a trustworthy input. No DoS concerns. - - while (elementType.IsArray) - { - elementType = elementType.GetElementType()!; - arrayNestingDepth++; - } - - if (PrimitiveTypes.Contains(elementType) || (Nullable.GetUnderlyingType(elementType) is Type nullable && PrimitiveTypes.Contains(nullable))) - { - isClassRecord = false; - return arrayNestingDepth == 1 ? elementType : arrayType.GetElementType()!; - } - - // Complex types are never instantiated, but represented as ClassRecord - isClassRecord = true; - Type complexType = typeof(ClassRecord); - for (int i = 1; i < arrayNestingDepth; i++) - { - complexType = complexType.MakeArrayType(); - } - - return complexType; - } -} diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs index c643d3ce8c846..2762be167b111 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Formats.Nrbf.Utils; using System.IO; using System.Runtime.Serialization; @@ -27,16 +28,57 @@ private ClassWithIdRecord(SerializationRecordId id, ClassRecord metadataClass) : internal ClassRecord MetadataClass { get; } - internal static ClassWithIdRecord Decode( + internal static SerializationRecord Decode( BinaryReader reader, RecordMap recordMap) { SerializationRecordId id = SerializationRecordId.Decode(reader); SerializationRecordId metadataId = SerializationRecordId.Decode(reader); - ClassRecord referencedRecord = recordMap.GetRecord(metadataId); + SerializationRecord metadataRecord = recordMap.GetRecord(metadataId); + if (metadataRecord is ClassRecord referencedClassRecord) + { + return new ClassWithIdRecord(id, referencedClassRecord); + } + else if (metadataRecord is PrimitiveTypeRecord primitiveTypeRecord + && !primitiveTypeRecord.Id.Equals(default) // such records always have Id provided + && metadataRecord is not BinaryObjectStringRecord) // it does not apply to BinaryObjectStringRecord + { + // BinaryFormatter represents primitive types as MemberPrimitiveTypedRecord + // only for arrays of objects. For other arrays, like arrays of some abstraction + // (example: new IComparable[] { int.MaxValue }), it uses SystemClassWithMembersAndTypes. + // SystemClassWithMembersAndTypes.Decode handles that by returning MemberPrimitiveTypedRecord. + // But arrays of such types typically have only one SystemClassWithMembersAndTypes record with + // all the member information and multiple ClassWithIdRecord records that just reuse that information. + return primitiveTypeRecord switch + { + MemberPrimitiveTypedRecord => Create(reader.ReadBoolean()), + MemberPrimitiveTypedRecord => Create(reader.ReadByte()), + MemberPrimitiveTypedRecord => Create(reader.ReadSByte()), + MemberPrimitiveTypedRecord => Create(reader.ParseChar()), + MemberPrimitiveTypedRecord => Create(reader.ReadInt16()), + MemberPrimitiveTypedRecord => Create(reader.ReadUInt16()), + MemberPrimitiveTypedRecord => Create(reader.ReadInt32()), + MemberPrimitiveTypedRecord => Create(reader.ReadUInt32()), + MemberPrimitiveTypedRecord => Create(reader.ReadInt64()), + MemberPrimitiveTypedRecord => Create(reader.ReadUInt64()), + MemberPrimitiveTypedRecord => Create(reader.ReadSingle()), + MemberPrimitiveTypedRecord => Create(reader.ReadDouble()), + MemberPrimitiveTypedRecord => Create(new IntPtr(reader.ReadInt64())), + MemberPrimitiveTypedRecord => Create(new UIntPtr(reader.ReadUInt64())), + MemberPrimitiveTypedRecord => Create(new TimeSpan(reader.ReadInt64())), + MemberPrimitiveTypedRecord => SystemClassWithMembersAndTypesRecord.DecodeDateTime(reader, id), + MemberPrimitiveTypedRecord => SystemClassWithMembersAndTypesRecord.DecodeDecimal(reader, id), + _ => throw new InvalidOperationException() + }; + } + else + { + throw new SerializationException(SR.Serialization_InvalidReference); + } - return new ClassWithIdRecord(id, referencedRecord); + SerializationRecord Create(T value) where T : unmanaged + => new MemberPrimitiveTypedRecord(value, id); } internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType() diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs new file mode 100644 index 0000000000000..6ac97ef40675d --- /dev/null +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Reflection.Metadata; +using System.Formats.Nrbf.Utils; +using System.Diagnostics; +using System.Runtime.Serialization; + +namespace System.Formats.Nrbf; + +/// +/// Represents an array of arrays. +/// +/// +/// BinaryArray records are described in [MS-NRBF] 2.4.3.1. +/// +internal sealed class JaggedArrayRecord : ArrayRecord +{ + private readonly MemberTypeInfo _memberTypeInfo; + private readonly int[] _lengths; + private readonly List _records; + private readonly AllowedRecordTypes _allowedRecordTypes; + private TypeName? _typeName; + + internal JaggedArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo, int[] lengths) + : base(arrayInfo) + { + _memberTypeInfo = memberTypeInfo; + _lengths = lengths; + _records = []; + _allowedRecordTypes = memberTypeInfo.GetNextAllowedRecordType(0).allowed; + + Debug.Assert(TypeName.GetElementType().IsArray, "Jagged arrays are required."); + } + + public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; + + public override ReadOnlySpan Lengths => _lengths; + + public override TypeName TypeName => _typeName ??= _memberTypeInfo.GetArrayTypeName(ArrayInfo); + + [RequiresDynamicCode("May call Array.CreateInstance().")] + private protected override Array Deserialize(Type arrayType, bool allowNulls) + { + // This method returns arrays of ArrayRecords. + Array array = _lengths.Length switch + { + 1 => new ArrayRecord[_lengths[0]], + 2 => new ArrayRecord[_lengths[0], _lengths[1]], + _ => Array.CreateInstance(typeof(ArrayRecord), _lengths) + }; + + Populate(_records, array, _lengths, AllowedRecordTypes.Arrays, allowNulls); + + return array; + } + + private protected override void AddValue(object value) => _records.Add((SerializationRecord)value); + + internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() + => (_allowedRecordTypes, default); +} diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs index 57e47a02eec68..84c1073b0ef67 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs @@ -86,10 +86,12 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt // Every class can be a null or a reference and a ClassWithId const AllowedRecordTypes Classes = AllowedRecordTypes.ClassWithId | AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference - | AllowedRecordTypes.MemberPrimitiveTyped | AllowedRecordTypes.BinaryLibrary; // Classes may be preceded with a library record (System too!) // but System Classes can be expressed only by System records - const AllowedRecordTypes SystemClass = Classes | AllowedRecordTypes.SystemClassWithMembersAndTypes; + const AllowedRecordTypes SystemClass = Classes | AllowedRecordTypes.SystemClassWithMembersAndTypes + // All primitive types can be stored by using one of the interfaces they implement. + // Example: `new IEnumerable[1] { "hello" }` or `new IComparable[1] { int.MaxValue }`. + | AllowedRecordTypes.BinaryObjectString | AllowedRecordTypes.MemberPrimitiveTyped; const AllowedRecordTypes NonSystemClass = Classes | AllowedRecordTypes.ClassWithMembersAndTypes; return binaryType switch @@ -106,43 +108,6 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt }; } - internal bool ShouldBeRepresentedAsArrayOfClassRecords() - { - // This library tries to minimize the number of concepts the users need to learn to use it. - // Since SZArrays are most common, it provides an SZArrayRecord abstraction. - // Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord. - // The goal of this method is to determine whether given array can be represented as SZArrayRecord. - - (BinaryType binaryType, object? additionalInfo) = Infos[0]; - - if (binaryType == BinaryType.Class) - { - // An array of arrays can not be represented as SZArrayRecord. - return !((ClassTypeInfo)additionalInfo!).TypeName.IsArray; - } - else if (binaryType == BinaryType.SystemClass) - { - TypeName typeName = (TypeName)additionalInfo!; - - // An array of arrays can not be represented as SZArrayRecord. - if (typeName.IsArray) - { - return false; - } - - if (!typeName.IsConstructedGenericType) - { - return true; - } - - // Can't use SZArrayRecord for Nullable[] - // as it consists of MemberPrimitiveTypedRecord and NullsRecord - return typeName.GetGenericTypeDefinition().FullName != typeof(Nullable<>).FullName; - } - - return false; - } - internal TypeName GetArrayTypeName(ArrayInfo arrayInfo) { (BinaryType binaryType, object? additionalInfo) = Infos[0]; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs index 02411578fb7bf..76089c07ee0ce 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs @@ -9,6 +9,7 @@ using System.Text; using System.Runtime.Serialization; using System.Runtime.InteropServices; +using System.Reflection.Metadata; namespace System.Formats.Nrbf; @@ -223,7 +224,7 @@ private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap rec SerializationRecordType.ArraySingleObject => ArraySingleObjectRecord.Decode(reader), SerializationRecordType.ArraySinglePrimitive => DecodeArraySinglePrimitiveRecord(reader), SerializationRecordType.ArraySingleString => ArraySingleStringRecord.Decode(reader), - SerializationRecordType.BinaryArray => BinaryArrayRecord.Decode(reader, recordMap, options), + SerializationRecordType.BinaryArray => DecodeBinaryArrayRecord(reader, recordMap, options), SerializationRecordType.BinaryLibrary => BinaryLibraryRecord.Decode(reader, options), SerializationRecordType.BinaryObjectString => BinaryObjectStringRecord.Decode(reader), SerializationRecordType.ClassWithId => ClassWithIdRecord.Decode(reader, recordMap), @@ -269,11 +270,16 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader }; } - private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader) + private static ArrayRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader) { ArrayInfo info = ArrayInfo.Decode(reader); PrimitiveType primitiveType = reader.ReadPrimitiveType(); + return DecodeArraySinglePrimitiveRecord(reader, info, primitiveType); + } + + private static ArrayRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader, ArrayInfo info, PrimitiveType primitiveType) + { return primitiveType switch { PrimitiveType.Boolean => Decode(info, reader), @@ -294,10 +300,171 @@ private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader _ => throw new InvalidOperationException() }; - static SerializationRecord Decode(ArrayInfo info, BinaryReader reader) where T : unmanaged + static ArrayRecord Decode(ArrayInfo info, BinaryReader reader) where T : unmanaged => new ArraySinglePrimitiveRecord(info, ArraySinglePrimitiveRecord.DecodePrimitiveTypes(reader, info.GetSZArrayLength())); } + private static ArrayRecord DecodeArrayRectangularPrimitiveRecord(PrimitiveType primitiveType, ArrayInfo info, int[] lengths, BinaryReader reader) + { + return primitiveType switch + { + PrimitiveType.Boolean => Decode(info, lengths, reader), + PrimitiveType.Byte => Decode(info, lengths, reader), + PrimitiveType.SByte => Decode(info, lengths, reader), + PrimitiveType.Char => Decode(info, lengths, reader), + PrimitiveType.Int16 => Decode(info, lengths, reader), + PrimitiveType.UInt16 => Decode(info, lengths, reader), + PrimitiveType.Int32 => Decode(info, lengths, reader), + PrimitiveType.UInt32 => Decode(info, lengths, reader), + PrimitiveType.Int64 => Decode(info, lengths, reader), + PrimitiveType.UInt64 => Decode(info, lengths, reader), + PrimitiveType.Single => Decode(info, lengths, reader), + PrimitiveType.Double => Decode(info, lengths, reader), + PrimitiveType.Decimal => Decode(info, lengths, reader), + PrimitiveType.DateTime => Decode(info, lengths, reader), + PrimitiveType.TimeSpan => Decode(info, lengths, reader), + _ => throw new InvalidOperationException() + }; + + static ArrayRecord Decode(ArrayInfo info, int[] lengths, BinaryReader reader) where T : unmanaged + { + // We limit the length of multi-dimensional array to max length of SZArray. + // Because of that, it's possible to re-use the same decoding logic for both MD and SZ arrays. + IReadOnlyList values = ArraySinglePrimitiveRecord.DecodePrimitiveTypes(reader, info.GetSZArrayLength()); + return new ArrayRectangularPrimitiveRecord(info, lengths, values); + } + } + + private static ArrayRecord DecodeBinaryArrayRecord(BinaryReader reader, RecordMap recordMap, PayloadOptions options) + { + SerializationRecordId objectId = SerializationRecordId.Decode(reader); + BinaryArrayType arrayType = reader.ReadArrayType(); + int rank = reader.ReadInt32(); + + bool isRectangular = arrayType is BinaryArrayType.Rectangular; + + // It is an arbitrary limit in the current CoreCLR type loader. + // Don't change this value without reviewing the loop a few lines below. + const int MaxSupportedArrayRank = 32; + + if (rank < 1 || rank > MaxSupportedArrayRank + || (rank != 1 && !isRectangular) + || (rank == 1 && isRectangular)) + { + ThrowHelper.ThrowInvalidValue(rank); + } + + int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32 + long totalElementCount = 1; // to avoid integer overflow during the multiplication below + for (int i = 0; i < lengths.Length; i++) + { + lengths[i] = ArrayInfo.ParseValidArrayLength(reader); + totalElementCount *= lengths[i]; + + // n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]" + // but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But + // that's the same behavior that newarr and Array.CreateInstance exhibit, so at least + // we're consistent. + + if (totalElementCount > ArrayInfo.MaxArrayLength) + { + ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded + } + } + + // Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so + // we don't need to read the NRBF stream 'LowerBounds' field here. + + MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap); + ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank); + + (BinaryType binaryType, object? additionalInfo) = memberTypeInfo.Infos[0]; + if (arrayType == BinaryArrayType.Rectangular) + { + if (binaryType == BinaryType.Primitive) + { + return DecodeArrayRectangularPrimitiveRecord((PrimitiveType)additionalInfo!, arrayInfo, lengths, reader); + } + else if (binaryType == BinaryType.String) + { + return new RectangularArrayRecord(typeof(string), arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType == BinaryType.Object) + { + return new RectangularArrayRecord(typeof(SerializationRecord), arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType is BinaryType.SystemClass or BinaryType.Class) + { + TypeName typeName = binaryType == BinaryType.SystemClass ? (TypeName)additionalInfo! : ((ClassTypeInfo)additionalInfo!).TypeName; + // BinaryArrayType.Rectangular can be also a jagged array. + return typeName.IsArray + ? new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths) + : new RectangularArrayRecord(typeof(SerializationRecord), arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType is BinaryType.PrimitiveArray or BinaryType.StringArray or BinaryType.ObjectArray) + { + // A multi-dimensional array of single dimensional arrays. Example: int[][,] + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + } + else if (arrayType == BinaryArrayType.Single) + { + if (binaryType is BinaryType.SystemClass or BinaryType.Class) + { + TypeName typeName = binaryType == BinaryType.SystemClass ? (TypeName)additionalInfo! : ((ClassTypeInfo)additionalInfo!).TypeName; + // BinaryArrayType.Single that describes an array is just a jagged array. + return typeName.IsArray + ? new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths) + : new SZArrayOfRecords(arrayInfo, memberTypeInfo); + } + else if (binaryType == BinaryType.String) + { + // BinaryArrayRecord can represent string[] (but BF always uses ArraySingleStringRecord for that). + return new ArraySingleStringRecord(arrayInfo); + } + else if (binaryType == BinaryType.Primitive) + { + // BinaryArrayRecord can represent Primitive[] (but BF always uses ArraySinglePrimitiveRecord for that). + return DecodeArraySinglePrimitiveRecord(reader, arrayInfo, (PrimitiveType)additionalInfo!); + } + else if (binaryType == BinaryType.Object) + { + // BinaryArrayRecord can represent object[] (but BF always uses ArraySingleObjectRecord for that). + return new ArraySingleObjectRecord(arrayInfo); + } + else if (binaryType is BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.PrimitiveArray) + { + // It's a Jagged array that does not use BinaryArrayType.Jagged. + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + } + else if (arrayType == BinaryArrayType.Jagged) + { + if (binaryType is BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.PrimitiveArray) + { + // It's a Jagged array that does not use BinaryArrayType.Jagged. + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType == BinaryType.SystemClass && ((TypeName)additionalInfo!).IsArray) + { + // BinaryType.SystemClass can be used to describe arrays of system class records. + // Example: new Exception[] { new Exception("test") }; + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType == BinaryType.Class && ((ClassTypeInfo)additionalInfo!).TypeName.IsArray) + { + // BinaryType.Class can be used to describe arrays of class records. + // Example: new MyCustomType[] { new MyCustomType(0) }; + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + + // It's invalid, the element type must be an array. + throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, binaryType)); + } + + throw new InvalidOperationException(); + } + /// /// This method is responsible for pushing only the FIRST read info /// of the NESTED record into the . diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs index eafcbf93249c5..dd5862c7b2b86 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs @@ -61,18 +61,7 @@ internal void Add(SerializationRecord record) } } - internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header) - { - SerializationRecord rootRecord = GetRecord(header.RootId); - - if (rootRecord is SystemClassWithMembersAndTypesRecord systemClass) - { - // update the record map, so it's visible also to those who access it via Id - _map[header.RootId] = rootRecord = systemClass.TryToMapToUserFriendly(); - } - - return rootRecord; - } + internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header) => GetRecord(header.RootId); internal SerializationRecord GetRecord(SerializationRecordId recordId) => _map.TryGetValue(recordId, out SerializationRecord? record) diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs index f64dde36163d6..f10bc3f51efda 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs @@ -9,24 +9,29 @@ using System.Runtime.InteropServices; using System.Formats.Nrbf.Utils; using System.Diagnostics; +using System.Runtime.Serialization; namespace System.Formats.Nrbf; internal sealed class RectangularArrayRecord : ArrayRecord { + private readonly Type _elementType; private readonly int[] _lengths; - private readonly List _values; + private readonly List _records; + private readonly AllowedRecordTypes _allowedRecordTypes; + private readonly MemberTypeInfo _memberTypeInfo; private TypeName? _typeName; - private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, - MemberTypeInfo memberTypeInfo, int[] lengths, bool canPreAllocate) : base(arrayInfo) + internal RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo, int[] lengths) : base(arrayInfo) { - ElementType = elementType; - MemberTypeInfo = memberTypeInfo; + _elementType = elementType; _lengths = lengths; + _memberTypeInfo = memberTypeInfo; + _records = new List(Math.Min(4, arrayInfo.GetSZArrayLength())); + _allowedRecordTypes = memberTypeInfo.GetNextAllowedRecordType(0).allowed; - // ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength - _values = new List(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength())); + Debug.Assert(elementType == typeof(string) || elementType == typeof(SerializationRecord)); + Debug.Assert(!TypeName.GetElementType().IsArray, "Use JaggedArrayRecord instead."); } public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; @@ -34,230 +39,32 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, public override ReadOnlySpan Lengths => _lengths.AsSpan(); public override TypeName TypeName - => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo); - - private Type ElementType { get; } - - private MemberTypeInfo MemberTypeInfo { get; } + => _typeName ??= _memberTypeInfo.GetArrayTypeName(ArrayInfo); [RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")] private protected override Array Deserialize(Type arrayType, bool allowNulls) { - // We can not deserialize non-primitive types. - // This method returns arrays of ClassRecord for arrays of complex types. + bool storeStrings = _elementType == typeof(string); + + // We can not deserialize non-string types. + // This method returns arrays of SerializationRecord for arrays of complex types. Array result = #if NET9_0_OR_GREATER - ElementType == typeof(ClassRecord) - ? Array.CreateInstance(ElementType, _lengths) - : Array.CreateInstanceFromArrayType(arrayType, _lengths); + storeStrings + ? Array.CreateInstanceFromArrayType(arrayType, _lengths) + : Array.CreateInstance(_elementType, _lengths); #else - Array.CreateInstance(ElementType, _lengths); + Array.CreateInstance(_elementType, _lengths); #endif -#if !NET8_0_OR_GREATER - int[] indices = new int[_lengths.Length]; - nuint numElementsWritten = 0; // only for debugging; not used in release builds - - foreach (object value in _values) - { - result.SetValue(GetActualValue(value), indices); - numElementsWritten++; - - int dimension = indices.Length - 1; - while (dimension >= 0) - { - indices[dimension]++; - if (indices[dimension] < Lengths[dimension]) - { - break; - } - indices[dimension] = 0; - dimension--; - } - - if (dimension < 0) - { - break; - } - } - - Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection."); - Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array."); + AllowedRecordTypes allowedRecordTypes = storeStrings ? AllowedRecordTypes.BinaryObjectString : AllowedRecordTypes.AnyObject; + Populate(_records, result, _lengths, allowedRecordTypes, allowNulls); return result; -#else - // Idea from Array.CoreCLR that maps an array of int indices into - // an internal flat index. - if (ElementType.IsValueType) - { - if (ElementType == typeof(bool)) CopyTo(_values, result); - else if (ElementType == typeof(byte)) CopyTo(_values, result); - else if (ElementType == typeof(sbyte)) CopyTo(_values, result); - else if (ElementType == typeof(short)) CopyTo(_values, result); - else if (ElementType == typeof(ushort)) CopyTo(_values, result); - else if (ElementType == typeof(char)) CopyTo(_values, result); - else if (ElementType == typeof(int)) CopyTo(_values, result); - else if (ElementType == typeof(float)) CopyTo(_values, result); - else if (ElementType == typeof(long)) CopyTo(_values, result); - else if (ElementType == typeof(ulong)) CopyTo(_values, result); - else if (ElementType == typeof(double)) CopyTo(_values, result); - else if (ElementType == typeof(TimeSpan)) CopyTo(_values, result); - else if (ElementType == typeof(DateTime)) CopyTo(_values, result); - else if (ElementType == typeof(decimal)) CopyTo(_values, result); - else throw new InvalidOperationException(); - } - else - { - CopyTo(_values, result); - } - - return result; - - static void CopyTo(List list, Array array) - { - ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array); - ref T firstElementRef = ref Unsafe.As(ref arrayDataRef); - nuint flattenedIndex = 0; - foreach (object value in list) - { - ref T targetElement = ref Unsafe.Add(ref firstElementRef, flattenedIndex); - targetElement = (T)GetActualValue(value)!; - flattenedIndex++; - } - - Debug.Assert(flattenedIndex == (ulong)array.LongLength, "We should have traversed the entirety of the array."); - } -#endif } - private protected override void AddValue(object value) => _values.Add(value); + private protected override void AddValue(object value) => _records.Add((SerializationRecord)value); internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() - { - (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0); - - if (allowed != AllowedRecordTypes.None) - { - // It's an array, it can also contain multiple nulls - return (allowed | AllowedRecordTypes.Nulls, primitiveType); - } - - return (allowed, primitiveType); - } - - internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arrayInfo, - MemberTypeInfo memberTypeInfo, int[] lengths) - { - BinaryType binaryType = memberTypeInfo.Infos[0].BinaryType; - Type elementType = binaryType switch - { - BinaryType.Primitive => MapPrimitive((PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!), - BinaryType.PrimitiveArray => MapPrimitiveArray((PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!), - BinaryType.String => typeof(string), - BinaryType.Object => typeof(object), - _ => typeof(ClassRecord) - }; - - bool canPreAllocate = false; - if (binaryType == BinaryType.Primitive) - { - int sizeOfSingleValue = (PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo! switch - { - PrimitiveType.Boolean => sizeof(bool), - PrimitiveType.Byte => sizeof(byte), - PrimitiveType.SByte => sizeof(sbyte), - PrimitiveType.Char => sizeof(byte), // it's UTF8 (see comment below) - PrimitiveType.Int16 => sizeof(short), - PrimitiveType.UInt16 => sizeof(ushort), - PrimitiveType.Int32 => sizeof(int), - PrimitiveType.UInt32 => sizeof(uint), - PrimitiveType.Single => sizeof(float), - PrimitiveType.Int64 => sizeof(long), - PrimitiveType.UInt64 => sizeof(ulong), - PrimitiveType.Double => sizeof(double), - PrimitiveType.TimeSpan => sizeof(ulong), - PrimitiveType.DateTime => sizeof(ulong), - PrimitiveType.Decimal => -1, // represented as variable-length string - _ => throw new InvalidOperationException() - }; - - if (sizeOfSingleValue > 0) - { - // NRBF encodes rectangular char[,,,...] by converting each standalone UTF-16 code point into - // its UTF-8 encoding. This means that surrogate code points (including adjacent surrogate - // pairs) occurring within a char[,,,...] cannot be encoded by NRBF. BinaryReader will detect - // that they're ill-formed and reject them on read. - // - // Per the comment in ArraySinglePrimitiveRecord.DecodePrimitiveTypes, we'll assume best-case - // encoding where 1 UTF-16 char encodes as a single UTF-8 byte, even though this might lead - // to encountering an EOF if we realize later that we actually need to read more bytes in - // order to fully populate the char[,,,...] array. Any such allocation is still linearly - // proportional to the length of the incoming payload, so it's not a DoS vector. - // The multiplication below is guaranteed not to overflow because FlattenedLength is bounded - // to <= Array.MaxLength (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8. - Debug.Assert(arrayInfo.FlattenedLength >= 0 && arrayInfo.FlattenedLength <= long.MaxValue / sizeOfSingleValue); - - long size = arrayInfo.FlattenedLength * sizeOfSingleValue; - bool? isDataAvailable = reader.IsDataAvailable(size); - if (isDataAvailable.HasValue) - { - if (!isDataAvailable.Value) - { - ThrowHelper.ThrowEndOfStreamException(); - } - - canPreAllocate = true; - } - } - } - - return new RectangularArrayRecord(elementType, arrayInfo, memberTypeInfo, lengths, canPreAllocate); - } - - private static Type MapPrimitive(PrimitiveType primitiveType) - => primitiveType switch - { - PrimitiveType.Boolean => typeof(bool), - PrimitiveType.Byte => typeof(byte), - PrimitiveType.Char => typeof(char), - PrimitiveType.Decimal => typeof(decimal), - PrimitiveType.Double => typeof(double), - PrimitiveType.Int16 => typeof(short), - PrimitiveType.Int32 => typeof(int), - PrimitiveType.Int64 => typeof(long), - PrimitiveType.SByte => typeof(sbyte), - PrimitiveType.Single => typeof(float), - PrimitiveType.TimeSpan => typeof(TimeSpan), - PrimitiveType.DateTime => typeof(DateTime), - PrimitiveType.UInt16 => typeof(ushort), - PrimitiveType.UInt32 => typeof(uint), - PrimitiveType.UInt64 => typeof(ulong), - _ => throw new InvalidOperationException() - }; - - private static Type MapPrimitiveArray(PrimitiveType primitiveType) - => primitiveType switch - { - PrimitiveType.Boolean => typeof(bool[]), - PrimitiveType.Byte => typeof(byte[]), - PrimitiveType.Char => typeof(char[]), - PrimitiveType.Decimal => typeof(decimal[]), - PrimitiveType.Double => typeof(double[]), - PrimitiveType.Int16 => typeof(short[]), - PrimitiveType.Int32 => typeof(int[]), - PrimitiveType.Int64 => typeof(long[]), - PrimitiveType.SByte => typeof(sbyte[]), - PrimitiveType.Single => typeof(float[]), - PrimitiveType.TimeSpan => typeof(TimeSpan[]), - PrimitiveType.DateTime => typeof(DateTime[]), - PrimitiveType.UInt16 => typeof(ushort[]), - PrimitiveType.UInt32 => typeof(uint[]), - PrimitiveType.UInt64 => typeof(ulong[]), - _ => throw new InvalidOperationException() - }; - - private static object? GetActualValue(object value) - => value is SerializationRecord serializationRecord - ? serializationRecord.GetValue() - : value; // it must be a primitive type + => (_allowedRecordTypes, default); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs similarity index 69% rename from src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs rename to src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs index f345292c693a6..b77a4a57a2a34 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs @@ -8,11 +8,15 @@ namespace System.Formats.Nrbf; -internal sealed class ArrayOfClassesRecord : SZArrayRecord +// This library tries to minimize the number of concepts the users need to learn to use it. +// Since SZArrays are most common, it provides an SZArrayRecord abstraction. +// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord. +// The goal of this class is to let the users use SZArrayRecord abstraction. +internal sealed class SZArrayOfRecords : SZArrayRecord { private TypeName? _typeName; - internal ArrayOfClassesRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) + internal SZArrayOfRecords(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) : base(arrayInfo) { MemberTypeInfo = memberTypeInfo; @@ -29,12 +33,12 @@ public override TypeName TypeName => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo); /// - public override ClassRecord?[] GetArray(bool allowNulls = true) - => (ClassRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false)); + public override SerializationRecord?[] GetArray(bool allowNulls = true) + => (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false)); - private ClassRecord?[] ToArray(bool allowNulls) + private SerializationRecord?[] ToArray(bool allowNulls) { - ClassRecord?[] result = new ClassRecord?[Length]; + SerializationRecord?[] result = new SerializationRecord?[Length]; int resultIndex = 0; foreach (SerializationRecord record in Records) @@ -43,9 +47,9 @@ public override TypeName TypeName ? referenceRecord.GetReferencedRecord() : record; - if (actual is ClassRecord classRecord) + if (actual is not NullsRecord nullsRecord) { - result[resultIndex++] = classRecord; + result[resultIndex++] = actual; } else { @@ -54,7 +58,7 @@ public override TypeName TypeName ThrowHelper.ThrowArrayContainedNulls(); } - int nullCount = ((NullsRecord)actual).NullCount; + int nullCount = nullsRecord.NullCount; Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount."); do { diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs index 2a5f7b945ce85..0c5193cd92272 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs @@ -3,6 +3,7 @@ using System.IO; using System.Formats.Nrbf.Utils; +using System.Reflection.Metadata; namespace System.Formats.Nrbf; @@ -21,84 +22,100 @@ private SystemClassWithMembersAndTypesRecord(ClassInfo classInfo, MemberTypeInfo public override SerializationRecordType RecordType => SerializationRecordType.SystemClassWithMembersAndTypes; - internal static SystemClassWithMembersAndTypesRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options) + internal static SerializationRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options) { ClassInfo classInfo = ClassInfo.Decode(reader); MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, classInfo.MemberNames.Count, options, recordMap); // the only difference with ClassWithMembersAndTypesRecord is that we don't read library id here classInfo.LoadTypeName(options); - return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo); - } + TypeName typeName = classInfo.TypeName; - internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType() - => MemberTypeInfo.GetNextAllowedRecordType(MemberValues.Count); + // BinaryFormatter represents primitive types as MemberPrimitiveTypedRecord + // only for arrays of objects. For other arrays, like arrays of some abstraction + // (example: new IComparable[] { int.MaxValue }), it uses SystemClassWithMembersAndTypes. + // The same goes for root records that turn out to be primitive types. + // We want to have the behavior unified, so we map such records to + // PrimitiveTypeRecord so the users don't need to learn the BF internals + // to get a single primitive value. + // We need to be as strict as possible, as we don't want to map anything else by accident. + // That is why the code below is VERY defensive. - // For the root records that turn out to be primitive types, we map them to - // PrimitiveTypeRecord so the users don't need to learn the BF internals - // to get a single primitive value! - internal SerializationRecord TryToMapToUserFriendly() - { - if (!TypeName.IsSimple) + if (!classInfo.TypeName.IsSimple || classInfo.MemberNames.Count == 0 || memberTypeInfo.Infos[0].BinaryType != BinaryType.Primitive) { - return this; + return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo); } - - if (MemberValues.Count == 1) + else if (classInfo.MemberNames.Count == 1) { - if (HasMember("m_value")) + PrimitiveType primitiveType = (PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!; + // Get the member name without allocating on the heap. + Collections.Generic.Dictionary.Enumerator structEnumerator = classInfo.MemberNames.GetEnumerator(); + _ = structEnumerator.MoveNext(); + string memberName = structEnumerator.Current.Key; + // Everything needs to match: primitive type, type name name and member name. + return (primitiveType, typeName.FullName, memberName) switch { - return MemberValues[0] switch - { - // there can be a value match, but no TypeName match - bool value when TypeNameMatches(typeof(bool)) => Create(value), - byte value when TypeNameMatches(typeof(byte)) => Create(value), - sbyte value when TypeNameMatches(typeof(sbyte)) => Create(value), - char value when TypeNameMatches(typeof(char)) => Create(value), - short value when TypeNameMatches(typeof(short)) => Create(value), - ushort value when TypeNameMatches(typeof(ushort)) => Create(value), - int value when TypeNameMatches(typeof(int)) => Create(value), - uint value when TypeNameMatches(typeof(uint)) => Create(value), - long value when TypeNameMatches(typeof(long)) => Create(value), - ulong value when TypeNameMatches(typeof(ulong)) => Create(value), - float value when TypeNameMatches(typeof(float)) => Create(value), - double value when TypeNameMatches(typeof(double)) => Create(value), - _ => this - }; - } - else if (HasMember("value")) - { - return MemberValues[0] switch - { - // there can be a value match, but no TypeName match - long value when TypeNameMatches(typeof(IntPtr)) => Create(new IntPtr(value)), - ulong value when TypeNameMatches(typeof(UIntPtr)) => Create(new UIntPtr(value)), - _ => this - }; - } - else if (HasMember("_ticks") && GetRawValue("_ticks") is long ticks && TypeNameMatches(typeof(TimeSpan))) - { - return Create(new TimeSpan(ticks)); - } + (PrimitiveType.Boolean, "System.Boolean", "m_value") => Create(reader.ReadBoolean()), + (PrimitiveType.Byte, "System.Byte", "m_value") => Create(reader.ReadByte()), + (PrimitiveType.SByte, "System.SByte", "m_value") => Create(reader.ReadSByte()), + (PrimitiveType.Char, "System.Char", "m_value") => Create(reader.ParseChar()), + (PrimitiveType.Int16, "System.Int16", "m_value") => Create(reader.ReadInt16()), + (PrimitiveType.UInt16, "System.UInt16", "m_value") => Create(reader.ReadUInt16()), + (PrimitiveType.Int32, "System.Int32", "m_value") => Create(reader.ReadInt32()), + (PrimitiveType.UInt32, "System.UInt32", "m_value") => Create(reader.ReadUInt32()), + (PrimitiveType.Int64, "System.Int64", "m_value") => Create(reader.ReadInt64()), + (PrimitiveType.Int64, "System.IntPtr", "value") => Create(new IntPtr(reader.ReadInt64())), + (PrimitiveType.Int64, "System.TimeSpan", "_ticks") => Create(new TimeSpan(reader.ReadInt64())), + (PrimitiveType.UInt64, "System.UInt64", "m_value") => Create(reader.ReadUInt64()), + (PrimitiveType.UInt64, "System.UIntPtr", "value") => Create(new UIntPtr(reader.ReadUInt64())), + (PrimitiveType.Single, "System.Single", "m_value") => Create(reader.ReadSingle()), + (PrimitiveType.Double, "System.Double", "m_value") => Create(reader.ReadDouble()), + _ => new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo) + }; } - else if (MemberValues.Count == 2 - && HasMember("ticks") && HasMember("dateData") - && GetRawValue("ticks") is long && GetRawValue("dateData") is ulong dateData - && TypeNameMatches(typeof(DateTime))) + else if (classInfo.MemberNames.Count == 2 && typeName.FullName == "System.DateTime" + && HasMember("ticks", 0, PrimitiveType.Int64) + && HasMember("dateData", 1, PrimitiveType.UInt64)) { - return Create(Utils.BinaryReaderExtensions.CreateDateTimeFromData(dateData)); + return DecodeDateTime(reader, classInfo.Id); } - else if (MemberValues.Count == 4 - && HasMember("lo") && HasMember("mid") && HasMember("hi") && HasMember("flags") - && GetRawValue("lo") is int lo && GetRawValue("mid") is int mid - && GetRawValue("hi") is int hi && GetRawValue("flags") is int flags - && TypeNameMatches(typeof(decimal))) + else if (classInfo.MemberNames.Count == 4 && typeName.FullName == "System.Decimal" + && HasMember("flags", 0, PrimitiveType.Int32) + && HasMember("hi", 1, PrimitiveType.Int32) + && HasMember("lo", 2, PrimitiveType.Int32) + && HasMember("mid", 3, PrimitiveType.Int32)) { - return Create(new decimal([lo, mid, hi, flags])); + return DecodeDecimal(reader, classInfo.Id); } - return this; + return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo); SerializationRecord Create(T value) where T : unmanaged - => new MemberPrimitiveTypedRecord(value, Id); + => new MemberPrimitiveTypedRecord(value, classInfo.Id); + + bool HasMember(string name, int order, PrimitiveType primitiveType) + => classInfo.MemberNames.TryGetValue(name, out int memberOrder) + && memberOrder == order + && ((PrimitiveType)memberTypeInfo.Infos[order].AdditionalInfo!) == primitiveType; + } + + internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType() + => MemberTypeInfo.GetNextAllowedRecordType(MemberValues.Count); + + internal static MemberPrimitiveTypedRecord DecodeDateTime(BinaryReader reader, SerializationRecordId id) + { + _ = reader.ReadInt64(); // ticks are not used, but they need to be read as they go first in the payload + ulong dateData = reader.ReadUInt64(); + + return new MemberPrimitiveTypedRecord(BinaryReaderExtensions.CreateDateTimeFromData(dateData), id); + } + + internal static MemberPrimitiveTypedRecord DecodeDecimal(BinaryReader reader, SerializationRecordId id) + { + int flags = reader.ReadInt32(); + int hi = reader.ReadInt32(); + int lo = reader.ReadInt32(); + int mid = reader.ReadInt32(); + + return new MemberPrimitiveTypedRecord(new decimal([lo, mid, hi, flags]), id); } } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs index d5baa09dbd8fc..8bb3ac3a1107b 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs @@ -33,7 +33,7 @@ internal static BinaryArrayType ReadArrayType(this BinaryReader reader) { // To simplify the behavior and security review of the BinaryArrayRecord type, we // do not support reading non-zero-offset arrays. If this should change in the - // future, the BinaryArrayRecord.Decode method and supporting infrastructure + // future, the NrbfDecoder.DecodeBinaryArrayRecord method and supporting infrastructure // will need re-review. byte arrayType = reader.ReadByte(); diff --git a/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs b/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs new file mode 100644 index 0000000000000..18e39a5fd68e1 --- /dev/null +++ b/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs @@ -0,0 +1,516 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.Serialization; +using Microsoft.DotNet.XUnitExtensions; +using Xunit; + +namespace System.Formats.Nrbf.Tests +{ + public class ArrayOfSerializationRecordsTests : ReadTests + { + public enum ElementType + { + Object, + NonGeneric, + Generic + } + + [Serializable] + public class CustomClassThatImplementsIEnumerable : IEnumerable + { + public int Field; + + public IEnumerator GetEnumerator() => Array.Empty().GetEnumerator(); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsStringRecord_SZ(ElementType elementType) + { + const string Text = "hello"; + Array input = elementType switch + { + ElementType.Object => new object[] { Text }, + ElementType.NonGeneric => new IEnumerable[] { Text }, + ElementType.Generic => new IEnumerable[] { Text }, + _ => throw new InvalidOperationException() + }; + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output.Single(); + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsStringRecord_MD(ElementType elementType) + { + const string Text = "hello"; + Array input = elementType switch + { + ElementType.Object => new object[1, 1], + ElementType.NonGeneric => new IEnumerable[1, 1], + ElementType.Generic => new IEnumerable[1, 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(Text, 0, 0); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, 0]; + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsStringRecord_Jagged(ElementType elementType) + { + const string Text = "hello"; + Array input = elementType switch + { + ElementType.Object => new object[1][] { [Text] }, + ElementType.NonGeneric => new IEnumerable[1][] { [Text] }, + ElementType.Generic => new IEnumerable[1][] { [Text] }, + _ => throw new InvalidOperationException() + }; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + + SZArrayRecord contained = (SZArrayRecord)output.Single(); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)contained.GetArray().Single(); + Assert.Equal(Text, stringRecord.Value); + } + + [ConditionalTheory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_SZ(ElementType elementType) + { + if (elementType != ElementType.Object && !IsPatched) + { + throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix."); + } + + const int Integer = 123; + Array input = elementType switch + { + ElementType.Object => new object[] { Integer }, + ElementType.NonGeneric => new IComparable[] { Integer }, + ElementType.Generic => new IComparable[] { Integer }, + _ => throw new InvalidOperationException() + }; + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)output.Single(); + Assert.Equal(Integer, intRecord.Value); + } + + [ConditionalTheory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_MD(ElementType elementType) + { + if (elementType != ElementType.Object && !IsPatched) + { + throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix."); + } + + const int Integer = 123; + Array input = elementType switch + { + ElementType.Object => new object[1, 1], + ElementType.NonGeneric => new IComparable[1, 1], + ElementType.Generic => new IComparable[1, 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(Integer, 0, 0); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)output[0, 0]; + Assert.Equal(Integer, intRecord.Value); + } + + [ConditionalTheory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_Jagged(ElementType elementType) + { + if (elementType != ElementType.Object && !IsPatched) + { + throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix."); + } + + const int Integer = 123; + Array input = elementType switch + { + ElementType.Object => new object[1][] { [Integer] }, + ElementType.NonGeneric => new IComparable[1][] { [Integer] }, + ElementType.Generic => new IComparable[1][] { [Integer] }, + _ => throw new InvalidOperationException() + }; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord contained = (SZArrayRecord)output.Single(); + PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)contained.GetArray().Single(); + Assert.Equal(Integer, intRecord.Value); + } + + public static IEnumerable NullAndArrayPermutations() + { + foreach (ElementType elementType in Enum.GetValues(typeof(ElementType))) + { + yield return new object[] { elementType, 1 }; // ObjectNullRecord + yield return new object[] { elementType, 200 }; // ObjectNullMultiple256Record + yield return new object[] { elementType, 1_000 }; // ObjectNullMultipleRecord + } + } + + [Theory] + [MemberData(nameof(NullAndArrayPermutations))] + public void CanReadArrayThatContainsNullRecords_SZ(ElementType elementType, int nullCount) + { + const string Text = "notNull"; + Array input = elementType switch + { + ElementType.Object => new object[nullCount + 1], + ElementType.NonGeneric => new IEnumerable[nullCount + 1], + ElementType.Generic => new IEnumerable[nullCount + 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(Text, nullCount); + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord?[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + Assert.All(output.Take(nullCount), Assert.Null); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[nullCount]; + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [MemberData(nameof(NullAndArrayPermutations))] + public void CanReadArrayThatContainsNullRecords_MD(ElementType elementType, int nullCount) + { + const string Text = "notNull"; + Array input = elementType switch + { + ElementType.Object => new object[1, nullCount + 1], + ElementType.NonGeneric => new IEnumerable[1, nullCount + 1], + ElementType.Generic => new IEnumerable[1, nullCount + 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(Text, 0, nullCount); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + for (int i = 0; i < nullCount; i++) + { + Assert.Null(output[0, i]); + } + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, nullCount]; + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [MemberData(nameof(NullAndArrayPermutations))] + public void CanReadArrayThatContainsNullRecords_Jagged(ElementType elementType, int nullCount) + { + const string Text = "notNull"; + Array input = elementType switch + { + ElementType.Object => new object[1][] { new object[nullCount + 1] }, + ElementType.NonGeneric => new IEnumerable[1][] { new IEnumerable[nullCount + 1] }, + ElementType.Generic => new IEnumerable[1][] { new IEnumerable[nullCount + 1] }, + _ => throw new InvalidOperationException() + }; + ((Array)input.GetValue(0)).SetValue(Text, nullCount); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord contained = (SZArrayRecord)output.Single(); + Assert.All(contained.GetArray().Take(nullCount), Assert.Null); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)contained.GetArray()[nullCount]; + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsArrayRecord_SZ(ElementType elementType) + { + int[] intArray = [1, 2, 3]; + Array input = elementType switch + { + ElementType.Object => new object[] { intArray }, + ElementType.NonGeneric => new IEnumerable[] { intArray }, + ElementType.Generic => new IEnumerable[] { intArray }, + _ => throw new InvalidOperationException() + }; + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord intArrayRecord = (SZArrayRecord)output.Single(); + Assert.Equal(intArray, intArrayRecord.GetArray()); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsArrayRecord_MD(ElementType elementType) + { + int[] intArray = [1, 2, 3]; + Array input = elementType switch + { + ElementType.Object => new object[1, 1], + ElementType.NonGeneric => new IEnumerable[1, 1], + ElementType.Generic => new IEnumerable[1, 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(intArray, 0, 0); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord intArrayRecord = (SZArrayRecord)output[0, 0]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsArrayRecord_Jagged(ElementType elementType) + { + int[] intArray = [1, 2, 3]; + Array input = elementType switch + { + ElementType.Object => new object[1][] { [intArray] }, + ElementType.NonGeneric => new IEnumerable[1][] { [intArray] }, + ElementType.Generic => new IEnumerable[1][] { [intArray] }, + _ => throw new InvalidOperationException() + }; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord contained = (SZArrayRecord)output.Single(); + SZArrayRecord intArrayRecord = (SZArrayRecord)contained.GetArray().Single(); + Assert.Equal(intArray, intArrayRecord.GetArray()); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + public void CanReadArrayThatContainsAllRecordTypes_SZ(ElementType elementType) + { + const string Text = "hello"; + int[] intArray = [1, 2, 3]; + CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 }; + Array input = elementType switch + { + ElementType.Object => new object[] + { + Text, // BinaryObjectStringRecord + intArray, // ArraySinglePrimitiveRecord + classThatImplementsIEnumerable, // ClassWithMembersAndTypesRecord, + null // ObjectNullRecord + }, + ElementType.NonGeneric => new IEnumerable[] { Text, intArray, classThatImplementsIEnumerable, null }, + _ => throw new InvalidOperationException() + }; + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0]; + Assert.Equal(Text, stringRecord.Value); + SZArrayRecord intArrayRecord = (SZArrayRecord)output[1]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + ClassRecord classRecord = (ClassRecord)output[2]; + Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field))); + Assert.Null(output[3]); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + public void CanReadArrayThatContainsAllRecordTypes_MD(ElementType elementType) + { + const string Text = "hello"; + int[] intArray = [1, 2, 3]; + CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 }; + + Array input = elementType switch + { + ElementType.Object => new object[1, 4], + ElementType.NonGeneric => new IEnumerable[1, 4], + _ => throw new InvalidOperationException() + }; + input.SetValue(Text, 0, 0); + input.SetValue(intArray, 0, 1); + input.SetValue(classThatImplementsIEnumerable, 0, 2); + input.SetValue(null, 0, 3); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, 0]; + Assert.Equal(Text, stringRecord.Value); + SZArrayRecord intArrayRecord = (SZArrayRecord)output[0, 1]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + ClassRecord classRecord = (ClassRecord)output[0, 2]; + Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field))); + Assert.Null(output[0, 3]); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + public void CanReadArrayThatContainsAllRecordTypes_Jagged(ElementType elementType) + { + const string Text = "hello"; + int[] intArray = [1, 2, 3]; + CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 }; + + Array input = elementType switch + { + ElementType.Object => new object[1][] { [Text, intArray, classThatImplementsIEnumerable, null] }, + ElementType.NonGeneric => new IEnumerable[1][] { [Text, intArray, classThatImplementsIEnumerable, null] }, + _ => throw new InvalidOperationException() + }; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord contained = (SZArrayRecord)output.Single(); + SerializationRecord[] records = contained.GetArray(); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)records[0]; + Assert.Equal(Text, stringRecord.Value); + SZArrayRecord intArrayRecord = (SZArrayRecord)records[1]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + ClassRecord classRecord = (ClassRecord)records[2]; + Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field))); + Assert.Null(records[3]); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + public void CanReadArrayThatContainsAllRecordTypes_Jagged_MD(ElementType elementType) + { + const string Text = "hello"; + int[] intArray = [1, 2, 3]; + CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 }; + + Array input = elementType switch + { + ElementType.Object => new object[1, 1][,], + ElementType.NonGeneric => new IEnumerable[1, 1][,], + _ => throw new InvalidOperationException() + }; + Array contained = elementType switch + { + ElementType.Object => new object[2, 2], + ElementType.NonGeneric => new IEnumerable[2, 2], + _ => throw new InvalidOperationException() + }; + contained.SetValue(Text, 0, 0); + contained.SetValue(intArray, 0, 1); + contained.SetValue(classThatImplementsIEnumerable, 1, 0); + input.SetValue(contained, 0, 0); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[,] output = (ArrayRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SerializationRecord[,] records = (SerializationRecord[,])output[0, 0].GetArray(contained.GetType()); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)records[0, 0]; + Assert.Equal(Text, stringRecord.Value); + SZArrayRecord intArrayRecord = (SZArrayRecord)records[0, 1]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + ClassRecord classRecord = (ClassRecord)records[1, 0]; + Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field))); + Assert.Null(records[1, 1]); + } + + [Fact] + public void TypeMismatch() + { + // An array of strings that contains non-string. + byte[] bytes = Convert.FromBase64String("AAEAAAD/////AQAAAAAAAAAHAQAAAAICAAAAAQAAAAEAAAABCQEAAAAL"); + + ArrayRecord arrRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(bytes)); + + Assert.Throws(() => arrRecord.GetArray(typeof(string[,]))); + } + + private static void Verify(Array input, ArrayRecord arrayRecord, Array output, + IReadOnlyDictionary recordMap) + { + Assert.Equal(input.Rank, arrayRecord.Rank); + Assert.Equal(input.Rank, output.Rank); + + for (int i = 0; i < input.Rank; i++) + { + Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]); + Assert.Equal(input.GetLength(i), output.GetLength(i)); + } + + foreach (object? recordOrNull in output) + { + if (recordOrNull is SerializationRecord record && !record.Id.Equals(default)) + { + // An array of abstractions always uses SystemClassWithMembersAndTypesRecord to represent primitive values. + // This requires some non-trivial mapping and we need to ensure that it's reflected not only in what + // has been stored in the array, but also in the record map. + Assert.Same(record, recordMap[record.Id]); + } + } + } + } +} diff --git a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs index 49d523088a89f..7ef801808e4e9 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs @@ -5,6 +5,7 @@ using System.IO; using System.Runtime.Serialization; using System.Text; +using Microsoft.DotNet.XUnitExtensions; using Xunit; namespace System.Formats.Nrbf.Tests; @@ -71,63 +72,63 @@ public void DontCastBytesToDateTimes() Assert.Throws(() => NrbfDecoder.Decode(stream)); } - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Bool(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Byte(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_SByte(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Char(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Int16(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_UInt16(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Int32(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_UInt32(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Int64(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_UInt64(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Single(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Double(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_TimeSpan(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_DateTime(int size, bool canSeek) => Test(size, canSeek); - private void Test(int size, bool canSeek) + private void Test(int size, bool canSeek) where T : IComparable { Random constSeed = new Random(27644437); T[] input = new T[size]; @@ -136,17 +137,69 @@ private void Test(int size, bool canSeek) input[i] = GenerateValue(constSeed); } + TestSZArrayOfT(input, size, canSeek); + TestSZArrayOfIComparable(input, size, canSeek); + } + + private void TestSZArrayOfT(T[] input, int size, bool canSeek) + { MemoryStream stream = Serialize(input); stream = canSeek ? stream : new NonSeekableStream(stream.ToArray()); SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(stream); Assert.Equal(size, arrayRecord.Length); - Assert.Equal(size, arrayRecord.FlattenedLength); T?[] output = arrayRecord.GetArray(); Assert.Equal(input, output); Assert.Same(output, arrayRecord.GetArray()); } + private void TestSZArrayOfIComparable(T[] input, int size, bool canSeek) where T : IComparable + { + if (!IsPatched) + { + throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix."); + } + + // Arrays of abstractions that store primitive values (example: new IComparable[1] { int.MaxValue }) + // are represented by BinaryFormatter with a single SystemClassWithMembersAndTypesRecord + // and multiple ClassWithIdRecord that re-use the information from the system record. + // This requires some non-trivial mapping and this test is very important as it covers that code path. + IComparable[] comparables = new IComparable[size]; + for (int i = 0; i < input.Length; i++) + { + comparables[i] = input[i]; + } + + TestArrayOfSerializationRecords(input, comparables, canSeek); + } + + private void TestSZArrayOfObjects(T[] input, int size, bool canSeek) + { + // Arrays of objects that store primitive values (example: new object[1] { int.MaxValue }) + // are represented by BinaryFormatter with MemberPrimitiveTypedRecord instances. + object[] objects = new object[size]; + for (int i = 0; i < input.Length; i++) + { + objects[i] = input[i]; + } + + TestArrayOfSerializationRecords(input, objects, canSeek); + } + + private void TestArrayOfSerializationRecords(T[] values, object input, bool canSeek) + { + MemoryStream stream = Serialize(input); + + stream = canSeek ? stream : new NonSeekableStream(stream.ToArray()); + SZArrayRecord arrayRecordOfPrimitiveRecords = (SZArrayRecord)NrbfDecoder.Decode(stream); + SerializationRecord[] arrayOfPrimitiveRecords = arrayRecordOfPrimitiveRecords.GetArray(); + for (int i = 0; i < values.Length; i++) + { + Assert.Equal(values[i], ((PrimitiveTypeRecord)arrayOfPrimitiveRecords[i]).Value); + Assert.Equal(values[i], ((PrimitiveTypeRecord)arrayOfPrimitiveRecords[i]).Value); + } + } + private static T GenerateValue(Random random) { if (typeof(T) == typeof(byte)) diff --git a/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs b/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs index fe780d94698df..3a81e3f131c82 100644 --- a/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs @@ -50,18 +50,51 @@ public void CyclicReferencesInSystemClassesDoNotCauseStackOverflow() } [Fact] - public void CyclicReferencesInArraysOfObjectsDoNotCauseStackOverflow() + public void CyclicReferencesInSZArraysOfObjectsDoNotCauseStackOverflow() { object[] input = new object[2]; input[0] = "not an array"; input[1] = input; ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); - object?[] output = ((SZArrayRecord)arrayRecord).GetArray(); + SerializationRecord?[] output = ((SZArrayRecord)arrayRecord).GetArray(); - Assert.Equal(input[0], output[0]); + Assert.Equal(input[0], ((PrimitiveTypeRecord)output[0]).Value); Assert.Same(input, input[1]); - Assert.Same(output, output[1]); + Assert.Same(arrayRecord, output[1]); + } + + [Fact] + public void CyclicReferencesInMDArraysOfObjectsDoNotCauseStackOverflow() + { + object[,] input = new object[2, 2]; + input[0, 0] = "not an array"; + input[1, 1] = input; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); + SerializationRecord?[,] output = (SerializationRecord?[,])arrayRecord.GetArray(typeof(object[,])); + + Assert.Equal(input[0, 0], ((PrimitiveTypeRecord)output[0, 0]).Value); + Assert.Same(input, input[1, 1]); + Assert.Same(arrayRecord, output[1, 1]); + } + + [Fact] + public void CyclicReferencesInJaggedArraysOfObjectsDoNotCauseStackOverflow() + { + object[][] input = new object[1][]; + input[0] = new object[2]; + input[0][0] = "not an array"; + input[0][1] = input; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(typeof(object[][])); + SZArrayRecord row = (SZArrayRecord)output.Single(); + SerializationRecord[] contained = row.GetArray(); + + Assert.Equal(input[0][0], ((PrimitiveTypeRecord)contained[0]).Value); + Assert.Same(input, input[0][1]); + Assert.Same(arrayRecord, contained[1]); } [Serializable] @@ -81,8 +114,8 @@ public void CyclicClassReferencesInArraysOfObjectsDoNotCauseStackOverflow() ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input)); Assert.Equal(input.Name, classRecord.GetString(nameof(WithCyclicReferenceInArrayOfObjects.Name))); - SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfObjects.ArrayWithReferenceToSelf))!; - object?[] array = arrayRecord.GetArray(); + SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfObjects.ArrayWithReferenceToSelf))!; + SerializationRecord?[] array = arrayRecord.GetArray(); Assert.Same(classRecord, array.Single()); } @@ -103,7 +136,7 @@ public void CyclicClassReferencesInArraysOfTDoNotCauseStackOverflow() ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input)); Assert.Equal(input.Name, classRecord.GetString(nameof(WithCyclicReferenceInArrayOfT.Name))); - SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfT.ArrayWithReferenceToSelf))!; + SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfT.ArrayWithReferenceToSelf))!; Assert.Same(classRecord, classRecords.GetArray().Single()); } diff --git a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs index 6acb44d03697d..2d78954d64909 100644 --- a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs @@ -355,6 +355,36 @@ public void ThrowsForInvalidPositiveArrayRank(int rank, byte arrayType) Assert.Throws(() => NrbfDecoder.Decode(stream)); } + public static IEnumerable AllPrimitiveTypes() + { + foreach (PrimitiveType primitiveType in Enum.GetValues(typeof(PrimitiveType))) + { + yield return new object[] { (byte)primitiveType }; + } + } + + [Theory] + [MemberData(nameof(AllPrimitiveTypes))] + public void ThrowsForInvalidPrimitiveTypeForBinaryArrayRecords(byte primitiveType) + { + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + + writer.Write((byte)SerializationRecordType.BinaryArray); + writer.Write(1); // object Id + writer.Write((byte)BinaryArrayType.Jagged); + writer.Write(1); // rank! + writer.Write(1); // length + writer.Write((byte)BinaryType.Primitive); // A jagged array must consist of other arrays, not primitive values + writer.Write(primitiveType); + writer.Write((byte)SerializationRecordType.MessageEnd); + + stream.Position = 0; + Assert.Throws(() => NrbfDecoder.Decode(stream)); + } + [Theory] [InlineData(SerializationRecordType.ClassWithMembersAndTypes)] [InlineData(SerializationRecordType.SystemClassWithMembersAndTypes)] diff --git a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs index 8bb844ff76a58..f02128ab08f99 100644 --- a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs @@ -24,40 +24,43 @@ public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences) var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); + ArrayRecord?[] output = (ArrayRecord?[])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(input[i], ((SZArrayRecord)output[i]).GetArray()); + if (useReferences) + { + Assert.Same(((SZArrayRecord)output[0]).GetArray(), ((SZArrayRecord)output[i]).GetArray()); + } + } } [Theory] [InlineData(1)] // SerializationRecordType.ObjectNull [InlineData(200)] // SerializationRecordType.ObjectNullMultiple256 [InlineData(10_000)] // SerializationRecordType.ObjectNullMultiple - public void FlattenedLengthIncludesNullArrays(int nullCount) + public void NullRecordsOfAllKindsAreHandledProperly(int nullCount) { int[][] input = new int[nullCount][]; var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(nullCount, arrayRecord.FlattenedLength); + ArrayRecord?[] output = (ArrayRecord?[])arrayRecord.GetArray(input.GetType()); + Assert.All(output, Assert.Null); } [Fact] public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged() { int[][][] input = new int[3][][]; - long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { input[i] = new int[4][]; - totalElementsCount++; // count the arrays themselves for (int j = 0; j < input[i].Length; j++) { input[i][j] = [i, j, 0, 1, 2]; - totalElementsCount += input[i][j].Length; - totalElementsCount++; // count the arrays themselves } } @@ -75,57 +78,105 @@ public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutB var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(serialized)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount); - Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); + ArrayRecord?[] output = (ArrayRecord?[])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.Length; i++) + { + ArrayRecord[] firstLevel = (ArrayRecord[])output[i].GetArray(typeof(int[][])); + + for (int j = 0; j < input[i].Length; j++) + { + Assert.Equal(input[i][j], (int[])firstLevel[j].GetArray(typeof(int[]))); + } + } } [Fact] - public void CanReadJaggedArraysOfPrimitiveTypes_3D() + public void CanReadSZJaggedArrayOfMDArrays() { - int[][][] input = new int[7][][]; - long totalElementsCount = 0; + int[][,] input = new int[7][,]; for (int i = 0; i < input.Length; i++) { - totalElementsCount++; // count the arrays themselves - input[i] = new int[1][]; - totalElementsCount++; // count the arrays themselves - input[i][0] = [i, i, i]; - totalElementsCount += input[i][0].Length; + input[i] = new int[3, 3]; + + for (int j = 0; j < input[i].GetLength(0); j++) + { + for (int k = 0; k < input[i].GetLength(1); k++) + { + input[i][j, k] = i * j * k; + } + } } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(1, arrayRecord.Rank); - Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount); - Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(input[i], output[i].GetArray(typeof(int[,]))); + } } [Fact] - public void CanReadJaggedArrayOfRectangularArrays() + public void CanReadMDJaggedArrayOfSZArrays() { - int[][,] input = new int[7][,]; - for (int i = 0; i < input.Length; i++) - { - input[i] = new int[3,3]; + int[,][] input = new int[2,2][]; + input[0, 0] = [1, 2, 3]; - for (int j = 0; j < input[i].GetLength(0); j++) + var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); + + Verify(input, arrayRecord); + ArrayRecord[,] output = (ArrayRecord[,])arrayRecord.GetArray(input.GetType()); + Assert.Equal(input[0, 0], output[0, 0].GetArray(typeof(int[]))); + Assert.Null(output[0, 1]); + Assert.Null(output[1, 0]); + Assert.Null(output[1, 1]); + } + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Integers() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x * y); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Doubles() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x * y / 10); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Strings() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => $"{x},{y}"); + + static void MultiDimensionalArrayOfMultiDimensionalArrays(Func valueFactory) + { + T[,][,] input = new T[2, 2][,]; + for (int i = 0; i < input.GetLength(0); i++) + { + for (int j = 0; j < input.GetLength(1); j++) { - for (int k = 0; k < input[i].GetLength(1); k++) + T[,] contained = new T[i + 1, j + 1]; + for (int k = 0; k < contained.GetLength(0); k++) { - input[i][j, k] = i * j * k; + for (int l = 0; l < contained.GetLength(1); l++) + { + contained[k, l] = valueFactory(k, l); + } } + + input[i, j] = contained; } } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(1, arrayRecord.Rank); - Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength); + + ArrayRecord[,] output = (ArrayRecord[,])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.GetLength(0); i++) + { + for (int j = 0; j < input.GetLength(1); j++) + { + Assert.Equal(input[i, j], output[i, j].GetArray(typeof(T[,]))); + } + } } [Fact] @@ -140,8 +191,11 @@ public void CanReadJaggedArraysOfStrings() var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(input[i], ((SZArrayRecord)output[i]).GetArray()); + } } [Fact] @@ -156,8 +210,16 @@ public void CanReadJaggedArraysOfObjects() var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + for (int i = 0; i < input.Length; i++) + { + SerializationRecord[] row = (SerializationRecord[])output[i].GetArray(typeof(object[])); + for (int j = 0; j < input[i].Length; j++) + { + Assert.Equal(input[i][j], ((PrimitiveTypeRecord)row[j]).Value); + } + } } [Serializable] @@ -170,32 +232,28 @@ public class ComplexType public void CanReadJaggedArraysOfComplexTypes() { ComplexType[][] input = new ComplexType[3][]; - long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray(); - totalElementsCount += input[i].Length; - totalElementsCount++; // count the arrays themselves } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); - var output = (ClassRecord?[][])arrayRecord.GetArray(input.GetType()); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); for (int i = 0; i < input.Length; i++) { + SerializationRecord[] row = ((SZArrayRecord)output[i]).GetArray(); for (int j = 0; j < input[i].Length; j++) { - Assert.Equal(input[i][j].SomeField, output[i][j]!.GetInt32(nameof(ComplexType.SomeField))); + Assert.Equal(input[i][j].SomeField, ((ClassRecord)row[j]!).GetInt32(nameof(ComplexType.SomeField))); } } } private static void Verify(Array input, ArrayRecord arrayRecord) { - Assert.Equal(1, arrayRecord.Lengths.Length); - Assert.Equal(input.Length, arrayRecord.Lengths[0]); + Assert.Equal(input.Rank, arrayRecord.Rank); Assert.True(arrayRecord.TypeName.GetElementType().IsArray); // true only for Jagged arrays Assert.Equal(input.GetType().FullName, arrayRecord.TypeName.FullName); Assert.Equal(input.GetType().GetAssemblyNameIncludingTypeForwards(), arrayRecord.TypeName.AssemblyName!.FullName); diff --git a/src/libraries/System.Formats.Nrbf/tests/ReadAnythingTests.cs b/src/libraries/System.Formats.Nrbf/tests/ReadAnythingTests.cs index 66cbab818dbba..3dd4d212ca280 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ReadAnythingTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ReadAnythingTests.cs @@ -32,8 +32,8 @@ public void UserCanReadAnyValidInputAndCheckTypesUsingStronglyTypedTypeInstances ClassRecord comparerRecord = dictionaryRecord.GetClassRecord(nameof(input.Comparer))!; Assert.True(comparerRecord.TypeNameMatches(input.Comparer.GetType())); - SZArrayRecord arrayRecord = (SZArrayRecord)dictionaryRecord.GetSerializationRecord("KeyValuePairs")!; - ClassRecord[] keyValuePairs = arrayRecord.GetArray()!; + SZArrayRecord arrayRecord = (SZArrayRecord)dictionaryRecord.GetSerializationRecord("KeyValuePairs")!; + ClassRecord[] keyValuePairs = arrayRecord.GetArray().OfType().ToArray(); Assert.True(keyValuePairs[0].TypeNameMatches(typeof(KeyValuePair))); ClassRecord exceptionPair = Find(keyValuePairs, "exception"); @@ -225,8 +225,8 @@ public void UserCanReadEveryPossibleSerializationRecord(object input) case ClassRecord record when record.TypeNameMatches(typeof(Exception)): Assert.Equal(((Exception)input).Message, record.GetString("Message")); break; - case SZArrayRecord record when record.TypeNameMatches(typeof(Exception[])): - Assert.Equal(((Exception[])input)[0].Message, record.GetArray()[0]!.GetString("Message")); + case SZArrayRecord record when record.TypeNameMatches(typeof(Exception[])): + Assert.Equal(((Exception[])input)[0].Message, ((ClassRecord)record.GetArray()[0]!).GetString("Message")); break; case ClassRecord record when record.TypeNameMatches(typeof(JsonException)): Assert.Equal(((JsonException)input).Message, record.GetString("Message")); @@ -241,7 +241,11 @@ public void UserCanReadEveryPossibleSerializationRecord(object input) Assert.Empty(record.MemberNames); break; case ArrayRecord arrayRecord when arrayRecord.TypeNameMatches(typeof(int?[])): - Assert.Equal(input, arrayRecord.GetArray(typeof(int?[]))); + SerializationRecord?[] nullableArray = (SerializationRecord?[])arrayRecord.GetArray(typeof(int?[])); + Assert.Equal(((int?[])input)[0], ((PrimitiveTypeRecord)nullableArray[0]).Value); + Assert.Equal(((int?[])input)[1], ((PrimitiveTypeRecord)nullableArray[1]).Value); + Assert.Equal(((int?[])input)[2], ((PrimitiveTypeRecord)nullableArray[2]).Value); + Assert.Null(nullableArray[3]); break; case ArrayRecord arrayRecord when arrayRecord.TypeNameMatches(typeof(EmptyClass[])): Assert.Equal(0, arrayRecord.Lengths.ToArray().Single()); @@ -262,11 +266,58 @@ public void UserCanReadEveryPossibleSerializationRecord(object input) static void VerifyDictionary(ClassRecord record) { - SZArrayRecord arrayRecord = (SZArrayRecord)record.GetSerializationRecord("KeyValuePairs")!; - ClassRecord[] keyValuePairs = arrayRecord.GetArray()!; + SZArrayRecord arrayRecord = (SZArrayRecord)record.GetSerializationRecord("KeyValuePairs")!; + ClassRecord[] keyValuePairs = arrayRecord.GetArray().OfType().ToArray(); Assert.True(keyValuePairs[0].TypeNameMatches(typeof(KeyValuePair))); } } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void UserCanReadArrayOfBaseType(bool mixed) + { + Mammal[] input = mixed + ? [new Dog { Name = "Buddy" }, new Cat { Name = "Luna" }] + : [new Dog { Name = "Buddy" }, new Dog { Name = "Rocky" }]; + + SZArrayRecord root = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input)); + + SerializationRecord[] output = (SerializationRecord[])root.GetArray(typeof(Mammal[])); + Assert.True(output[0].TypeNameMatches(typeof(Dog))); + Assert.True(output[1].TypeNameMatches(mixed ? typeof(Cat) : typeof(Dog))); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void UserCanReadArrayOfDerivedTypes(bool dogs) + { + Array input = dogs + ? new Dog[] { new Dog { Name = "Buddy" }, new Dog { Name = "Rocky" } } + : new Cat[] { new Cat { Name = "Luna" }, new Cat { Name = "Tiger" } }; + + SZArrayRecord root = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input)); + + Type expected = root.TypeName.GetElementType().FullName == typeof(Dog).FullName + ? typeof(Dog[]) + : typeof(Cat[]); + + SerializationRecord[] output = (SerializationRecord[])root.GetArray(expected); + Assert.All(output, record => record.TypeNameMatches(expected.GetElementType())); + } + + [Serializable] + public class Mammal + { + public string Name; + } + + [Serializable] + public class Cat : Mammal { } + + [Serializable] + public class Dog : Mammal { } } [Serializable] diff --git a/src/libraries/System.Formats.Nrbf/tests/ReadExactTypesTests.cs b/src/libraries/System.Formats.Nrbf/tests/ReadExactTypesTests.cs index ccf1dd402fc7b..027293bce05a6 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ReadExactTypesTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ReadExactTypesTests.cs @@ -263,11 +263,11 @@ public void CanRead_ArraysOfComplexTypes() new () { Long = 5 }, ]; - SZArrayRecord arrayRecord = ((SZArrayRecord)NrbfDecoder.Decode(Serialize(input))); + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input)); Assert.Equal(typeof(CustomTypeWithPrimitiveFields[]).FullName, arrayRecord.TypeName.FullName); Assert.Equal(typeof(CustomTypeWithPrimitiveFields).Assembly.FullName, arrayRecord.TypeName.GetElementType().AssemblyName!.FullName); - ClassRecord?[] classRecords = arrayRecord.GetArray(); + ClassRecord?[] classRecords = arrayRecord.GetArray().OfType().ToArray(); for (int i = 0; i < input.Length; i++) { Verify(input[i], classRecords[i]!); @@ -298,8 +298,8 @@ public void CanRead_TypesWithArraysOfComplexTypes() ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input)); - SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfComplexTypes.Array))!; - ClassRecord?[] array = classRecords.GetArray(); + SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfComplexTypes.Array))!; + SerializationRecord?[] array = classRecords.GetArray(); } [Theory] @@ -316,8 +316,8 @@ public void CanRead_TypesWithArraysOfComplexTypes_MultipleNulls(int nullCount) ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(stream); - SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfComplexTypes.Array))!; - ClassRecord?[] array = classRecords.GetArray(); + SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfComplexTypes.Array))!; + SerializationRecord?[] array = classRecords.GetArray(); Assert.Equal(nullCount, array.Length); Assert.All(array, Assert.Null); @@ -337,7 +337,10 @@ public void CanRead_ArraysOfObjects() Assert.Equal(typeof(object[]).FullName, arrayRecord.TypeName.FullName); Assert.Equal("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089", arrayRecord.TypeName.GetElementType().AssemblyName!.FullName); - Assert.Equal(input, ((SZArrayRecord)arrayRecord).GetArray()); + SerializationRecord?[] output = ((SZArrayRecord)arrayRecord).GetArray(); + Assert.Equal(input[0], ((PrimitiveTypeRecord)output[0]).Value); + Assert.Equal(input[1], ((PrimitiveTypeRecord)output[1]).Value); + Assert.Null(output[2]); } [Theory] @@ -348,7 +351,7 @@ public void CanRead_ArraysOfObjects_MultipleNulls(int nullCount) object?[] input = Enumerable.Repeat(null!, nullCount).ToArray(); ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); - object?[] output = ((SZArrayRecord)arrayRecord).GetArray(); + SerializationRecord?[] output = ((SZArrayRecord)arrayRecord).GetArray(); Assert.Equal(nullCount, output.Length); Assert.All(output, Assert.Null); @@ -374,9 +377,13 @@ public void CanRead_CustomTypeWithArrayOfObjects() }; ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input)); - SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfObjects.Array))!; + SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfObjects.Array))!; + SerializationRecord?[] output = arrayRecord.GetArray(); - Assert.Equal(input.Array, arrayRecord.GetArray()); + Assert.Equal(input.Array[0], ((PrimitiveTypeRecord)output[0]).Value); + Assert.Equal(input.Array[1], ((PrimitiveTypeRecord)output[1]).Value); + Assert.Equal(input.Array[2], ((PrimitiveTypeRecord)output[2]).Value); + Assert.Null(output[3]); } [Theory] diff --git a/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs b/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs index 0c7bd2045fa1f..8e5ce021db175 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs @@ -8,6 +8,40 @@ namespace System.Formats.Nrbf.Tests; [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsBinaryFormatterSupported))] public abstract class ReadTests { + public static bool IsPatched +#if NET + => true; +#else + => s_isPatched.Value; + + private static readonly Lazy s_isPatched = new(GetIsPatched); + + private static bool GetIsPatched() + { + Tuple tuple = new Tuple(42, new byte[] { 1, 2, 3, 4 }); +#pragma warning disable SYSLIB0011 // Type or member is obsolete + BinaryFormatter formatter = new(); +#pragma warning restore SYSLIB0011 // Type or member is obsolete + using MemoryStream stream = new(); + + // This particular scenario is going to throw on Full Framework + // if given machine has not installed the July 2024 cumulative update preview: + // https://learn.microsoft.com/dotnet/framework/release-notes/2024/07-25-july-preview-cumulative-update + + try + { + formatter.Serialize(stream, tuple); + stream.Position = 0; + Tuple deserialized = (Tuple)formatter.Deserialize(stream); + return tuple.Item1.Equals(deserialized.Item1); + } + catch (Exception) + { + return false; + } + } +#endif + protected static MemoryStream Serialize(T instance) where T : notnull { MemoryStream ms = new(); diff --git a/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs index 3191d57ba807c..b746faccbdb53 100644 --- a/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs @@ -64,11 +64,18 @@ public void CanReadRectangularArraysOfObjects_2D() using MemoryStream stream = Serialize(array); ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(stream); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(array.GetType()); Verify(array, arrayRecord); Assert.True(arrayRecord.TypeNameMatches(typeof(object[,]))); Assert.False(arrayRecord.TypeNameMatches(typeof(int[,]))); - Assert.Equal(array, arrayRecord.GetArray(typeof(object[,]))); + + for (int i = 0; i < array.GetLength(0); i++) + { + Assert.Equal(array[i, 0], ((PrimitiveTypeRecord)output[i, 0]).Value); + Assert.Equal(array[i, 1], ((PrimitiveTypeRecord)output[i, 1]).Value); + Assert.Null(output[i, 2]); + } } [Serializable] @@ -176,7 +183,14 @@ public void CanReadRectangularArraysOfObjects_3D() Assert.True(arrayRecord.TypeNameMatches(typeof(object[,,]))); Assert.False(arrayRecord.TypeNameMatches(typeof(object[,]))); Assert.False(arrayRecord.TypeNameMatches(typeof(int[,,]))); - Assert.Equal(array, arrayRecord.GetArray(typeof(object[,,]))); + SerializationRecord[,,] output = (SerializationRecord[,,])arrayRecord.GetArray(typeof(object[,,])); + + for (int i = 0; i < array.GetLength(0); i++) + { + Assert.Equal(array[i, 0, 0], ((PrimitiveTypeRecord)output[i, 0, 0]).Value); + Assert.Equal(array[i, 1, 0], ((PrimitiveTypeRecord)output[i, 1, 0]).Value); + Assert.Null(output[i, 2, 0]); + } } [Serializable] @@ -223,13 +237,10 @@ public void CanReadRectangularArraysOfComplexTypes_3D() internal static void Verify(Array input, ArrayRecord arrayRecord) { Assert.Equal(input.Rank, arrayRecord.Lengths.Length); - long totalElementsCount = 1; for (int i = 0; i < input.Rank; i++) { Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]); - totalElementsCount *= input.GetLength(i); } - Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); Assert.Equal(input.GetType().FullName, arrayRecord.TypeName.FullName); Assert.Equal(input.GetType().GetAssemblyNameIncludingTypeForwards(), arrayRecord.TypeName.AssemblyName!.FullName); } diff --git a/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs b/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs index e0827c1225b42..c4c8018cdc464 100644 --- a/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs @@ -362,7 +362,7 @@ private static void VerifySZArray(T input) where T : notnull } else { - Assert.True(arrayRecord is SZArrayRecord, userMessage: typeof(T).Name); + Assert.True(arrayRecord is SZArrayRecord, userMessage: typeof(T).Name); Assert.True(arrayRecord.TypeNameMatches(typeof(T[]))); Assert.Equal(arrayRecord.TypeName.GetElementType().AssemblyName.FullName, typeof(T).GetAssemblyNameIncludingTypeForwards()); } diff --git a/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/ArrayRecordDeserializer.cs b/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/ArrayRecordDeserializer.cs index 0c8e5d00fe454..18981a295abdc 100644 --- a/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/ArrayRecordDeserializer.cs +++ b/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/ArrayRecordDeserializer.cs @@ -121,8 +121,14 @@ internal override SerializationRecordId Continue() }; [RequiresUnreferencedCode("Calls System.Windows.Forms.BinaryFormat.BinaryFormattedObject.TypeResolver.GetType(TypeName)")] - internal static Array? GetSimpleBinaryArray(ArrayRecord arrayRecord, BinaryFormattedObject.ITypeResolver typeResolver) + internal static Array? GetRectangularArrayOfPrimitives(ArrayRecord arrayRecord, BinaryFormattedObject.ITypeResolver typeResolver) { + // Only rectangular, non-jagged BinaryArrayRecord can hit the lucky path below. + if (arrayRecord.Rank <= 1 || arrayRecord.TypeName.GetElementType().IsArray) + { + return null; + } + Type arrayRecordElementType = typeResolver.GetType(arrayRecord.TypeName.GetElementType()); Type elementType = arrayRecordElementType; while (elementType.IsArray) @@ -130,17 +136,12 @@ internal override SerializationRecordId Continue() elementType = elementType.GetElementType()!; } - if (!(HasBuiltInSupport(elementType) - || (Nullable.GetUnderlyingType(elementType) is Type nullable && HasBuiltInSupport(nullable)))) + if (!HasBuiltInSupport(elementType)) { return null; } - Type expectedArrayType = arrayRecord.Rank switch - { - 1 => arrayRecordElementType.MakeArrayType(), - _ => arrayRecordElementType.MakeArrayType(arrayRecord.Rank) - }; + Type expectedArrayType = arrayRecordElementType.MakeArrayType(arrayRecord.Rank); return arrayRecord.GetArray(expectedArrayType); diff --git a/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/Deserializer.cs b/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/Deserializer.cs index 8e763fe850d86..eee808bfc3606 100644 --- a/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/Deserializer.cs +++ b/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/Deserializer.cs @@ -229,7 +229,7 @@ object DeserializeNew(SerializationRecordId id) SerializationRecordType.MemberPrimitiveTyped => ((PrimitiveTypeRecord)record).Value, SerializationRecordType.ArraySingleString => ((SZArrayRecord)record).GetArray(), SerializationRecordType.ArraySinglePrimitive => ArrayRecordDeserializer.GetArraySinglePrimitive(record), - SerializationRecordType.BinaryArray => ArrayRecordDeserializer.GetSimpleBinaryArray((ArrayRecord)record, _typeResolver), + SerializationRecordType.BinaryArray => ArrayRecordDeserializer.GetRectangularArrayOfPrimitives((ArrayRecord)record, _typeResolver), _ => null }; diff --git a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/Common/MultidimensionalArrayTests.cs b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/Common/MultidimensionalArrayTests.cs index 8814a3184ef9e..efac507855764 100644 --- a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/Common/MultidimensionalArrayTests.cs +++ b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/Common/MultidimensionalArrayTests.cs @@ -52,6 +52,84 @@ public void IntegerArrays_Basic() Assert.Equal(threeDimensions, deserialized); } + [Serializable] + public class CustomComparable : IComparable, IEquatable + { + public int Integer; + + public int CompareTo(object? obj) + { + CustomComparable other = (CustomComparable)obj; + + return other.Integer.CompareTo(other.Integer); + } + + public bool Equals(CustomComparable? other) => Integer == other.Integer; + + public override int GetHashCode() => Integer; + + public override bool Equals(object? obj) => obj is CustomComparable other && Equals(other); + } + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Integers() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x * y); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Doubles() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x * y / 10); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Strings() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => $"{x},{y}"); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Abstraction() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x switch + { + 0 => x * y, // int + 1 => x + (double)y / 10, // double + 2 => $"{x},{y}", // string + _ => new CustomComparable() { Integer = x * y } + }); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Objects() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x switch + { + 0 => x * y, // int + 1 => x + (double)y / 10, // double + 2 => $"{x},{y}", // string + _ => new CustomComparable() { Integer = x * y } + }); + + private static void MultiDimensionalArrayOfMultiDimensionalArrays(Func valueFactory) + { + TValue[,][,] input = new TValue[3, 3][,]; + for (int i = 0; i < input.GetLength(0); i++) + { + for (int j = 0; j < input.GetLength(1); j++) + { + TValue[,] contained = new TValue[i + 1, j + 1]; + for (int k = 0; k < contained.GetLength(0); k++) + { + for (int l = 0; l < contained.GetLength(1); l++) + { + contained[k, l] = valueFactory(k, l); + } + } + + input[i, j] = contained; + + object deserializedMd = Deserialize(Serialize(contained)); + Assert.Equal(contained, deserializedMd); + } + } + + object deserializedJagged = Deserialize(Serialize(input)); + Assert.Equal(input, deserializedJagged); + } + [Fact] public void EmptyDimensions() { diff --git a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/BinaryFormattedObjectTests.cs b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/BinaryFormattedObjectTests.cs index f59d2bd47c9ea..b08874b133620 100644 --- a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/BinaryFormattedObjectTests.cs +++ b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/BinaryFormattedObjectTests.cs @@ -38,9 +38,9 @@ public void ReadEmptyHashTable() ClassRecord systemClass = (ClassRecord)format[format.RootRecord.Id]; VerifyHashTable(systemClass, expectedVersion: 0, expectedHashSize: 3); - SZArrayRecord keys = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; + SZArrayRecord keys = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; Assert.Equal(0, keys.Length); - SZArrayRecord values = (SZArrayRecord)systemClass.GetSerializationRecord("Values")!; + SZArrayRecord values = (SZArrayRecord)systemClass.GetSerializationRecord("Values")!; Assert.Equal(0, values.Length); } @@ -77,12 +77,12 @@ public void ReadHashTableWithStringPair() ClassRecord systemClass = (ClassRecord)format[format.RootRecord.Id]; VerifyHashTable(systemClass, expectedVersion: 1, expectedHashSize: 3); - SZArrayRecord keys = (SZArrayRecord)format[systemClass.GetArrayRecord("Keys").Id]; + SZArrayRecord keys = (SZArrayRecord)format[systemClass.GetArrayRecord("Keys").Id]; Assert.Equal(1, keys.Length); - Assert.Equal("This", keys.GetArray().Single()); - SZArrayRecord values = (SZArrayRecord)format[systemClass.GetArrayRecord("Values").Id]; + Assert.Equal("This", ((PrimitiveTypeRecord)keys.GetArray().Single()).Value); + SZArrayRecord values = (SZArrayRecord)format[systemClass.GetArrayRecord("Values").Id]; Assert.Equal(1, values.Length); - Assert.Equal("That", values.GetArray().Single()); + Assert.Equal("That", ((PrimitiveTypeRecord)values.GetArray().Single()).Value); } [Fact] @@ -100,8 +100,9 @@ public void ReadHashTableWithRepeatedStrings() // The collections themselves get ids first before the strings do. // Everything in the second keys is a string reference. - SZArrayRecord array = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; - Assert.Equivalent(new object[] { "TheOther", "That", "This" }, array.GetArray()); + SZArrayRecord arrayRecord = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; + SerializationRecord[] array = arrayRecord.GetArray(); + Assert.Equivalent(new string[] { "TheOther", "That", "This" }, array.OfType>().Select(sr => sr.Value).ToArray()); } [Fact] @@ -119,11 +120,14 @@ public void ReadHashTableWithNullValues() // The collections themselves get ids first before the strings do. // Everything in the second keys is a string reference. - SZArrayRecord keys = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; - Assert.Equivalent(new object[] { "Yowza", "Youza", "Meeza" }, keys.GetArray()); - - SZArrayRecord values = (SZArrayRecord)systemClass.GetSerializationRecord("Values")!; - Assert.Equal(new object?[] { null, null, null }, values.GetArray()); + SZArrayRecord keysRecord = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; + SerializationRecord[] keysRecords = keysRecord.GetArray(); + Assert.Equivalent(new string[] { "Yowza", "Youza", "Meeza" }, keysRecords.OfType>().Select(sr => sr.Value).ToArray()); + + SZArrayRecord valuesRecord = (SZArrayRecord)systemClass.GetSerializationRecord("Values")!; + SerializationRecord[] valuesRecords = valuesRecord.GetArray(); + Assert.Equal(3, valuesRecords.Length); + Assert.All(valuesRecords, Assert.Null); } [Fact] diff --git a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/HashTableTests.cs b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/HashTableTests.cs index dad76abff8e91..f0b3b5adf83b8 100644 --- a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/HashTableTests.cs +++ b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/HashTableTests.cs @@ -72,8 +72,10 @@ public void HashTable_CustomComparer() Assert.Equal("System.Collections.Hashtable", systemClass.TypeName.FullName); Assert.Equal("System.OrdinalComparer", systemClass.GetClassRecord("Comparer")!.TypeName.FullName); Assert.Equal("System.Resources.Extensions.Tests.FormattedObject.HashtableTests+CustomHashCodeProvider", systemClass.GetClassRecord("HashCodeProvider")!.TypeName.FullName); - Assert.True(systemClass.GetSerializationRecord("Keys") is SZArrayRecord); - Assert.True(systemClass.GetSerializationRecord("Values") is SZArrayRecord); + Assert.True(systemClass.GetSerializationRecord("Keys") is SZArrayRecord); + Assert.Equal(SerializationRecordType.ArraySingleObject, systemClass.GetSerializationRecord("Keys").RecordType); + Assert.True(systemClass.GetSerializationRecord("Values") is SZArrayRecord); + Assert.Equal(SerializationRecordType.ArraySingleObject, systemClass.GetSerializationRecord("Values").RecordType); } [Serializable] diff --git a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/ListTests.cs b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/ListTests.cs index 6d43e04498dd3..cb3fdc2927641 100644 --- a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/ListTests.cs +++ b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/ListTests.cs @@ -19,7 +19,8 @@ public void BinaryFormattedObject_ParseEmptyArrayList() VerifyArrayList((ClassRecord)format[format.RootRecord.Id]); - Assert.True(format[((ClassRecord)format.RootRecord).GetArrayRecord("_items").Id] is SZArrayRecord); + Assert.True(format[((ClassRecord)format.RootRecord).GetArrayRecord("_items").Id] is SZArrayRecord); + Assert.Equal(SerializationRecordType.ArraySingleObject, format[((ClassRecord)format.RootRecord).GetArrayRecord("_items").Id].RecordType); } private static void VerifyArrayList(ClassRecord systemClass) @@ -28,7 +29,8 @@ private static void VerifyArrayList(ClassRecord systemClass) Assert.Equal(typeof(ArrayList).FullName, systemClass.TypeName.FullName); Assert.Equal(["_items", "_size", "_version"], systemClass.MemberNames); - Assert.True(systemClass.GetSerializationRecord("_items") is SZArrayRecord); + Assert.True(systemClass.GetSerializationRecord("_items") is SZArrayRecord); + Assert.Equal(SerializationRecordType.ArraySingleObject, systemClass.GetSerializationRecord("_items").RecordType); } [Theory] @@ -43,9 +45,9 @@ public void BinaryFormattedObject_ParsePrimitivesArrayList(object value) ClassRecord listRecord = (ClassRecord)format[format.RootRecord.Id]; VerifyArrayList(listRecord); - SZArrayRecord array = (SZArrayRecord)format[listRecord.GetArrayRecord("_items").Id]; + SZArrayRecord array = (SZArrayRecord)format[listRecord.GetArrayRecord("_items").Id]; - Assert.Equal(new[] { value }, array.GetArray().Take(listRecord.GetInt32("_size"))); + Assert.Equal(value, ((PrimitiveTypeRecord)array.GetArray().Take(listRecord.GetInt32("_size")).Single()).Value); } [Fact] @@ -59,8 +61,8 @@ public void BinaryFormattedObject_ParseStringArrayList() ClassRecord listRecord = (ClassRecord)format[format.RootRecord.Id]; VerifyArrayList(listRecord); - SZArrayRecord array = (SZArrayRecord)format[listRecord.GetArrayRecord("_items").Id]; - Assert.Equal(new object[] { "JarJar" }, array.GetArray().Take(listRecord.GetInt32("_size"))); + SZArrayRecord array = (SZArrayRecord)format[listRecord.GetArrayRecord("_items").Id]; + Assert.Equal("JarJar", ((PrimitiveTypeRecord)array.GetArray().Take(listRecord.GetInt32("_size")).Single()).Value); } public static TheoryData ArrayList_Primitive_Data => new()