From 05b2bdfca06e9dae71005df35fa561a00e6e19d1 Mon Sep 17 00:00:00 2001 From: Jocelyn Schreppler Date: Mon, 4 Mar 2024 12:05:21 -0500 Subject: [PATCH 1/2] validate stream length --- .../Shared/StructuredMessageDecodingStream.cs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs index 806a968cf4ec7..cfcea4a942e16 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs @@ -89,9 +89,12 @@ public override long Position } public StructuredMessageDecodingStream( - Stream innerStream) + Stream innerStream, + long? expectedStreamLength = default) { Argument.AssertNotNull(innerStream, nameof(innerStream)); + + _innerStreamLength = expectedStreamLength ?? -1; _innerBufferedStream = new BufferedStream(innerStream); // Assumes stream will be structured message 1.0. Will validate this when consuming stream. @@ -366,9 +369,19 @@ private int ProcessStreamHeader(ReadOnlySpan span) { StructuredMessage.V1_0.ReadStreamHeader( span.Slice(0, _streamHeaderLength), - out _innerStreamLength, + out long streamLength, out _flags, out _totalSegments); + + if (_innerStreamLength > 0 && streamLength != _innerStreamLength) + { + throw Errors.InvalidStructuredMessage("Unexpected message size."); + } + else + { + _innerStreamLength = streamLength; + } + if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) { _segmentFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0; From b817460e196f9ea563b22520d121ecd57dd854c3 Mon Sep 17 00:00:00 2001 From: Jocelyn Schreppler Date: Mon, 4 Mar 2024 12:31:20 -0500 Subject: [PATCH 2/2] tests --- .../StructuredMessageDecodingStreamTests.cs | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs index 15c36f19d8763..daf691114e989 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs @@ -227,6 +227,46 @@ public void BadStreamWrongSegmentNum() Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } + [Test] + [Combinatorial] + public async Task BadStreamWrongContentLength( + [Values(-1, 1)] int difference, + [Values(true, false)] bool lengthProvided) + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + BinaryPrimitives.WriteInt64LittleEndian( + new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), + encodedData.Length + difference); + + Stream decodingStream = new StructuredMessageDecodingStream( + new MemoryStream(encodedData), + lengthProvided ? (long?)encodedData.Length : default); + + // manual try/catch with tiny buffer to validate the proccess failed mid-stream rather than the end + const int copyBufferSize = 4; + bool caught = false; + try + { + await CopyStream(decodingStream, Stream.Null, copyBufferSize); + } + catch (CopyStreamException ex) + { + caught = true; + if (lengthProvided) + { + Assert.That(ex.TotalCopied, Is.EqualTo(0)); + } + else + { + Assert.That(ex.TotalCopied, Is.EqualTo(originalData.Length)); + } + } + Assert.That(caught); + } + [Test] public void BadStreamMissingExpectedStreamFooter() {