From d68a423b94631c4dc8e80ede45e213f24ba2fa0d Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 13 Sep 2024 09:11:08 +0200 Subject: [PATCH] address code review feedback --- .../Nrbf/ArraySinglePrimitiveRecord.cs | 51 ++++++++++++++----- .../Formats/Nrbf/SerializationRecordId.cs | 8 +-- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs index 15dce6b2781aa..a13507b97015a 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs @@ -72,7 +72,13 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c // which is a sufficient defense against DoS. long requiredBytes = count; - if (typeof(T) != typeof(char)) + if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)) + { + // We can't assume DateTime as represented by the runtime is 8 bytes. + // The only assumption we can make is that it's 8 bytes on the wire. + requiredBytes *= 8; + } + else if (typeof(T) != typeof(char)) { requiredBytes *= Unsafe.SizeOf(); } @@ -97,6 +103,10 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c { return (T[])(object)reader.ParseChars(count); } + else if (typeof(T) == typeof(TimeSpan) || typeof(T) == typeof(DateTime)) + { + return DecodeTime(reader, count); + } // It's safe to pre-allocate, as we have ensured there is enough bytes in the stream. T[] result = new T[count]; @@ -148,8 +158,7 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c } #endif } - else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double) - || typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)) + else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)) { Span span = MemoryMarshal.Cast(result); #if NET @@ -167,24 +176,16 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c { // See DontCastBytesToBooleans test to see what could go wrong. bool[] booleans = (bool[])(object)result; + resultAsBytes = MemoryMarshal.AsBytes(result); for (int i = 0; i < booleans.Length; i++) { - if (booleans[i]) // it can be any byte different than 0 + // We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this. + if (resultAsBytes[i] != 0) // it can be any byte different than 0 { booleans[i] = true; // set it to 1 in explicit way } } } - else if (typeof(T) == typeof(DateTime)) - { - DateTime[] dateTimes = (DateTime[])(object)result; - Span span = MemoryMarshal.Cast(result); - for (int i = 0; i < dateTimes.Length; i++) - { - // The value needs to get validated. - dateTimes[i] = BinaryReaderExtensions.CreateDateTimeFromData(span[i]); - } - } return result; } @@ -199,6 +200,28 @@ private static List DecodeDecimals(BinaryReader reader, int count) return values; } + private static T[] DecodeTime(BinaryReader reader, int count) + { + T[] values = new T[count]; + for (int i = 0; i < values.Length; i++) + { + if (typeof(T) == typeof(DateTime)) + { + values[i] = (T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64()); + } + else if (typeof(T) == typeof(TimeSpan)) + { + values[i] = (T)(object)new TimeSpan(reader.ReadInt64()); + } + else + { + throw new InvalidOperationException(); + } + } + + return values; + } + private static List DecodeFromNonSeekableStream(BinaryReader reader, int count) { // The count arg could originate from untrusted input, so we shouldn't diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs index 2a8592862e327..a8318cb72d11d 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs @@ -32,11 +32,11 @@ internal static SerializationRecordId Decode(BinaryReader reader) int id = reader.ReadInt32(); // Many object ids are required to be positive. See: - // - https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb - // - https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/eb503ca5-e1f6-4271-a7ee-c4ca38d07996 - // - https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/7fcf30e1-4ad4-4410-8f1a-901a4a1ea832 (for library id) + // - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb + // - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/eb503ca5-e1f6-4271-a7ee-c4ca38d07996 + // - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/7fcf30e1-4ad4-4410-8f1a-901a4a1ea832 (for library id) // - // Exception: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/0a192be0-58a1-41d0-8a54-9c91db0ab7bf may be negative + // Exception: https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/0a192be0-58a1-41d0-8a54-9c91db0ab7bf may be negative // The problem is that input generated with FormatterTypeStyle.XsdString ends up generating negative Ids anyway. // That information is not reflected in payload in anyway, so we just always allow for negative Ids.