Skip to content

Commit

Permalink
[NRBF] Comments and bug fixes from internal code review (dotnet#107735)
Browse files Browse the repository at this point in the history
* copy comments and asserts from Levis internal code review

* apply Levis suggestion: don't store Array.MaxLength as a const, as it may change in the future

* add missing and fix some of the existing comments

* first bug fix: SerializationRecord.TypeNameMatches should throw ArgumentNullException for null Type argument

* second bug fix: SerializationRecord.TypeNameMatches should know the difference between SZArray and single-dimension, non-zero offset arrays (example: int[] and int[*])

* third bug fix: don't cast bytes to booleans

* fourth bug fix: don't cast bytes to DateTimes

* add one test case that I've forgot in previous PR
  • Loading branch information
adamsitnik authored and sirntar committed Sep 30, 2024
1 parent 1f08360 commit aebd108
Show file tree
Hide file tree
Showing 28 changed files with 392 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

namespace System.Formats.Nrbf;

// See [MS-NRBF] Sec. 2.7 for more information.
// https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/ca3ad2bc-777b-413a-a72a-9ba6ced76bc3

[Flags]
internal enum AllowedRecordTypes : uint
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ namespace System.Formats.Nrbf;
[DebuggerDisplay("{ArrayType}, rank={Rank}")]
internal readonly struct ArrayInfo
{
internal const int MaxArrayLength = 2147483591; // Array.MaxLength
#if NET8_0_OR_GREATER
internal static int MaxArrayLength => Array.MaxLength; // dynamic lookup in case the value changes in a future runtime
#else
internal const int MaxArrayLength = 2147483591; // hardcode legacy Array.MaxLength for downlevel runtimes
#endif

internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1)
{
Expand Down Expand Up @@ -47,7 +51,7 @@ internal static int ParseValidArrayLength(BinaryReader reader)
{
int length = reader.ReadInt32();

if (length is < 0 or > MaxArrayLength)
if (length < 0 || length > MaxArrayLength)
{
ThrowHelper.ThrowInvalidValue(length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand Down Expand Up @@ -54,6 +55,7 @@ public override TypeName TypeName
}

int nullCount = ((NullsRecord)actual).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
result[resultIndex++] = null;
Expand All @@ -63,6 +65,8 @@ public override TypeName TypeName
}
}

Debug.Assert(resultIndex == result.Length, "We should have traversed the entirety of the newly created array.");

return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand Down Expand Up @@ -33,13 +34,15 @@ public override TypeName TypeName
{
object?[] values = new object?[Length];

for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++)
int valueIndex = 0;
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
{
SerializationRecord record = Records[recordIndex];

int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0;
if (nullCount == 0)
{
// "new object[] { <SELF> }" is special cased because it allows for storing reference to itself.
values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id)
? values // a reference to self, and a way to get StackOverflow exception ;)
: record.GetValue();
Expand All @@ -59,6 +62,8 @@ public override TypeName TypeName
while (nullCount > 0);
}

Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array.");

return values;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,32 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
return (List<T>)(object)DecodeDecimals(reader, count);
}

// char[] has a unique representation in NRBF streams. Typical strings are transcoded
// to UTF-8 and prefixed with the number of bytes in the UTF-8 representation. char[]
// is also serialized as UTF-8, but it is instead prefixed with the number of chars
// in the UTF-16 representation, not the number of bytes in the UTF-8 representation.
// This number doesn't directly precede the UTF-8 contents in the NRBF stream; it's
// instead contained within the ArrayInfo structure (passed to this method as the
// 'count' argument).
//
// The practical consequence of this is that we don't actually know how many UTF-8
// bytes we need to consume in order to ensure we've read 'count' chars. We know that
// an n-length UTF-16 string turns into somewhere between [n .. 3n] UTF-8 bytes.
// The best we can do is that when reading an n-element char[], we'll ensure that
// there are at least n bytes remaining in the input stream. We'll still need to
// account for that even with this check, we might hit EOF before fully populating
// the char[]. But from a safety perspective, it does appropriately limit our
// allocations to be proportional to the amount of data present in the input stream,
// which is a sufficient defense against DoS.

long requiredBytes = count;
if (typeof(T) != typeof(char)) // the input is UTF8
if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
{
// We can't assume DateTime as represented by the runtime is 8 bytes.
// The only assumption we can make is that it's 8 bytes on the wire.
requiredBytes *= 8;
}
else if (typeof(T) != typeof(char))
{
requiredBytes *= Unsafe.SizeOf<T>();
}
Expand All @@ -79,6 +103,10 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
{
return (T[])(object)reader.ParseChars(count);
}
else if (typeof(T) == typeof(TimeSpan) || typeof(T) == typeof(DateTime))
{
return DecodeTime(reader, count);
}

// It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
T[] result = new T[count];
Expand Down Expand Up @@ -130,8 +158,7 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
}
#endif
}
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)
|| typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double))
{
Span<long> span = MemoryMarshal.Cast<T, long>(result);
#if NET
Expand All @@ -145,6 +172,21 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
}
}

