Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NRBF] Address issues discovered by Threat Model #106629

Merged
merged 13 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRec
internal ArrayRecord() { }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan<int> Lengths { get; }
public virtual long FlattenedLength { get; }
public int Rank { get { throw null; } }
[System.Diagnostics.CodeAnalysis.RequiresDynamicCode("The code for an array of the specified type might not be available.")]
public System.Array GetArray(System.Type expectedArrayType, bool allowNulls = true) { throw null; }
Expand Down
11 changes: 4 additions & 7 deletions src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on what we have discussed and what @bartonjs wrote here: #103713 (comment)

I believe the type names and assembly names should not be provided in the exception messages.

Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,23 @@
<data name="Serialization_UnexpectedNullRecordCount" xml:space="preserve">
<value>Unexpected Null Record count.</value>
</data>
<data name="Serialization_MaxArrayLength" xml:space="preserve">
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this resource was not being used for a while

<value>The serialized array length ({0}) was larger than the configured limit {1}.</value>
</data>
<data name="NotSupported_RecordType" xml:space="preserve">
<value>{0} Record Type is not supported by design.</value>
</data>
<data name="Serialization_InvalidReference" xml:space="preserve">
<value>Member reference was pointing to a record of unexpected type.</value>
</data>
<data name="Serialization_InvalidTypeName" xml:space="preserve">
<value>Invalid type name: `{0}`.</value>
<value>Invalid type name.</value>
</data>
<data name="Serialization_TypeMismatch" xml:space="preserve">
<value>Expected the array to be of type {0}, but its element type was {1}.</value>
</data>
<data name="Serialization_InvalidTypeOrAssemblyName" xml:space="preserve">
<value>Invalid type or assembly name: `{0},{1}`.</value>
<value>Invalid type or assembly name.</value>
</data>
<data name="Serialization_DuplicateMemberName" xml:space="preserve">
<value>Duplicate member name: `{0}`.</value>
<value>Duplicate member name.</value>
</data>
<data name="Argument_NonSeekableStream" xml:space="preserve">
<value>Stream does not support seeking.</value>
Expand All @@ -160,6 +157,6 @@
<value>Only arrays with zero offsets are supported.</value>
</data>
<data name="Serialization_InvalidAssemblyName" xml:space="preserve">
<value>Invalid assembly name: `{0}`.</value>
<value>Invalid assembly name.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ internal readonly struct ArrayInfo
internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1)
{
Id = id;
TotalElementsCount = totalElementsCount;
FlattenedLength = totalElementsCount;
ArrayType = arrayType;
Rank = rank;
}

internal SerializationRecordId Id { get; }

internal long TotalElementsCount { get; }
internal long FlattenedLength { get; }

internal BinaryArrayType ArrayType { get; }

internal int Rank { get; }

internal int GetSZArrayLength()
{
Debug.Assert(TotalElementsCount <= MaxArrayLength);
return (int)TotalElementsCount;
Debug.Assert(FlattenedLength <= MaxArrayLength);
return (int)FlattenedLength;
}

internal static ArrayInfo Decode(BinaryReader reader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public abstract class ArrayRecord : SerializationRecord
private protected ArrayRecord(ArrayInfo arrayInfo)
{
ArrayInfo = arrayInfo;
ValuesToRead = arrayInfo.TotalElementsCount;
ValuesToRead = arrayInfo.FlattenedLength;
}

/// <summary>
Expand All @@ -27,6 +27,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo)
/// <value>A buffer of integers that represent the number of elements in every dimension.</value>
public abstract ReadOnlySpan<int> Lengths { get; }

/// <summary>
/// When overridden in a derived class, gets the total number of all elements in every dimension.
/// </summary>
/// <value>A number that represent the total number of all elements in every dimension.</value>
public virtual long FlattenedLength => ArrayInfo.FlattenedLength;

/// <summary>
/// Gets the rank of the array.
/// </summary>
Expand All @@ -46,6 +52,11 @@ private protected ArrayRecord(ArrayInfo arrayInfo)

private protected ArrayInfo ArrayInfo { get; }

