Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NRBF] Address issues discovered by Threat Model #106629

Merged
merged 13 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRec
internal ArrayRecord() { }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan<int> Lengths { get; }
public virtual long FlattenedLength { get; }
public int Rank { get { throw null; } }
[System.Diagnostics.CodeAnalysis.RequiresDynamicCode("The code for an array of the specified type might not be available.")]
public System.Array GetArray(System.Type expectedArrayType, bool allowNulls = true) { throw null; }
Expand Down
11 changes: 4 additions & 7 deletions src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on what we have discussed and what @bartonjs wrote here: #103713 (comment)

I believe the type names and assembly names should not be provided in the exception messages.

Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,23 @@
<data name="Serialization_UnexpectedNullRecordCount" xml:space="preserve">
<value>Unexpected Null Record count.</value>
</data>
<data name="Serialization_MaxArrayLength" xml:space="preserve">
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this resource was not being used for a while

<value>The serialized array length ({0}) was larger than the configured limit {1}.</value>
</data>
<data name="NotSupported_RecordType" xml:space="preserve">
<value>{0} Record Type is not supported by design.</value>
</data>
<data name="Serialization_InvalidReference" xml:space="preserve">
<value>Invalid member reference.</value>
</data>
<data name="Serialization_InvalidTypeName" xml:space="preserve">
<value>Invalid type name: `{0}`.</value>
<value>Invalid type name.</value>
</data>
<data name="Serialization_TypeMismatch" xml:space="preserve">
<value>Expected the array to be of type {0}, but its element type was {1}.</value>
</data>
<data name="Serialization_InvalidTypeOrAssemblyName" xml:space="preserve">
<value>Invalid type or assembly name: `{0},{1}`.</value>
<value>Invalid type or assembly name.</value>
</data>
<data name="Serialization_DuplicateMemberName" xml:space="preserve">
<value>Duplicate member name: `{0}`.</value>
<value>Duplicate member name.</value>
</data>
<data name="Argument_NonSeekableStream" xml:space="preserve">
<value>Stream does not support seeking.</value>
Expand All @@ -160,7 +157,7 @@
<value>Only arrays with zero offsets are supported.</value>
</data>
<data name="Serialization_InvalidAssemblyName" xml:space="preserve">
<value>Invalid assembly name: `{0}`.</value>
<value>Invalid assembly name.</value>
</data>
<data name="Serialization_InvalidFormat" xml:space="preserve">
<value>Invalid format.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ internal readonly struct ArrayInfo
internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1)
{
Id = id;
TotalElementsCount = totalElementsCount;
FlattenedLength = totalElementsCount;
ArrayType = arrayType;
Rank = rank;
}

internal SerializationRecordId Id { get; }

internal long TotalElementsCount { get; }
internal long FlattenedLength { get; }

internal BinaryArrayType ArrayType { get; }

internal int Rank { get; }

internal int GetSZArrayLength()
{
Debug.Assert(TotalElementsCount <= MaxArrayLength);
return (int)TotalElementsCount;
Debug.Assert(FlattenedLength <= MaxArrayLength);
return (int)FlattenedLength;
}

internal static ArrayInfo Decode(BinaryReader reader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public abstract class ArrayRecord : SerializationRecord
private protected ArrayRecord(ArrayInfo arrayInfo)
{
ArrayInfo = arrayInfo;
ValuesToRead = arrayInfo.TotalElementsCount;
ValuesToRead = arrayInfo.FlattenedLength;
}

/// <summary>
Expand All @@ -27,6 +27,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo)
/// <value>A buffer of integers that represent the number of elements in every dimension.</value>
public abstract ReadOnlySpan<int> Lengths { get; }

/// <summary>
/// When overridden in a derived class, gets the total number of all elements in every dimension.
/// </summary>
/// <value>A number that represent the total number of all elements in every dimension.</value>
public virtual long FlattenedLength => ArrayInfo.FlattenedLength;

/// <summary>
/// Gets the rank of the array.
/// </summary>
Expand All @@ -44,7 +50,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo)

internal long ValuesToRead { get; private protected set; }

private protected ArrayInfo ArrayInfo { get; }
internal ArrayInfo ArrayInfo { get; }

internal bool IsJagged
=> ArrayInfo.ArrayType == BinaryArrayType.Jagged
// It is possible to have binary array records have an element type of array without being marked as jagged.
|| TypeName.GetElementType().IsArray;
Comment on lines +57 to +58
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JeremyKuhne this handles the scenario you have mentioned offline (there is also a test for that). Thank you again for pointing this out!


/// <summary>
/// Allocates an array and fills it with the data provided in the serialized records (in case of primitive types like <see cref="string"/> or <see cref="int"/>) or the serialized records themselves.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand All @@ -27,19 +28,38 @@ internal sealed class BinaryArrayRecord : ArrayRecord
];

private TypeName? _typeName;
private long _totalElementsCount;

private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
: base(arrayInfo)
{
MemberTypeInfo = memberTypeInfo;
Values = [];
// We need to parse all elements of the jagged array to obtain total elements count.
_totalElementsCount = -1;
}

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;

