diff --git a/Snappier.Benchmarks/VarIntEncodingRead.cs b/Snappier.Benchmarks/VarIntEncodingRead.cs index 9e0fbf9..9f621ab 100644 --- a/Snappier.Benchmarks/VarIntEncodingRead.cs +++ b/Snappier.Benchmarks/VarIntEncodingRead.cs @@ -18,18 +18,10 @@ public void GlobalSetup() VarIntEncoding.Write(_source, Value); } - [Benchmark(Baseline = true)] - public (int, uint) Previous() - { - var length = VarIntEncoding.ReadSlow(_source, out var result); - - return (length, result); - } - [Benchmark] - public (int, uint) New() + public (int, uint) TryRead() { - var length = VarIntEncoding.Read(_source, out var result); + _ = VarIntEncoding.TryRead(_source, out var result, out var length); return (length, result); } diff --git a/Snappier.Tests/Internal/VarIntEncodingReadTests.cs b/Snappier.Tests/Internal/VarIntEncodingReadTests.cs index 2c98287..87a0014 100644 --- a/Snappier.Tests/Internal/VarIntEncodingReadTests.cs +++ b/Snappier.Tests/Internal/VarIntEncodingReadTests.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using Snappier.Internal; using Xunit; @@ -23,39 +24,73 @@ public static TheoryData TestData() => { 0xFFFFFFFF, [ 0xFF, 0xFF, 0xFF, 0xFF, 0x0F ] }, }; + public static TheoryData IncompleteTestData() => + new() { + { [ 0x80 ] }, + { [ 0xD5 ] }, + { [ 0xFF, 0xFF ] }, + { [ 0xFF, 0xFF ] }, + { [ 0XFF, 0xFF ] }, + { [ 0x80, 0x80 ] }, + { [ 0xD5, 0xAA ] }, + { [ 0x80, 0xDE, 0xBF ] }, + { [ 0x8D, 0xE0, 0xFB, 0xD7 ] }, + { [ 0xFF, 0xFF, 0xFF, 0xFF ] }, + }; + [Theory] [MemberData(nameof(TestData))] - public void Test_Read(uint expected, byte[] input) + public void Test_TryRead(uint expected, byte[] input) { - var length = VarIntEncoding.Read(input, out var result); - Assert.Equal(input.Length, length); + var status = VarIntEncoding.TryRead(input, out var result, out var bytesRead); + Assert.Equal(OperationStatus.Done, status); + Assert.Equal(input.Length, bytesRead); Assert.Equal(expected, result); } [Theory] [MemberData(nameof(TestData))] - public void Test_Read_ZeroPadding(uint expected, byte[] input) + public void Test_TryRead_ZeroPadding(uint expected, byte[] input) { var bytes = new byte[16]; input.AsSpan().CopyTo(bytes); - var length = VarIntEncoding.Read(bytes, out var result); - Assert.Equal(input.Length, length); + var status = VarIntEncoding.TryRead(bytes, out var result, out var bytesRead); + Assert.Equal(OperationStatus.Done, status); + Assert.Equal(input.Length, bytesRead); Assert.Equal(expected, result); } [Theory] [MemberData(nameof(TestData))] - public void Test_Read_OnePadding(uint expected, byte[] input) + public void Test_TryRead_OnePadding(uint expected, byte[] input) { var bytes = new byte[16]; bytes.AsSpan().Fill(0xff); input.AsSpan().CopyTo(bytes); - var length = VarIntEncoding.Read(bytes, out var result); - Assert.Equal(input.Length, length); + var status = VarIntEncoding.TryRead(bytes, out var result, out int bytesRead); + Assert.Equal(OperationStatus.Done, status); + Assert.Equal(input.Length, bytesRead); Assert.Equal(expected, result); } + + [Theory] + [MemberData(nameof(IncompleteTestData))] + public void Test_TryRead_Incomplete(byte[] input) + { + var status = VarIntEncoding.TryRead(input, out _, out var bytesRead); + Assert.Equal(OperationStatus.NeedMoreData, status); + Assert.Equal(0, bytesRead); + } + + [Fact] + public void Test_TryRead_BadData() + { + var status = VarIntEncoding.TryRead([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF], out _, out var bytesRead); + Assert.Equal(OperationStatus.InvalidData, status); + Assert.Equal(0, bytesRead); + } } } diff --git a/Snappier/Internal/VarIntEncoding.Read.cs b/Snappier/Internal/VarIntEncoding.Read.cs index 445309d..daf9984 100644 --- a/Snappier/Internal/VarIntEncoding.Read.cs +++ b/Snappier/Internal/VarIntEncoding.Read.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; #if NET7_0_OR_GREATER using System.Buffers.Binary; @@ -13,34 +14,44 @@ namespace Snappier.Internal { internal static partial class VarIntEncoding { - public static int Read(ReadOnlySpan input, out uint result) + public static uint Read(ReadOnlySpan input, out int bytesRead) + { + if (TryRead(input, out var result, out bytesRead) != OperationStatus.Done) + { + ThrowHelper.ThrowInvalidDataException("Invalid stream length"); + } + + return result; + } + + public static OperationStatus TryRead(ReadOnlySpan input, out uint result, out int bytesRead) { #if NET7_0_OR_GREATER if (Sse2.IsSupported && Bmi2.IsSupported && BitConverter.IsLittleEndian && input.Length >= Vector128.Count) { - return ReadFast(input, out result); + return ReadFast(input, out result, out bytesRead); } #endif - return ReadSlow(input, out result); + return TryReadSlow(input, out result, out bytesRead); } - public static int ReadSlow(ReadOnlySpan input, out uint result) + private static OperationStatus TryReadSlow(ReadOnlySpan input, out uint result, out int bytesRead) { result = 0; int shift = 0; bool foundEnd = false; - int i = 0; - while (input.Length > 0) + bytesRead = 0; + while (input.Length > bytesRead) { - byte c = input[i]; - i += 1; + byte c = input[bytesRead++]; int val = c & 0x7f; if (Helpers.LeftShiftOverflows((byte) val, shift)) { - ThrowHelper.ThrowInvalidDataException("Invalid stream length"); + bytesRead = 0; + return OperationStatus.InvalidData; } result |= (uint)(val << shift); @@ -54,16 +65,18 @@ public static int ReadSlow(ReadOnlySpan input, out uint result) if (shift >= 32) { - ThrowHelper.ThrowInvalidDataException("Invalid stream length"); + bytesRead = 0; + return OperationStatus.InvalidData; } } if (!foundEnd) { - ThrowHelper.ThrowInvalidDataException("Invalid stream length"); + bytesRead = 0; + return OperationStatus.NeedMoreData; } - return shift / 7; + return OperationStatus.Done; } #if NET7_0_OR_GREATER @@ -78,7 +91,7 @@ public static int ReadSlow(ReadOnlySpan input, out uint result) 0xffffffff ]; - private static int ReadFast(ReadOnlySpan input, out uint result) + private static OperationStatus ReadFast(ReadOnlySpan input, out uint result, out int bytesRead) { Debug.Assert(Sse2.IsSupported); Debug.Assert(Bmi2.IsSupported); @@ -86,7 +99,7 @@ private static int ReadFast(ReadOnlySpan input, out uint result) Debug.Assert(BitConverter.IsLittleEndian); var mask = ~Sse2.MoveMask(Vector128.LoadUnsafe(ref MemoryMarshal.GetReference(input))); - var bytesRead = BitOperations.TrailingZeroCount(mask) + 1; + bytesRead = BitOperations.TrailingZeroCount(mask) + 1; uint shuffledBits = Bmi2.X64.IsSupported ? unchecked((uint)Bmi2.X64.ParallelBitExtract(BinaryPrimitives.ReadUInt64LittleEndian(input), 0x7F7F7F7F7Fu)) @@ -101,14 +114,14 @@ private static int ReadFast(ReadOnlySpan input, out uint result) { // Currently, JIT doesn't optimize the bounds check away in the branch above, // but we'll leave it written this way in case JIT improves in the future to avoid - // checking the bounds twice. We could just let it throw an IndexOutOfRangeException, - // but that would be inconsistent with the other code paths. + // checking the bounds twice. - ThrowHelper.ThrowInvalidDataException("Invalid stream length"); result = 0; + bytesRead = 0; + return OperationStatus.InvalidData; } - return bytesRead; + return OperationStatus.Done; } #endif