internal bool IsJagged
=> ArrayInfo.ArrayType == BinaryArrayType.Jagged
// It is possible to have binary array records have an element type of array without being marked as jagged.
|| TypeName.GetElementType().IsArray;
Comment on lines +57 to +58
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JeremyKuhne this handles the scenario you have mentioned offline (there is also a test for that). Thank you again for pointing this out!


/// <summary>
/// Allocates an array and fills it with the data provided in the serialized records (in case of primitive types like <see cref="string"/> or <see cref="int"/>) or the serialized records themselves.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand All @@ -27,19 +28,38 @@ internal sealed class BinaryArrayRecord : ArrayRecord
];

private TypeName? _typeName;
private long _totalElementsCount;

private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
: base(arrayInfo)
{
MemberTypeInfo = memberTypeInfo;
Values = [];
// We need to parse all elements of the jagged array to obtain total elements count.
_totalElementsCount = -1;
}

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;

/// <inheritdoc/>
public override ReadOnlySpan<int> Lengths => new int[1] { Length };

/// <inheritdoc/>
public override long FlattenedLength
{
get
{
if (_totalElementsCount < 0)
{
_totalElementsCount = IsJagged
? GetJaggedArrayFlattenedLength(this)
: ArrayInfo.FlattenedLength;
}

return _totalElementsCount;
}
}

public override TypeName TypeName
=> _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);

Expand Down Expand Up @@ -157,6 +177,74 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
: new BinaryArrayRecord(arrayInfo, memberTypeInfo);
}

private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord)
{
long result = 0;
Queue<BinaryArrayRecord>? jaggedArrayRecords = null;

do
{
if (jaggedArrayRecords is not null)
{
jaggedArrayRecord = jaggedArrayRecords.Dequeue();
}

Debug.Assert(jaggedArrayRecord.IsJagged);

foreach (object value in jaggedArrayRecord.Values)
{
object item = value is MemberReferenceRecord referenceRecord
? referenceRecord.GetReferencedRecord()
: value;

if (item is not SerializationRecord record)
{
result++;
continue;
}

switch (record.RecordType)
{
case SerializationRecordType.BinaryArray:
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
if (nestedArrayRecord.IsJagged)
{
(jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
}
else
{
Debug.Assert(nestedArrayRecord is not BinaryArrayRecord, "Ensure lack of recursive call");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can read the assert, but I don't understand it. Why can't a BinaryArrayRecord be nested within another BinaryArrayRecord? Why would that necessarily mean that a recursive call happened?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've eliminated the potential recursive call and added a comment.

checked
{
// In theory somebody could create a payload that would represent
// a very nested array with total elements count > long.MaxValue.
result += nestedArrayRecord.FlattenedLength;
}
}
break;
Copy link
Member

@bartonjs bartonjs Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't counting the arrays themselves. Assuming that's intentional, you should have a comment at the top of the method (or before this switch) that explains what does, and what doesn't count.

object[][] objs = new object[1_000_000][];
objs.Fill(Array.Empty<object>());

Doesn't feel like it should have a smaller count than

object[][] objs = new object[1_000_000][];
objs.Fill(null);

And right now it feels like the first one says 0, and the second says a million.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I can understand not counting the arrays as answering how many int values might come out of an int[][][]; but then it gets weird when the nulls get counted, because they're not int values.

So I guess there's a bit of "what is this number supposed to mean?". The comment can explain that, and then the code can be judged against the comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I've fixed that.

case SerializationRecordType.ObjectNull:
case SerializationRecordType.ObjectNullMultiple256:
case SerializationRecordType.ObjectNullMultiple:
// All nulls need to be included, as it's another form of possible attack.
checked
{
result += ((NullsRecord)item).NullCount;
}
break;
default:
result++;
break;
}
Copy link
Member

@bartonjs bartonjs Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't the arrays already know their lengths? I'd expect FlattenedLength to be a very simple

long length = here.Length;

foreach (element in here.Children)
{
    if (element.IsArray)
    {
        length += element.FlattenedLength;
    }
}

return length;

(or the queue variant)

And that approach would already have counted the null elements. Ensuring that the number of records matches the length is a different problem (one I feel GetArray tries to solve?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bartonjs You are right. This suggestion has greatly simplified the code and solved the other problem (arrays not being included themselves)

}
}
while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0);

return result;
}

private protected override void AddValue(object value) => Values.Add(value);

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ internal static BinaryLibraryRecord Decode(BinaryReader reader, PayloadOptions o
}
else if (!options.UndoTruncatedTypeNames)
{
ThrowHelper.ThrowInvalidAssemblyName(rawName);
ThrowHelper.ThrowInvalidAssemblyName();
}

return new BinaryLibraryRecord(id, rawName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal static ClassInfo Decode(BinaryReader reader)
continue;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateMemberName, memberName));
ThrowHelper.ThrowDuplicateMemberName();
}