/// <inheritdoc/>
public override ReadOnlySpan<int> Lengths => new int[1] { Length };

/// <inheritdoc/>
public override long FlattenedLength
{
get
{
if (_totalElementsCount < 0)
{
_totalElementsCount = IsJagged
? GetJaggedArrayFlattenedLength(this)
: ArrayInfo.FlattenedLength;
}

return _totalElementsCount;
}
}

public override TypeName TypeName
=> _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);

Expand Down Expand Up @@ -157,6 +177,65 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
: new BinaryArrayRecord(arrayInfo, memberTypeInfo);
}

private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord)
{
long result = 0;
Queue<BinaryArrayRecord>? jaggedArrayRecords = null;

do
{
if (jaggedArrayRecords is not null)
{
jaggedArrayRecord = jaggedArrayRecords.Dequeue();
}

Debug.Assert(jaggedArrayRecord.IsJagged);

// In theory somebody could create a payload that would represent
// a very nested array with total elements count > long.MaxValue.
// That is why this method is using checked arithmetic.
result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves

foreach (object value in jaggedArrayRecord.Values)
{
if (value is not SerializationRecord record)
{
continue;
}

if (record.RecordType == SerializationRecordType.MemberReference)
{
record = ((MemberReferenceRecord)record).GetReferencedRecord();
}

switch (record.RecordType)
{
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:
case SerializationRecordType.BinaryArray:
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
if (nestedArrayRecord.IsJagged)
{
(jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
}
else
{
// Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion,
// just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value.
result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength);
}
break;
default:
break;
}
}
}
while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0);

return result;
}

private protected override void AddValue(object value) => Values.Add(value);

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ internal static BinaryLibraryRecord Decode(BinaryReader reader, PayloadOptions o
}
else if (!options.UndoTruncatedTypeNames)
{
ThrowHelper.ThrowInvalidAssemblyName(rawName);
ThrowHelper.ThrowInvalidAssemblyName();
}

return new BinaryLibraryRecord(id, rawName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal static ClassInfo Decode(BinaryReader reader)
continue;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateMemberName, memberName));
ThrowHelper.ThrowDuplicateMemberName();
}

return new ClassInfo(id, typeName, memberNames);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr

if (sizeOfSingleValue > 0)
{
long size = arrayInfo.TotalElementsCount * sizeOfSingleValue;
long size = arrayInfo.FlattenedLength * sizeOfSingleValue;
bool? isDataAvailable = reader.IsDataAvailable(size);
if (isDataAvailable.HasValue)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,30 @@

namespace System.Formats.Nrbf.Utils;

// The exception messages do not contain member/type/assembly names on purpose,
// as it's most likely corrupted/tampered/malicious data.
internal static class ThrowHelper
{
internal static void ThrowInvalidValue(object value)
internal static void ThrowDuplicateMemberName()
=> throw new SerializationException(SR.Serialization_DuplicateMemberName);

internal static void ThrowInvalidValue(int value)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, value));

internal static void ThrowInvalidReference()
=> throw new SerializationException(SR.Serialization_InvalidReference);

internal static void ThrowInvalidTypeName(string name)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, name));
internal static void ThrowInvalidTypeName()
=> throw new SerializationException(SR.Serialization_InvalidTypeName);

internal static void ThrowUnexpectedNullRecordCount()
=> throw new SerializationException(SR.Serialization_UnexpectedNullRecordCount);

internal static void ThrowMaxArrayLength(long limit, long actual)
=> throw new SerializationException(SR.Format(SR.Serialization_MaxArrayLength, actual, limit));

internal static void ThrowArrayContainedNulls()
=> throw new SerializationException(SR.Serialization_ArrayContainedNulls);

internal static void ThrowInvalidAssemblyName(string rawName)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidAssemblyName, rawName));
internal static void ThrowInvalidAssemblyName()
=> throw new SerializationException(SR.Serialization_InvalidAssemblyName);

internal static void ThrowInvalidFormat()
=> throw new SerializationException(SR.Serialization_InvalidFormat);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ internal static TypeName ParseNonSystemClassRecordTypeName(this string rawName,

if (typeName is null)
{
throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeOrAssemblyName, rawName, libraryRecord.RawLibraryName));
throw new SerializationException(SR.Serialization_InvalidTypeOrAssemblyName);
}

if (typeName.AssemblyName is null)
Expand Down Expand Up @@ -168,7 +168,7 @@ private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyNa
else
{
// BinaryFormatter can not serialize pointers or references.
ThrowHelper.ThrowInvalidTypeName(typeName.FullName);
ThrowHelper.ThrowInvalidTypeName();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ private void Test<T>(int size, bool canSeek)
SZArrayRecord<T> arrayRecord = (SZArrayRecord<T>)NrbfDecoder.Decode(stream);

Assert.Equal(size, arrayRecord.Length);
Assert.Equal(size, arrayRecord.FlattenedLength);
T?[] output = arrayRecord.GetArray();
Assert.Equal(input, output);
Assert.Same(output, arrayRecord.GetArray());
Expand Down
Loading