From 26d8a12f062bb5dbe74b5694e7b9b165097bfb58 Mon Sep 17 00:00:00 2001 From: Jocelyn <41338290+jaschrep-msft@users.noreply.github.com> Date: Wed, 6 Mar 2024 13:13:14 -0500 Subject: [PATCH] Structured Message Decode: Validate Content Length (#42370) * validate stream length * tests --- .../Shared/StructuredMessageDecodingStream.cs | 17 +++++++- .../StructuredMessageDecodingStreamTests.cs | 40 +++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs index 1e2af40a594f2..aa94b8df350d2 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs @@ -89,9 +89,12 @@ public override long Position } public StructuredMessageDecodingStream( - Stream innerStream) + Stream innerStream, + long? expectedStreamLength = default) { Argument.AssertNotNull(innerStream, nameof(innerStream)); + + _innerStreamLength = expectedStreamLength ?? -1; _innerBufferedStream = new BufferedStream(innerStream); // Assumes stream will be structured message 1.0. Will validate this when consuming stream. @@ -366,9 +369,19 @@ private int ProcessStreamHeader(ReadOnlySpan span) { StructuredMessage.V1_0.ReadStreamHeader( span.Slice(0, _streamHeaderLength), - out _innerStreamLength, + out long streamLength, out _flags, out _totalSegments); + + if (_innerStreamLength > 0 && streamLength != _innerStreamLength) + { + throw Errors.InvalidStructuredMessage("Unexpected message size."); + } + else + { + _innerStreamLength = streamLength; + } + if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) { _segmentFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0; diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs index b115253009705..f881a70c8e78f 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs @@ -234,6 +234,46 @@ public void BadStreamWrongSegmentNum() Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } + [Test] + [Combinatorial] + public async Task BadStreamWrongContentLength( + [Values(-1, 1)] int difference, + [Values(true, false)] bool lengthProvided) + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + BinaryPrimitives.WriteInt64LittleEndian( + new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), + encodedData.Length + difference); + + Stream decodingStream = new StructuredMessageDecodingStream( + new MemoryStream(encodedData), + lengthProvided ? (long?)encodedData.Length : default); + + // manual try/catch with tiny buffer to validate the proccess failed mid-stream rather than the end + const int copyBufferSize = 4; + bool caught = false; + try + { + await CopyStream(decodingStream, Stream.Null, copyBufferSize); + } + catch (CopyStreamException ex) + { + caught = true; + if (lengthProvided) + { + Assert.That(ex.TotalCopied, Is.EqualTo(0)); + } + else + { + Assert.That(ex.TotalCopied, Is.EqualTo(originalData.Length)); + } + } + Assert.That(caught); + } + [Test] public void BadStreamMissingExpectedStreamFooter() {