diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs index 45e3f75e2a49c..7299b5e07b7f6 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs @@ -1552,7 +1552,7 @@ ValueTask> Factory(long offset, bool force forceStructuredMessage, async, cancellationToken); - async ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)> StructuredMessageFactory( + async ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData)> StructuredMessageFactory( long offset, bool async, CancellationToken cancellationToken) { Response result = await Factory(offset, forceStructuredMessage: true, async, cancellationToken).ConfigureAwait(false); @@ -1561,11 +1561,12 @@ ValueTask> Factory(long offset, bool force Stream stream; if (response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) { - (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = StructuredMessageDecodingStream.WrapStream( + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = StructuredMessageDecodingStream.WrapStream( response.Value.Content, response.Value.Details.ContentLength); stream = new StructuredMessageDecodingRetriableStream( decodingStream, decodedData, + StructuredMessage.Flags.StorageCrc64, startOffset => StructuredMessageFactory(startOffset, async: false, cancellationToken) .EnsureCompleted(), async startOffset => await StructuredMessageFactory(startOffset, async: true, cancellationToken) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StorageCrc64Composer.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StorageCrc64Composer.cs index ab6b76d78a87e..307ff23b21144 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StorageCrc64Composer.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StorageCrc64Composer.cs @@ -12,22 +12,52 @@ namespace Azure.Storage /// internal static class StorageCrc64Composer { - public static Memory Compose(params (byte[] Crc64, long OriginalDataLength)[] partitions) + public static byte[] Compose(params (byte[] Crc64, long OriginalDataLength)[] partitions) + => Compose(partitions.AsEnumerable()); + + public static byte[] Compose(IEnumerable<(byte[] Crc64, long OriginalDataLength)> partitions) { - return Compose(partitions.AsEnumerable()); + ulong result = Compose(partitions.Select(tup => (BitConverter.ToUInt64(tup.Crc64, 0), tup.OriginalDataLength))); + return BitConverter.GetBytes(result); } - public static Memory Compose(IEnumerable<(byte[] Crc64, long OriginalDataLength)> partitions) + public static byte[] Compose(params (ReadOnlyMemory Crc64, long OriginalDataLength)[] partitions) + => Compose(partitions.AsEnumerable()); + + public static byte[] Compose(IEnumerable<(ReadOnlyMemory Crc64, long OriginalDataLength)> partitions) { - ulong result = Compose(partitions.Select(tup => (BitConverter.ToUInt64(tup.Crc64, 0), tup.OriginalDataLength))); - return new Memory(BitConverter.GetBytes(result)); +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + ulong result = Compose(partitions.Select(tup => (BitConverter.ToUInt64(tup.Crc64.Span), tup.OriginalDataLength))); +#else + ulong result = Compose(partitions.Select(tup => (System.BitConverter.ToUInt64(tup.Crc64.ToArray(), 0), tup.OriginalDataLength))); +#endif + return BitConverter.GetBytes(result); } + public static byte[] Compose( + ReadOnlySpan leftCrc64, long leftOriginalDataLength, + ReadOnlySpan rightCrc64, long rightOriginalDataLength) + { +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + ulong result = Compose( + (BitConverter.ToUInt64(leftCrc64), leftOriginalDataLength), + (BitConverter.ToUInt64(rightCrc64), rightOriginalDataLength)); +#else + ulong result = Compose( + (BitConverter.ToUInt64(leftCrc64.ToArray(), 0), leftOriginalDataLength), + (BitConverter.ToUInt64(rightCrc64.ToArray(), 0), rightOriginalDataLength)); +#endif + return BitConverter.GetBytes(result); + } + + public static ulong Compose(params (ulong Crc64, long OriginalDataLength)[] partitions) + => Compose(partitions.AsEnumerable()); + public static ulong Compose(IEnumerable<(ulong Crc64, long OriginalDataLength)> partitions) { ulong composedCrc = 0; long composedDataLength = 0; - foreach (var tup in partitions) + foreach ((ulong crc64, long originalDataLength) in partitions) { composedCrc = StorageCrc64Calculator.Concatenate( uInitialCrcAB: 0, @@ -35,9 +65,9 @@ public static ulong Compose(IEnumerable<(ulong Crc64, long OriginalDataLength)> uFinalCrcA: composedCrc, uSizeA: (ulong) composedDataLength, uInitialCrcB: 0, - uFinalCrcB: tup.Crc64, - uSizeB: (ulong)tup.OriginalDataLength); - composedDataLength += tup.OriginalDataLength; + uFinalCrcB: crc64, + uSizeB: (ulong)originalDataLength); + composedDataLength += originalDataLength; } return composedCrc; } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs index 5e31dd4ac0ed8..89d3b0df05bfc 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs @@ -93,22 +93,18 @@ public static IDisposable GetStreamHeaderBytes( #endregion #region StreamFooter + public static int GetStreamFooterSize(Flags flags) + => flags.HasFlag(Flags.StorageCrc64) ? Crc64Length : 0; + public static void ReadStreamFooter( ReadOnlySpan buffer, - Span crc64 = default) + Flags flags, + out ulong crc64) { - int expectedBufferSize = 0; - if (!crc64.IsEmpty) - { - Errors.AssertBufferExactSize(crc64, Crc64Length, nameof(crc64)); - expectedBufferSize += Crc64Length; - } + int expectedBufferSize = GetSegmentFooterSize(flags); Errors.AssertBufferExactSize(buffer, expectedBufferSize, nameof(buffer)); - if (!crc64.IsEmpty) - { - buffer.Slice(0, Crc64Length).CopyTo(crc64); - } + crc64 = flags.HasFlag(Flags.StorageCrc64) ? BinaryPrimitives.ReadUInt64LittleEndian(buffer) : default; } public static int WriteStreamFooter(Span buffer, ReadOnlySpan crc64 = default) @@ -193,22 +189,18 @@ public static IDisposable GetSegmentHeaderBytes( #endregion #region SegmentFooter + public static int GetSegmentFooterSize(Flags flags) + => flags.HasFlag(Flags.StorageCrc64) ? Crc64Length : 0; + public static void ReadSegmentFooter( ReadOnlySpan buffer, - Span crc64 = default) + Flags flags, + out ulong crc64) { - int expectedBufferSize = 0; - if (!crc64.IsEmpty) - { - Errors.AssertBufferExactSize(crc64, Crc64Length, nameof(crc64)); - expectedBufferSize += Crc64Length; - } + int expectedBufferSize = GetSegmentFooterSize(flags); Errors.AssertBufferExactSize(buffer, expectedBufferSize, nameof(buffer)); - if (!crc64.IsEmpty) - { - buffer.Slice(0, Crc64Length).CopyTo(crc64); - } + crc64 = flags.HasFlag(Flags.StorageCrc64) ? BinaryPrimitives.ReadUInt64LittleEndian(buffer) : default; } public static int WriteSegmentFooter(Span buffer, ReadOnlySpan crc64 = default) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs index fe2d6697a4621..22dfaef259972 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Buffers.Binary; using System.Collections.Generic; using System.IO; using System.Linq; @@ -15,21 +16,30 @@ namespace Azure.Storage.Shared; internal class StructuredMessageDecodingRetriableStream : Stream { + public class DecodedData + { + public ulong Crc { get; set; } + } + private readonly Stream _innerRetriable; private long _decodedBytesRead; - private readonly List _decodedDatas; - private readonly Action _onComplete; + private readonly StructuredMessage.Flags _expectedFlags; + private readonly List _decodedDatas; + private readonly Action _onComplete; + + private StorageCrc64HashAlgorithm _totalContentCrc; - private readonly Func _decodingStreamFactory; - private readonly Func> _decodingAsyncStreamFactory; + private readonly Func _decodingStreamFactory; + private readonly Func> _decodingAsyncStreamFactory; public StructuredMessageDecodingRetriableStream( Stream initialDecodingStream, - StructuredMessageDecodingStream.DecodedData initialDecodedData, - Func decodingStreamFactory, - Func> decodingAsyncStreamFactory, - Action onComplete, + StructuredMessageDecodingStream.RawDecodedData initialDecodedData, + StructuredMessage.Flags expectedFlags, + Func decodingStreamFactory, + Func> decodingAsyncStreamFactory, + Action onComplete, ResponseClassifier responseClassifier, int maxRetries) { @@ -37,13 +47,19 @@ public StructuredMessageDecodingRetriableStream( _decodingAsyncStreamFactory = decodingAsyncStreamFactory; _innerRetriable = RetriableStream.Create(initialDecodingStream, StreamFactory, StreamFactoryAsync, responseClassifier, maxRetries); _decodedDatas = new() { initialDecodedData }; + _expectedFlags = expectedFlags; _onComplete = onComplete; + + if (expectedFlags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + _totalContentCrc = StorageCrc64HashAlgorithm.Create(); + } } private Stream StreamFactory(long _) { - long offset = _decodedDatas.Select(d => d.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0).Sum(); - (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = _decodingStreamFactory(offset); + long offset = _decodedDatas.SelectMany(d => d.SegmentCrcs).Select(s => s.SegmentLen).Sum(); + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = _decodingStreamFactory(offset); _decodedDatas.Add(decodedData); FastForwardInternal(decodingStream, _decodedBytesRead - offset, false).EnsureCompleted(); return decodingStream; @@ -51,8 +67,8 @@ private Stream StreamFactory(long _) private async ValueTask StreamFactoryAsync(long _) { - long offset = _decodedDatas.Select(d => d.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0).Sum(); - (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = await _decodingAsyncStreamFactory(offset).ConfigureAwait(false); + long offset = _decodedDatas.SelectMany(d => d.SegmentCrcs).Select(s => s.SegmentLen).Sum(); + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = await _decodingAsyncStreamFactory(offset).ConfigureAwait(false); _decodedDatas.Add(decodedData); await FastForwardInternal(decodingStream, _decodedBytesRead - offset, true).ConfigureAwait(false); return decodingStream; @@ -81,21 +97,41 @@ private static async ValueTask FastForwardInternal(Stream stream, long bytes, bo protected override void Dispose(bool disposing) { - foreach (IDisposable data in _decodedDatas) - { - data.Dispose(); - } _decodedDatas.Clear(); _innerRetriable.Dispose(); } private void OnCompleted() { - StructuredMessageDecodingStream.DecodedData final = new(); - // TODO + DecodedData final = new(); + if (_totalContentCrc != null) + { + final.Crc = ValidateCrc(); + } _onComplete?.Invoke(final); } + private ulong ValidateCrc() + { + using IDisposable _ = ArrayPool.Shared.RentDisposable(StructuredMessage.Crc64Length * 2, out byte[] buf); + Span calculatedBytes = new(buf, 0, StructuredMessage.Crc64Length); + _totalContentCrc.GetCurrentHash(calculatedBytes); + ulong calculated = BinaryPrimitives.ReadUInt64LittleEndian(calculatedBytes); + + ulong reported = _decodedDatas.Count == 1 + ? _decodedDatas.First().TotalCrc.Value + : StorageCrc64Composer.Compose(_decodedDatas.SelectMany(d => d.SegmentCrcs)); + + if (calculated != reported) + { + Span reportedBytes = new(buf, calculatedBytes.Length, StructuredMessage.Crc64Length); + BinaryPrimitives.WriteUInt64LittleEndian(reportedBytes, reported); + throw Errors.ChecksumMismatch(calculatedBytes, reportedBytes); + } + + return calculated; + } + #region Read public override int Read(byte[] buffer, int offset, int count) { @@ -105,6 +141,10 @@ public override int Read(byte[] buffer, int offset, int count) { OnCompleted(); } + else + { + _totalContentCrc?.Append(new ReadOnlySpan(buffer, offset, read)); + } return read; } @@ -116,6 +156,10 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, { OnCompleted(); } + else + { + _totalContentCrc?.Append(new ReadOnlySpan(buffer, offset, read)); + } return read; } @@ -128,6 +172,10 @@ public override int Read(Span buffer) { OnCompleted(); } + else + { + _totalContentCrc?.Append(buffer.Slice(0, read)); + } return read; } @@ -139,6 +187,10 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation { OnCompleted(); } + else + { + _totalContentCrc?.Append(buffer.Span.Slice(0, read)); + } return read; } #endif diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs index 37b15a2245750..439fcab0e80b8 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Buffers.Binary; using System.Collections.Generic; using System.IO; using System.Linq; @@ -38,55 +39,14 @@ namespace Azure.Storage.Shared; /// internal class StructuredMessageDecodingStream : Stream { - internal class DecodedData : IDisposable + internal class RawDecodedData { - private byte[] _crcBackingArray; - - public long? InnerStreamLength { get; private set; } - public int? TotalSegments { get; private set; } - public StructuredMessage.Flags? Flags { get; private set; } - public List<(ReadOnlyMemory SegmentCrc, long SegmentEnd)> SegmentCrcs { get; private set; } - public ReadOnlyMemory TotalCrc { get; private set; } - public bool DecodeCompleted { get; private set; } - - internal void SetStreamHeaderData(int totalSegments, long innerStreamLength, StructuredMessage.Flags flags) - { - TotalSegments = totalSegments; - InnerStreamLength = innerStreamLength; - Flags = flags; - - if (flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) - { - _crcBackingArray = ArrayPool.Shared.Rent((totalSegments + 1) * StructuredMessage.Crc64Length); - SegmentCrcs = new(); - } - } - - internal void ReportSegmentCrc(ReadOnlySpan crc, int segmentNum, long segmentEnd) - { - int offset = (segmentNum - 1) * StructuredMessage.Crc64Length; - crc.CopyTo(new Span(_crcBackingArray, offset, StructuredMessage.Crc64Length)); - SegmentCrcs.Add((new ReadOnlyMemory(_crcBackingArray, offset, StructuredMessage.Crc64Length), segmentEnd)); - } - - internal void ReportTotalCrc(ReadOnlySpan crc) - { - int offset = (TotalSegments.Value) * StructuredMessage.Crc64Length; - crc.CopyTo(new Span(_crcBackingArray, offset, StructuredMessage.Crc64Length)); - TotalCrc = new ReadOnlyMemory(_crcBackingArray, offset, StructuredMessage.Crc64Length); - } - internal void MarkComplete() - { - DecodeCompleted = true; - } - - public void Dispose() - { - if (_crcBackingArray is not null) - { - ArrayPool.Shared.Return(_crcBackingArray); - } - } + public long? InnerStreamLength { get; set; } + public int? TotalSegments { get; set; } + public StructuredMessage.Flags? Flags { get; set; } + public List<(ulong SegmentCrc, long SegmentLen)> SegmentCrcs { get; } = new(); + public ulong? TotalCrc { get; set; } + public bool DecodeCompleted { get; set; } } private enum SMRegion @@ -113,7 +73,7 @@ private enum SMRegion private bool _disposed; - private readonly DecodedData _decodedData; + private readonly RawDecodedData _decodedData; private StorageCrc64HashAlgorithm _totalContentCrc; private StorageCrc64HashAlgorithm _segmentCrc; @@ -139,17 +99,17 @@ public override long Position set => throw new NotSupportedException(); } - public static (Stream DecodedStream, DecodedData DecodedData) WrapStream( + public static (Stream DecodedStream, RawDecodedData DecodedData) WrapStream( Stream innerStream, long? expextedStreamLength = default) { - DecodedData data = new(); + RawDecodedData data = new(); return (new StructuredMessageDecodingStream(innerStream, data, expextedStreamLength), data); } private StructuredMessageDecodingStream( Stream innerStream, - DecodedData decodedData, + RawDecodedData decodedData, long? expectedStreamLength) { Argument.AssertNotNull(innerStream, nameof(innerStream)); @@ -259,7 +219,7 @@ private void AssertDecodeFinished() { throw Errors.InvalidStructuredMessage("Premature end of stream."); } - _decodedData.MarkComplete(); + _decodedData.DecodeCompleted = true; } private long _innerStreamConsumed = 0; @@ -439,7 +399,9 @@ private int ProcessStreamHeader(ReadOnlySpan span) out StructuredMessage.Flags flags, out int totalSegments); - _decodedData.SetStreamHeaderData(totalSegments, streamLength, flags); + _decodedData.InnerStreamLength = streamLength; + _decodedData.Flags = flags; + _decodedData.TotalSegments = totalSegments; if (_expectedInnerStreamLength.HasValue && _expectedInnerStreamLength.Value != streamLength) { @@ -462,20 +424,24 @@ private int ProcessStreamHeader(ReadOnlySpan span) private int ProcessStreamFooter(ReadOnlySpan span) { - int totalProcessed = 0; + int footerLen = StructuredMessage.V1_0.GetStreamFooterSize(_decodedData.Flags.Value); + StructuredMessage.V1_0.ReadStreamFooter( + span.Slice(0, footerLen), + _decodedData.Flags.Value, + out ulong reportedCrc); if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) { - totalProcessed += StructuredMessage.Crc64Length; - ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); - _decodedData.ReportTotalCrc(expected); + _decodedData.TotalCrc = reportedCrc; if (_validateChecksums) { - using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) + using (ArrayPool.Shared.RentDisposable(StructuredMessage.Crc64Length * 2, out byte[] buf)) { + Span calculated = new(buf, 0, StructuredMessage.Crc64Length); _totalContentCrc.GetCurrentHash(calculated); - if (!calculated.SequenceEqual(expected)) + if (BinaryPrimitives.ReadUInt64LittleEndian(calculated) != reportedCrc) { - throw Errors.ChecksumMismatch(calculated, expected); + Span reportedAsBytes = new(buf, calculated.Length, StructuredMessage.Crc64Length); + throw Errors.ChecksumMismatch(calculated, reportedAsBytes); } } } @@ -490,8 +456,8 @@ private int ProcessStreamFooter(ReadOnlySpan span) throw Errors.InvalidStructuredMessage("Missing expected message segments."); } - _decodedData.MarkComplete(); - return totalProcessed; + _decodedData.DecodeCompleted = true; + return footerLen; } private int ProcessSegmentHeader(ReadOnlySpan span) @@ -512,27 +478,31 @@ private int ProcessSegmentHeader(ReadOnlySpan span) private int ProcessSegmentFooter(ReadOnlySpan span) { - int totalProcessed = 0; + int footerLen = StructuredMessage.V1_0.GetSegmentFooterSize(_decodedData.Flags.Value); + StructuredMessage.V1_0.ReadSegmentFooter( + span.Slice(0, footerLen), + _decodedData.Flags.Value, + out ulong reportedCrc); if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) { - totalProcessed += StructuredMessage.Crc64Length; - ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); if (_validateChecksums) { - using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) + using (ArrayPool.Shared.RentDisposable(StructuredMessage.Crc64Length * 2, out byte[] buf)) { + Span calculated = new(buf, 0, StructuredMessage.Crc64Length); _segmentCrc.GetCurrentHash(calculated); _segmentCrc = StorageCrc64HashAlgorithm.Create(); - if (!calculated.SequenceEqual(expected)) + if (BinaryPrimitives.ReadUInt64LittleEndian(calculated) != reportedCrc) { - throw Errors.ChecksumMismatch(calculated, expected); + Span reportedAsBytes = new(buf, calculated.Length, StructuredMessage.Crc64Length); + throw Errors.ChecksumMismatch(calculated, reportedAsBytes); } } } - _decodedData.ReportSegmentCrc(expected, _currentSegmentNum, _decodedContentConsumed); + _decodedData.SegmentCrcs.Add((reportedCrc, _currentSegmentContentLength)); } _currentRegion = _currentSegmentNum == _decodedData.TotalSegments ? SMRegion.StreamFooter : SMRegion.SegmentHeader; - return totalProcessed; + return footerLen; } #endregion diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs index 1414f4ec80076..828c41179bba3 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs @@ -26,7 +26,7 @@ public override void OnSendingRequest(HttpMessage message) { byte[] encodedContent; byte[] underlyingContent; - StructuredMessageDecodingStream.DecodedData decodedData; + StructuredMessageDecodingStream.RawDecodedData decodedData; using (MemoryStream ms = new()) { message.Request.Content.WriteTo(ms, default); diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs index 39d2a5566b5ff..a0f9158040b11 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -2,12 +2,14 @@ // Licensed under the MIT License. using System; +using System.Buffers.Binary; using System.IO; using System.Threading; using System.Threading.Tasks; using Azure.Core; using Azure.Storage.Shared; using Azure.Storage.Test.Shared; +using Microsoft.Diagnostics.Tracing.Parsers.AspNet; using Moq; using NUnit.Framework; @@ -39,7 +41,7 @@ public async ValueTask UninterruptedStream() // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream using (Stream src = new MemoryStream(data)) - using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream(src, new(), default, default, default, default, 1)) + using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream(src, new(), default, default, default, default, default, 1)) using (Stream dst = new MemoryStream(dest)) { await retriableSrc.CopyToInternal(dst, Async, default); @@ -61,12 +63,16 @@ public async Task Interrupt_DataIntact([Values(true, false)] bool multipleInterr byte[] dest = new byte[data.Length]; // Mock a decoded data for the mocked StructuredMessageDecodingStream - StructuredMessageDecodingStream.DecodedData initialDecodedData = new(); - initialDecodedData.SetStreamHeaderData(segments, data.Length, StructuredMessage.Flags.StorageCrc64); + StructuredMessageDecodingStream.RawDecodedData initialDecodedData = new() + { + TotalSegments = segments, + InnerStreamLength = data.Length, + Flags = StructuredMessage.Flags.StorageCrc64 + }; // for test purposes, initialize a DecodedData, since we are not actively decoding in this test - initialDecodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + initialDecodedData.SegmentCrcs.Add((BinaryPrimitives.ReadUInt64LittleEndian(r.NextBytesInline(StructuredMessage.Crc64Length)), segmentLen)); - (Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData) Factory(long offset, bool faulty) + (Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData) Factory(long offset, bool faulty) { Stream stream = new MemoryStream(data, (int)offset, data.Length - (int)offset); if (faulty) @@ -74,10 +80,14 @@ public async Task Interrupt_DataIntact([Values(true, false)] bool multipleInterr stream = new FaultyStream(stream, interruptPos, 1, new Exception(), () => { }); } // Mock a decoded data for the mocked StructuredMessageDecodingStream - StructuredMessageDecodingStream.DecodedData decodedData = new(); - decodedData.SetStreamHeaderData(segments, data.Length, StructuredMessage.Flags.StorageCrc64); + StructuredMessageDecodingStream.RawDecodedData decodedData = new() + { + TotalSegments = segments, + InnerStreamLength = data.Length, + Flags = StructuredMessage.Flags.StorageCrc64, + }; // for test purposes, initialize a DecodedData, since we are not actively decoding in this test - decodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + initialDecodedData.SegmentCrcs.Add((BinaryPrimitives.ReadUInt64LittleEndian(r.NextBytesInline(StructuredMessage.Crc64Length)), segmentLen)); return (stream, decodedData); } @@ -87,8 +97,9 @@ public async Task Interrupt_DataIntact([Values(true, false)] bool multipleInterr using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream( faultySrc, initialDecodedData, + default, offset => Factory(offset, multipleInterrupts), - offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData)>(Factory(offset, multipleInterrupts)), null, AllExceptionsRetry().Object, int.MaxValue)) @@ -112,10 +123,14 @@ public async Task Interrupt_AppropriateRewind() Random r = new(); // Mock a decoded data for the mocked StructuredMessageDecodingStream - StructuredMessageDecodingStream.DecodedData initialDecodedData = new(); - initialDecodedData.SetStreamHeaderData(segments, segments * segmentLen, StructuredMessage.Flags.StorageCrc64); + StructuredMessageDecodingStream.RawDecodedData initialDecodedData = new() + { + TotalSegments = segments, + InnerStreamLength = segments * segmentLen, + Flags = StructuredMessage.Flags.StorageCrc64, + }; // By the time of interrupt, there will be one segment reported - initialDecodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + initialDecodedData.SegmentCrcs.Add((BinaryPrimitives.ReadUInt64LittleEndian(r.NextBytesInline(StructuredMessage.Crc64Length)), segmentLen)); Mock mock = new(MockBehavior.Strict); mock.SetupGet(s => s.CanRead).Returns(true); @@ -158,8 +173,9 @@ public async Task Interrupt_AppropriateRewind() Stream retriableSrc = new StructuredMessageDecodingRetriableStream( faultySrc, initialDecodedData, + default, offset => (mock.Object, new()), - offset => new(Task.FromResult((mock.Object, new StructuredMessageDecodingStream.DecodedData()))), + offset => new(Task.FromResult((mock.Object, new StructuredMessageDecodingStream.RawDecodedData()))), null, AllExceptionsRetry().Object, 1); @@ -200,7 +216,7 @@ public async Task Interrupt_ProperDecode([Values(true, false)] bool multipleInte byte[] data = r.NextBytesInline(segments * Constants.KB).ToArray(); byte[] dest = new byte[data.Length]; - (Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData) Factory(long offset, bool faulty) + (Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData) Factory(long offset, bool faulty) { Stream stream = new MemoryStream(data, (int)offset, data.Length - (int)offset); stream = new StructuredMessageEncodingStream(stream, segmentLen, StructuredMessage.Flags.StorageCrc64); @@ -211,12 +227,13 @@ public async Task Interrupt_ProperDecode([Values(true, false)] bool multipleInte return StructuredMessageDecodingStream.WrapStream(stream); } - (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = Factory(0, true); + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = Factory(0, true); using Stream retriableSrc = new StructuredMessageDecodingRetriableStream( decodingStream, decodedData, + default, offset => Factory(offset, multipleInterrupts), - offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData)>(Factory(offset, multipleInterrupts)), null, AllExceptionsRetry().Object, int.MaxValue); diff --git a/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs b/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs index 6977bb4c6b374..60d3eebe2ab7b 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs +++ b/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs @@ -2287,7 +2287,7 @@ async ValueTask> Factory(long offset, bool async } return response; } - async ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)> StructuredMessageFactory( + async ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData)> StructuredMessageFactory( long offset, bool async, CancellationToken cancellationToken) { Response result = await Factory(offset, async, cancellationToken).ConfigureAwait(false); @@ -2296,11 +2296,12 @@ async ValueTask> Factory(long offset, bool async if (initialResponse.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) { - (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = StructuredMessageDecodingStream.WrapStream( + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = StructuredMessageDecodingStream.WrapStream( initialResponse.Value.Content, initialResponse.Value.ContentLength); initialResponse.Value.Content = new StructuredMessageDecodingRetriableStream( decodingStream, decodedData, + StructuredMessage.Flags.StorageCrc64, startOffset => StructuredMessageFactory(startOffset, async: false, cancellationToken) .EnsureCompleted(), async startOffset => await StructuredMessageFactory(startOffset, async: true, cancellationToken)