if (typeof(T) == typeof(bool))
{
// See DontCastBytesToBooleans test to see what could go wrong.
bool[] booleans = (bool[])(object)result;
resultAsBytes = MemoryMarshal.AsBytes<T>(result);
for (int i = 0; i < booleans.Length; i++)
{
// We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this.
if (resultAsBytes[i] != 0) // it can be any byte different than 0
{
booleans[i] = true; // set it to 1 in explicit way
}
}
}

return result;
}

Expand All @@ -158,8 +200,34 @@ private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
return values;
}

private static T[] DecodeTime(BinaryReader reader, int count)
{
T[] values = new T[count];
for (int i = 0; i < values.Length; i++)
{
if (typeof(T) == typeof(DateTime))
{
values[i] = (T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64());
}
else if (typeof(T) == typeof(TimeSpan))
{
values[i] = (T)(object)new TimeSpan(reader.ReadInt64());
}
else
{
throw new InvalidOperationException();
}
}

return values;
}

private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int count)
{
// The count arg could originate from untrusted input, so we shouldn't
// pass it as-is to the ctor's capacity arg. We'll instead rely on
// List<T>.Add's O(1) amortization to keep the entire loop O(count).

List<T> values = new List<T>(Math.Min(count, 4));
for (int i = 0; i < count; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand Down Expand Up @@ -47,7 +48,8 @@ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetA
{
string?[] values = new string?[Length];

for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++)
int valueIndex = 0;
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
{
SerializationRecord record = Records[recordIndex];

Expand All @@ -73,6 +75,7 @@ record = memberReference.GetReferencedRecord();
}

int nullCount = ((NullsRecord)record).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
values[valueIndex++] = null;
Expand All @@ -81,6 +84,8 @@ record = memberReference.GetReferencedRecord();
while (nullCount > 0);
}

Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array.");

return values;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand Down Expand Up @@ -84,6 +85,10 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:

// Recursion depth is bounded by the depth of arrayType, which is
// a trustworthy Type instance. Don't need to worry about stack overflow.

ArrayRecord nestedArrayRecord = (ArrayRecord)record;
Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls);
array.SetValue(nestedArray, resultIndex++);
Expand All @@ -97,6 +102,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
}

int nullCount = ((NullsRecord)item).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
array.SetValue(null, resultIndex++);
Expand All @@ -110,6 +116,8 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
}
}

Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array.");

return array;
}

Expand All @@ -122,6 +130,7 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
bool isRectangular = arrayType is BinaryArrayType.Rectangular;

// It is an arbitrary limit in the current CoreCLR type loader.
// Don't change this value without reviewing the loop a few lines below.
const int MaxSupportedArrayRank = 32;

if (rank < 1 || rank > MaxSupportedArrayRank
Expand All @@ -132,18 +141,26 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
}

