diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs b/sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs index 6e50037782e0d..e3372665928c1 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs @@ -72,6 +72,9 @@ public static ArgumentException CannotDeferTransactionalHashVerification() public static ArgumentException CannotInitializeWriteStreamWithData() => new ArgumentException("Initialized buffer for StorageWriteStream must be empty."); + public static InvalidDataException InvalidStructuredMessage(string optionalMessage = default) + => new InvalidDataException(("Invalid structured message data. " + optionalMessage ?? "").Trim()); + internal static void VerifyStreamPosition(Stream stream, string streamName) { if (stream != null && stream.CanSeek && stream.Length > 0 && stream.Position >= stream.Length) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs index d963224536fb3..5e31dd4ac0ed8 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs @@ -25,7 +25,14 @@ public static class V1_0 public const byte MessageVersionByte = 1; public const int StreamHeaderLength = 13; + public const int StreamHeaderVersionOffset = 0; + public const int StreamHeaderMessageLengthOffset = 1; + public const int StreamHeaderFlagsOffset = 9; + public const int StreamHeaderSegmentCountOffset = 11; + public const int SegmentHeaderLength = 10; + public const int SegmentHeaderNumOffset = 0; + public const int SegmentHeaderContentLengthOffset = 2; #region Stream Header public static void ReadStreamHeader( @@ -35,13 +42,13 @@ public static void ReadStreamHeader( out int totalSegments) { Errors.AssertBufferExactSize(buffer, 13, nameof(buffer)); - if (buffer[0] != 1) + if (buffer[StreamHeaderVersionOffset] != 1) { throw new InvalidDataException("Unrecognized version of structured message."); } - messageLength = (long)BinaryPrimitives.ReadUInt64LittleEndian(buffer.Slice(1, 8)); - flags = (Flags)BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(9, 2)); - totalSegments = BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(11, 2)); + messageLength = (long)BinaryPrimitives.ReadUInt64LittleEndian(buffer.Slice(StreamHeaderMessageLengthOffset, 8)); + flags = (Flags)BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(StreamHeaderFlagsOffset, 2)); + totalSegments = BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(StreamHeaderSegmentCountOffset, 2)); } public static int WriteStreamHeader( diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs index ef6e1af60604d..806a968cf4ec7 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs @@ -38,6 +38,15 @@ namespace Azure.Storage.Shared; /// internal class StructuredMessageDecodingStream : Stream { + private enum SMRegion + { + StreamHeader, + StreamFooter, + SegmentHeader, + SegmentFooter, + SegmentContent, + } + private readonly Stream _innerBufferedStream; private byte[] _metadataBuffer = ArrayPool.Shared.Rent(Constants.KB); @@ -48,9 +57,12 @@ internal class StructuredMessageDecodingStream : Stream private int _streamFooterLength; private int _segmentHeaderLength; private int _segmentFooterLength; + private int _totalSegments; + private long _innerStreamLength; private StructuredMessage.Flags _flags; + private bool _processedFooter = false; private bool _disposed; private StorageCrc64HashAlgorithm _totalContentCrc; @@ -76,19 +88,6 @@ public override long Position set => throw new NotSupportedException(); } - #region Position - private long _innerStreamLength; - - private enum SMRegion - { - StreamHeader, - StreamFooter, - SegmentHeader, - SegmentFooter, - SegmentContent, - } - #endregion - public StructuredMessageDecodingStream( Stream innerStream) { @@ -116,9 +115,15 @@ public override int Read(byte[] buf, int offset, int count) do { read = _innerBufferedStream.Read(buf, offset, count); + _innerStreamConsumed += read; decodedRead = Decode(new Span(buf, offset, read)); } while (decodedRead <= 0 && read > 0); + if (read <= 0) + { + AssertDecodeFinished(); + } + return decodedRead; } @@ -129,9 +134,15 @@ public override async Task ReadAsync(byte[] buf, int offset, int count, Can do { read = await _innerBufferedStream.ReadAsync(buf, offset, count, cancellationToken).ConfigureAwait(false); + _innerStreamConsumed += read; decodedRead = Decode(new Span(buf, offset, read)); } while (decodedRead <= 0 && read > 0); + if (read <= 0) + { + AssertDecodeFinished(); + } + return decodedRead; } @@ -143,9 +154,15 @@ public override int Read(Span buf) do { read = _innerBufferedStream.Read(buf); + _innerStreamConsumed += read; decodedRead = Decode(buf.Slice(0, read)); } while (decodedRead <= 0 && read > 0); + if (read <= 0) + { + AssertDecodeFinished(); + } + return decodedRead; } @@ -156,15 +173,31 @@ public override async ValueTask ReadAsync(Memory buf, CancellationTok do { read = await _innerBufferedStream.ReadAsync(buf).ConfigureAwait(false); + _innerStreamConsumed += read; decodedRead = Decode(buf.Slice(0, read).Span); } while (decodedRead <= 0 && read > 0); + if (read <= 0) + { + AssertDecodeFinished(); + } + return decodedRead; } #endif - private SMRegion _currentRegion; - private int _currentSegmentNum; + private void AssertDecodeFinished() + { + if (_streamFooterLength > 0 && !_processedFooter) + { + throw Errors.InvalidStructuredMessage("Missing or incomplete trailer."); + } + _processedFooter = true; + } + + private long _innerStreamConsumed = 0; + private SMRegion _currentRegion = SMRegion.StreamHeader; + private int _currentSegmentNum = 0; private long _currentSegmentContentLength; private long _currentSegmentContentRemaining; private long CurrentRegionLength => _currentRegion switch @@ -363,6 +396,17 @@ private int ProcessStreamFooter(ReadOnlySpan span) } } } + + if (_innerStreamConsumed != _innerStreamLength) + { + throw Errors.InvalidStructuredMessage("Unexpected message size."); + } + if (_currentSegmentNum != _totalSegments) + { + throw Errors.InvalidStructuredMessage("Missing expected message segments."); + } + + _processedFooter = true; return totalProcessed; } @@ -375,7 +419,7 @@ private int ProcessSegmentHeader(ReadOnlySpan span) _currentSegmentContentRemaining = _currentSegmentContentLength; if (newSegNum != _currentSegmentNum + 1) { - throw new InvalidDataException("Unexpected segment number in structured message."); + throw Errors.InvalidStructuredMessage("Unexpected segment number in structured message."); } _currentSegmentNum = newSegNum; _currentRegion = SMRegion.SegmentContent; diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs index 04a69d68169e3..15c36f19d8763 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs @@ -2,6 +2,8 @@ // Licensed under the MIT License. using System; +using System.Buffers.Binary; +using System.Dynamic; using System.IO; using System.Linq; using System.Threading; @@ -40,40 +42,63 @@ public StructuredMessageDecodingStreamTests(ReadMethod method) Method = method; } - private async ValueTask CopyStream(Stream source, Stream destination, int bufferSize = 81920) // number default for CopyTo impl + private class CopyStreamException : Exception + { + public long TotalCopied { get; } + + public CopyStreamException(Exception inner, long totalCopied) + : base($"Failed read after {totalCopied}-many bytes.", inner) + { + TotalCopied = totalCopied; + } + } + private async ValueTask CopyStream(Stream source, Stream destination, int bufferSize = 81920) // number default for CopyTo impl { byte[] buf = new byte[bufferSize]; int read; - switch (Method) + long totalRead = 0; + try { - case ReadMethod.SyncArray: - while ((read = source.Read(buf, 0, bufferSize)) > 0) - { - destination.Write(buf, 0, read); - } - break; - case ReadMethod.AsyncArray: - while ((read = await source.ReadAsync(buf, 0, bufferSize)) > 0) - { - await destination.WriteAsync(buf, 0, read); - } - break; + switch (Method) + { + case ReadMethod.SyncArray: + while ((read = source.Read(buf, 0, bufferSize)) > 0) + { + totalRead += read; + destination.Write(buf, 0, read); + } + break; + case ReadMethod.AsyncArray: + while ((read = await source.ReadAsync(buf, 0, bufferSize)) > 0) + { + totalRead += read; + await destination.WriteAsync(buf, 0, read); + } + break; #if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER - case ReadMethod.SyncSpan: - while ((read = source.Read(new Span(buf))) > 0) - { - destination.Write(new Span(buf, 0, read)); - } - break; - case ReadMethod.AsyncMemory: - while ((read = await source.ReadAsync(new Memory(buf))) > 0) - { - await destination.WriteAsync(new Memory(buf, 0, read)); - } - break; + case ReadMethod.SyncSpan: + while ((read = source.Read(new Span(buf))) > 0) + { + totalRead += read; + destination.Write(new Span(buf, 0, read)); + } + break; + case ReadMethod.AsyncMemory: + while ((read = await source.ReadAsync(new Memory(buf))) > 0) + { + totalRead += read; + await destination.WriteAsync(new Memory(buf, 0, read)); + } + break; #endif + } + destination.Flush(); } - destination.Flush(); + catch (Exception ex) + { + throw new CopyStreamException(ex, totalRead); + } + return totalRead; } [Test] @@ -102,6 +127,120 @@ public async Task DecodesData( Assert.That(new Span(decodedData).SequenceEqual(originalData)); } + [Test] + public void BadStreamBadVersion() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + encodedData[0] = byte.MaxValue; + + Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public async Task BadSegmentCrcThrows() + { + const int segmentLength = 256; + Random r = new(); + + byte[] originalData = new byte[2048]; + r.NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentLength, Flags.StorageCrc64); + + const int badBytePos = 1024; + encodedData[badBytePos] = (byte)~encodedData[badBytePos]; + + MemoryStream encodedDataStream = new(encodedData); + Stream decodingStream = new StructuredMessageDecodingStream(encodedDataStream); + + // manual try/catch to validate the proccess failed mid-stream rather than the end + const int copyBufferSize = 4; + bool caught = false; + try + { + await CopyStream(decodingStream, Stream.Null, copyBufferSize); + } + catch (CopyStreamException ex) + { + caught = true; + Assert.That(ex.TotalCopied, Is.LessThanOrEqualTo(badBytePos)); + } + Assert.That(caught); + } + + [Test] + public void BadStreamCrcThrows() + { + const int segmentLength = 256; + Random r = new(); + + byte[] originalData = new byte[2048]; + r.NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentLength, Flags.StorageCrc64); + + encodedData[originalData.Length - 1] = (byte)~encodedData[originalData.Length - 1]; + + Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public void BadStreamWrongContentLength() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + BinaryPrimitives.WriteInt64LittleEndian(new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), 123456789L); + + Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public void BadStreamWrongSegmentCount() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + BinaryPrimitives.WriteInt16LittleEndian(new Span(encodedData, V1_0.StreamHeaderSegmentCountOffset, 2), 123); + + Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public void BadStreamWrongSegmentNum() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + BinaryPrimitives.WriteInt16LittleEndian( + new Span(encodedData, V1_0.StreamHeaderLength + V1_0.SegmentHeaderNumOffset, 2), 123); + + Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public void BadStreamMissingExpectedStreamFooter() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + byte[] brokenData = new byte[encodedData.Length - Crc64Length]; + new Span(encodedData, 0, encodedData.Length - Crc64Length).CopyTo(brokenData); + + Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(brokenData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + [Test] public void NoSeek() { diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs new file mode 100644 index 0000000000000..633233db2e73c --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Azure.Storage.Shared; +using NUnit.Framework; +using static Azure.Storage.Shared.StructuredMessage; + +namespace Azure.Storage.Tests +{ + [TestFixture(ReadMethod.SyncArray)] + [TestFixture(ReadMethod.AsyncArray)] +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + [TestFixture(ReadMethod.SyncSpan)] + [TestFixture(ReadMethod.AsyncMemory)] +#endif + public class StructuredMessageStreamRoundtripTests + { + // Cannot just implement as passthru in the stream + // Must test each one + public enum ReadMethod + { + SyncArray, + AsyncArray, +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + SyncSpan, + AsyncMemory +#endif + } + + public ReadMethod Method { get; } + + public StructuredMessageStreamRoundtripTests(ReadMethod method) + { + Method = method; + } + + private class CopyStreamException : Exception + { + public long TotalCopied { get; } + + public CopyStreamException(Exception inner, long totalCopied) + : base($"Failed read after {totalCopied}-many bytes.", inner) + { + TotalCopied = totalCopied; + } + } + private async ValueTask CopyStream(Stream source, Stream destination, int bufferSize = 81920) // number default for CopyTo impl + { + byte[] buf = new byte[bufferSize]; + int read; + long totalRead = 0; + try + { + switch (Method) + { + case ReadMethod.SyncArray: + while ((read = source.Read(buf, 0, bufferSize)) > 0) + { + totalRead += read; + destination.Write(buf, 0, read); + } + break; + case ReadMethod.AsyncArray: + while ((read = await source.ReadAsync(buf, 0, bufferSize)) > 0) + { + totalRead += read; + await destination.WriteAsync(buf, 0, read); + } + break; +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + case ReadMethod.SyncSpan: + while ((read = source.Read(new Span(buf))) > 0) + { + totalRead += read; + destination.Write(new Span(buf, 0, read)); + } + break; + case ReadMethod.AsyncMemory: + while ((read = await source.ReadAsync(new Memory(buf))) > 0) + { + totalRead += read; + await destination.WriteAsync(new Memory(buf, 0, read)); + } + break; +#endif + } + destination.Flush(); + } + catch (Exception ex) + { + throw new CopyStreamException(ex, totalRead); + } + return totalRead; + } + + [Test] + [Pairwise] + public async Task RoundTrip( + [Values(2048, 2005)] int dataLength, + [Values(default, 512)] int? seglen, + [Values(8 * Constants.KB, 512, 530, 3)] int readLen, + [Values(true, false)] bool useCrc) + { + int segmentLength = seglen ?? int.MaxValue; + Flags flags = useCrc ? Flags.StorageCrc64 : Flags.None; + + byte[] originalData = new byte[dataLength]; + new Random().NextBytes(originalData); + + byte[] roundtripData; + using (MemoryStream source = new(originalData)) + using (StructuredMessageEncodingStream encode = new(source, segmentLength, flags)) + using (StructuredMessageDecodingStream decode = new(encode)) + using (MemoryStream dest = new()) + { + await CopyStream(source, dest, readLen); + roundtripData = dest.ToArray(); + } + + Assert.That(originalData.SequenceEqual(roundtripData)); + } + } +}