return new ClassInfo(id, typeName, memberNames);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
#if DEBUG
_values = new LinkedList<object>();
#else
_values = arrayInfo.TotalElementsCount <= ArrayInfo.MaxArrayLength
_values = arrayInfo.FlattenedLength <= ArrayInfo.MaxArrayLength
? new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()))
: new LinkedList<object>();
#endif
Expand Down Expand Up @@ -181,7 +181,7 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr

if (sizeOfSingleValue > 0)
{
long size = arrayInfo.TotalElementsCount * sizeOfSingleValue;
long size = arrayInfo.FlattenedLength * sizeOfSingleValue;
bool? isDataAvailable = reader.IsDataAvailable(size);
if (isDataAvailable.HasValue)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,30 @@

namespace System.Formats.Nrbf.Utils;

// The exception messages do not contain member/type/assembly names on purpose,
// as it's most likely corrupted/tampered/malicious data.
internal static class ThrowHelper
{
internal static void ThrowInvalidValue(object value)
internal static void ThrowDuplicateMemberName()
=> throw new SerializationException(SR.Serialization_DuplicateMemberName);

internal static void ThrowInvalidValue(int value)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, value));

internal static void ThrowInvalidReference()
=> throw new SerializationException(SR.Serialization_InvalidReference);

internal static void ThrowInvalidTypeName(string name)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, name));
internal static void ThrowInvalidTypeName()
=> throw new SerializationException(SR.Serialization_InvalidTypeName);

internal static void ThrowUnexpectedNullRecordCount()
=> throw new SerializationException(SR.Serialization_UnexpectedNullRecordCount);

internal static void ThrowMaxArrayLength(long limit, long actual)
=> throw new SerializationException(SR.Format(SR.Serialization_MaxArrayLength, actual, limit));

internal static void ThrowArrayContainedNulls()
=> throw new SerializationException(SR.Serialization_ArrayContainedNulls);

internal static void ThrowInvalidAssemblyName(string rawName)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidAssemblyName, rawName));
internal static void ThrowInvalidAssemblyName()
=> throw new SerializationException(SR.Serialization_InvalidAssemblyName);

internal static void ThrowEndOfStreamException()
=> throw new EndOfStreamException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ internal static TypeName ParseNonSystemClassRecordTypeName(this string rawName,

if (typeName is null)
{
throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeOrAssemblyName, rawName, libraryRecord.RawLibraryName));
throw new SerializationException(SR.Serialization_InvalidTypeOrAssemblyName);
}

if (typeName.AssemblyName is null)
Expand Down Expand Up @@ -169,7 +169,7 @@ private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyNa
else
{
// BinaryFormatter can not serialize pointers or references.
ThrowHelper.ThrowInvalidTypeName(typeName.FullName);
ThrowHelper.ThrowInvalidTypeName();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ private void Test<T>(int size, bool canSeek)
SZArrayRecord<T> arrayRecord = (SZArrayRecord<T>)NrbfDecoder.Decode(stream);

Assert.Equal(size, arrayRecord.Length);
Assert.Equal(size, arrayRecord.FlattenedLength);
T?[] output = arrayRecord.GetArray();
Assert.Equal(input, output);
Assert.Same(output, arrayRecord.GetArray());
Expand Down
Loading