int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32
long totalElementCount = 1;
long totalElementCount = 1; // to avoid integer overflow during the multiplication below
for (int i = 0; i < lengths.Length; i++)
{
lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
totalElementCount *= lengths[i];

// n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]"
// but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But
// that's the same behavior that newarr and Array.CreateInstance exhibit, so at least
// we're consistent.

if (totalElementCount > ArrayInfo.MaxArrayLength)
{
ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
}
}

// Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so
// we don't need to read the NRBF stream 'LowerBounds' field here.

MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap);
ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank);

Expand Down Expand Up @@ -186,6 +203,9 @@ private static Type MapElementType(Type arrayType, out bool isClassRecord)
Type elementType = arrayType;
int arrayNestingDepth = 0;

// Loop iteration counts are bound by the nesting depth of arrayType,
// which is a trustworthy input. No DoS concerns.

while (elementType.IsArray)
{
elementType = elementType.GetElementType()!;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ internal static ClassInfo Decode(BinaryReader reader)

// Use Dictionary instead of List so that searching for member IDs by name
// is O(n) instead of O(m * n), where m = memberCount and n = memberNameLength,
// in degenerate cases.
// in degenerate cases. Since memberCount may be hostile, don't allow it to be
// used as the initial capacity in the collection instance.
Dictionary<string, int> memberNames = new(StringComparer.Ordinal);
for (int i = 0; i < memberCount; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace System.Formats.Nrbf;

/// <summary>
/// Identifies a class by it's name and library id.
/// Identifies a class by its name and library id.
/// </summary>
/// <remarks>
/// ClassTypeInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/844b24dd-9f82-426e-9b98-05334307a239">[MS-NRBF] 2.1.1.8</see>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ internal bool ShouldBeRepresentedAsArrayOfClassRecords()
{
// This library tries to minimize the number of concepts the users need to learn to use it.
// Since SZArrays are most common, it provides an SZArrayRecord<T> abstraction.
// Every other array (jagged, multi-dimensional etc) is represented using SZArrayRecord.
// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord.
// The goal of this method is to determine whether given array can be represented as SZArrayRecord<ClassRecord>.

(BinaryType binaryType, object? additionalInfo) = Infos[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,5 @@ internal NextInfo(AllowedRecordTypes allowed, SerializationRecord parent,
internal PrimitiveType PrimitiveType { get; }

internal NextInfo With(AllowedRecordTypes allowed, PrimitiveType primitiveType)
=> allowed == Allowed && primitiveType == PrimitiveType
? this // previous record was of the same type
: new(allowed, Parent, Stack, primitiveType);
=> new(allowed, Parent, Stack, primitiveType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static class NrbfDecoder
// The header consists of:
// - a byte that describes the record type (SerializationRecordType.SerializedStreamHeader)
// - four 32 bit integers:
// - root Id (every value is valid)
// - root Id (every value except of 0 is valid)
// - header Id (value is ignored)
// - major version, it has to be equal 1.
// - minor version, it has to be equal 0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ public PayloadOptions() { }
/// </summary>
/// <value><see langword="true" /> if truncated type names should be reassembled; otherwise, <see langword="false" />.</value>
/// <remarks>
/// <para>
/// Example:
/// TypeName: "Namespace.TypeName`1[[Namespace.GenericArgName"
/// LibraryName: "AssemblyName]]"
/// Is combined into "Namespace.TypeName`1[[Namespace.GenericArgName, AssemblyName]]"
/// </para>
/// <para>
/// Setting this to <see langword="true" /> can render <see cref="NrbfDecoder"/> susceptible to Denial of Service
/// attacks when parsing or handling malicious input.
/// </para>
/// <para>The default value is <see langword="false" />.</para>
/// </remarks>
public bool UndoTruncatedTypeNames { get; set; }
}
Loading

0 comments on commit aebd108

Please sign in to comment.