From 99c02de3fdf9942e1b0e65ed99c98a653ae732f4 Mon Sep 17 00:00:00 2001 From: Jocelyn <41338290+jaschrep-msft@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:00:25 -0400 Subject: [PATCH] Structured message cherrypick stg96 (#45626) --- ...e.Storage.Blobs.Batch.Samples.Tests.csproj | 1 + .../Azure.Storage.Blobs.Batch.Tests.csproj | 3 +- ...rage.Blobs.ChangeFeed.Samples.Tests.csproj | 3 +- ...zure.Storage.Blobs.ChangeFeed.Tests.csproj | 3 +- .../api/Azure.Storage.Blobs.net6.0.cs | 7 +- .../api/Azure.Storage.Blobs.netstandard2.0.cs | 7 +- .../api/Azure.Storage.Blobs.netstandard2.1.cs | 7 +- sdk/storage/Azure.Storage.Blobs/assets.json | 2 +- .../Azure.Storage.Blobs.Samples.Tests.csproj | 1 + .../src/AppendBlobClient.cs | 45 +- .../src/Azure.Storage.Blobs.csproj | 7 + .../Azure.Storage.Blobs/src/BlobBaseClient.cs | 110 +++- .../src/BlobClientOptions.cs | 2 + .../src/BlobClientSideDecryptor.cs | 2 +- .../src/BlockBlobClient.cs | 92 ++- .../src/Models/BlobDownloadDetails.cs | 8 + .../src/Models/BlobDownloadInfo.cs | 10 + .../src/Models/BlobDownloadStreamingResult.cs | 8 + .../Azure.Storage.Blobs/src/PageBlobClient.cs | 49 +- .../src/PartitionedDownloader.cs | 95 +-- .../Azure.Storage.Blobs/src/autorest.md | 4 +- .../tests/Azure.Storage.Blobs.Tests.csproj | 3 + .../BlobBaseClientTransferValidationTests.cs | 114 ++-- .../tests/ClientSideEncryptionTests.cs | 2 +- .../tests/PartitionedDownloaderTests.cs | 2 +- .../Azure.Storage.Common.Samples.Tests.csproj | 1 + .../src/Shared/ChecksumExtensions.cs | 22 + .../src/Shared/Constants.cs | 9 + .../src/Shared/ContentRange.cs | 18 +- .../src/Shared/ContentRangeExtensions.cs | 14 + .../src/Shared/Errors.Clients.cs | 10 + .../Azure.Storage.Common/src/Shared/Errors.cs | 19 + .../src/Shared/LazyLoadingReadOnlyStream.cs | 40 +- .../src/Shared/PooledMemoryStream.cs | 2 +- .../src/Shared/StorageCrc64Composer.cs | 48 +- .../StorageRequestValidationPipelinePolicy.cs | 29 + .../src/Shared/StorageVersionExtensions.cs | 2 +- .../src/Shared/StreamExtensions.cs | 22 +- .../src/Shared/StructuredMessage.cs | 244 ++++++++ ...tructuredMessageDecodingRetriableStream.cs | 264 +++++++++ .../Shared/StructuredMessageDecodingStream.cs | 542 +++++++++++++++++ .../Shared/StructuredMessageEncodingStream.cs | 545 ++++++++++++++++++ ...redMessagePrecalculatedCrcWrapperStream.cs | 451 +++++++++++++++ .../TransferValidationOptionsExtensions.cs | 7 - .../tests/Azure.Storage.Common.Tests.csproj | 9 + .../tests/Shared/FaultyStream.cs | 13 +- .../Shared/ObserveStructuredMessagePolicy.cs | 85 +++ .../tests/Shared/RequestExtensions.cs | 27 + .../Shared/TamperStreamContentsPolicy.cs | 11 +- .../Shared/TransferValidationTestBase.cs | 325 ++++++++--- ...uredMessageDecodingRetriableStreamTests.cs | 246 ++++++++ .../StructuredMessageDecodingStreamTests.cs | 323 +++++++++++ .../StructuredMessageEncodingStreamTests.cs | 271 +++++++++ .../tests/StructuredMessageHelper.cs | 68 +++ .../StructuredMessageStreamRoundtripTests.cs | 127 ++++ .../tests/StructuredMessageTests.cs | 114 ++++ ...ge.DataMovement.Blobs.Samples.Tests.csproj | 1 + .../Azure.Storage.DataMovement.Blobs.csproj | 1 + .../src/DataMovementBlobsExtensions.cs | 4 +- ...re.Storage.DataMovement.Blobs.Tests.csproj | 5 + ...taMovement.Blobs.Files.Shares.Tests.csproj | 1 + ...Movement.Files.Shares.Samples.Tests.csproj | 3 +- .../src/DataMovementSharesExtensions.cs | 4 +- ...age.DataMovement.Files.Shares.Tests.csproj | 1 + .../tests/Shared/DisposingShare.cs | 2 +- .../src/Azure.Storage.DataMovement.csproj | 2 +- .../Azure.Storage.DataMovement.Tests.csproj | 1 + .../Azure.Storage.Files.DataLake.net6.0.cs | 2 +- ...e.Storage.Files.DataLake.netstandard2.0.cs | 2 +- .../Azure.Storage.Files.DataLake/assets.json | 2 +- ...torage.Files.DataLake.Samples.Tests.csproj | 1 + .../src/Azure.Storage.Files.DataLake.csproj | 5 + .../src/DataLakeFileClient.cs | 43 +- .../src/autorest.md | 4 +- .../Azure.Storage.Files.DataLake.Tests.csproj | 3 + ...taLakeFileClientTransferValidationTests.cs | 5 +- .../api/Azure.Storage.Files.Shares.net6.0.cs | 3 +- ...ure.Storage.Files.Shares.netstandard2.0.cs | 3 +- .../Azure.Storage.Files.Shares/assets.json | 2 +- ....Storage.Files.Shares.Samples.Tests.csproj | 1 + .../src/Azure.Storage.Files.Shares.csproj | 8 +- .../src/Models/ShareFileDownloadInfo.cs | 6 + .../src/ShareErrors.cs | 15 - .../src/ShareFileClient.cs | 165 ++++-- .../src/autorest.md | 4 +- .../Azure.Storage.Files.Shares.Tests.csproj | 1 + .../ShareFileClientTransferValidationTests.cs | 42 +- .../api/Azure.Storage.Queues.net6.0.cs | 4 +- .../Azure.Storage.Queues.netstandard2.0.cs | 4 +- .../Azure.Storage.Queues.netstandard2.1.cs | 4 +- .../Azure.Storage.Queues.Samples.Tests.csproj | 1 + .../tests/Azure.Storage.Queues.Tests.csproj | 1 + 92 files changed, 4446 insertions(+), 405 deletions(-) create mode 100644 sdk/storage/Azure.Storage.Common/src/Shared/ChecksumExtensions.cs create mode 100644 sdk/storage/Azure.Storage.Common/src/Shared/ContentRangeExtensions.cs create mode 100644 sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs create mode 100644 sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs create mode 100644 sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs create mode 100644 sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageEncodingStream.cs create mode 100644 sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessagePrecalculatedCrcWrapperStream.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/StructuredMessageEncodingStreamTests.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/StructuredMessageHelper.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/StructuredMessageTests.cs diff --git a/sdk/storage/Azure.Storage.Blobs.Batch/samples/Azure.Storage.Blobs.Batch.Samples.Tests.csproj b/sdk/storage/Azure.Storage.Blobs.Batch/samples/Azure.Storage.Blobs.Batch.Samples.Tests.csproj index 3dea34a02b7ea..6009a5336b8b9 100644 --- a/sdk/storage/Azure.Storage.Blobs.Batch/samples/Azure.Storage.Blobs.Batch.Samples.Tests.csproj +++ b/sdk/storage/Azure.Storage.Blobs.Batch/samples/Azure.Storage.Blobs.Batch.Samples.Tests.csproj @@ -17,6 +17,7 @@ + PreserveNewest diff --git a/sdk/storage/Azure.Storage.Blobs.Batch/tests/Azure.Storage.Blobs.Batch.Tests.csproj b/sdk/storage/Azure.Storage.Blobs.Batch/tests/Azure.Storage.Blobs.Batch.Tests.csproj index 2b77907e9aaac..286ab317256bf 100644 --- a/sdk/storage/Azure.Storage.Blobs.Batch/tests/Azure.Storage.Blobs.Batch.Tests.csproj +++ b/sdk/storage/Azure.Storage.Blobs.Batch/tests/Azure.Storage.Blobs.Batch.Tests.csproj @@ -23,6 +23,7 @@ + PreserveNewest @@ -42,4 +43,4 @@ - \ No newline at end of file + diff --git a/sdk/storage/Azure.Storage.Blobs.ChangeFeed/samples/Azure.Storage.Blobs.ChangeFeed.Samples.Tests.csproj b/sdk/storage/Azure.Storage.Blobs.ChangeFeed/samples/Azure.Storage.Blobs.ChangeFeed.Samples.Tests.csproj index 7711cae537db6..6f8fcaf6528b3 100644 --- a/sdk/storage/Azure.Storage.Blobs.ChangeFeed/samples/Azure.Storage.Blobs.ChangeFeed.Samples.Tests.csproj +++ b/sdk/storage/Azure.Storage.Blobs.ChangeFeed/samples/Azure.Storage.Blobs.ChangeFeed.Samples.Tests.csproj @@ -1,4 +1,4 @@ - + $(RequiredTargetFrameworks) Microsoft Azure.Storage.Blobs.ChangeFeed client library samples @@ -14,6 +14,7 @@ + diff --git a/sdk/storage/Azure.Storage.Blobs.ChangeFeed/tests/Azure.Storage.Blobs.ChangeFeed.Tests.csproj b/sdk/storage/Azure.Storage.Blobs.ChangeFeed/tests/Azure.Storage.Blobs.ChangeFeed.Tests.csproj index 9682ab15ecd60..8cf13cd60744f 100644 --- a/sdk/storage/Azure.Storage.Blobs.ChangeFeed/tests/Azure.Storage.Blobs.ChangeFeed.Tests.csproj +++ b/sdk/storage/Azure.Storage.Blobs.ChangeFeed/tests/Azure.Storage.Blobs.ChangeFeed.Tests.csproj @@ -17,6 +17,7 @@ + @@ -28,4 +29,4 @@ PreserveNewest - \ No newline at end of file + diff --git a/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.net6.0.cs b/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.net6.0.cs index 25640917de5bb..822d5b41d1404 100644 --- a/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.net6.0.cs +++ b/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.net6.0.cs @@ -51,7 +51,7 @@ public BlobClient(System.Uri blobUri, Azure.Storage.StorageSharedKeyCredential c } public partial class BlobClientOptions : Azure.Core.ClientOptions { - public BlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2024_11_04) { } + public BlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2025_01_05) { } public Azure.Storage.Blobs.Models.BlobAudience? Audience { get { throw null; } set { } } public Azure.Storage.Blobs.Models.CustomerProvidedKey? CustomerProvidedKey { get { throw null; } set { } } public bool EnableTenantDiscovery { get { throw null; } set { } } @@ -522,6 +522,7 @@ public BlobDownloadDetails() { } public long BlobSequenceNumber { get { throw null; } } public Azure.Storage.Blobs.Models.BlobType BlobType { get { throw null; } } public string CacheControl { get { throw null; } } + public byte[] ContentCrc { get { throw null; } } public string ContentDisposition { get { throw null; } } public string ContentEncoding { get { throw null; } } public byte[] ContentHash { get { throw null; } } @@ -567,6 +568,7 @@ internal BlobDownloadInfo() { } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public string ContentType { get { throw null; } } public Azure.Storage.Blobs.Models.BlobDownloadDetails Details { get { throw null; } } + public bool ExpectTrailingDetails { get { throw null; } } public void Dispose() { } } public partial class BlobDownloadOptions @@ -588,6 +590,7 @@ public partial class BlobDownloadStreamingResult : System.IDisposable internal BlobDownloadStreamingResult() { } public System.IO.Stream Content { get { throw null; } } public Azure.Storage.Blobs.Models.BlobDownloadDetails Details { get { throw null; } } + public bool ExpectTrailingDetails { get { throw null; } } public void Dispose() { } } public partial class BlobDownloadToOptions @@ -1850,7 +1853,7 @@ public PageBlobClient(System.Uri blobUri, Azure.Storage.StorageSharedKeyCredenti } public partial class SpecializedBlobClientOptions : Azure.Storage.Blobs.BlobClientOptions { - public SpecializedBlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2024_11_04) : base (default(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion)) { } + public SpecializedBlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2025_01_05) : base (default(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion)) { } public Azure.Storage.ClientSideEncryptionOptions ClientSideEncryption { get { throw null; } set { } } } public static partial class SpecializedBlobExtensions diff --git a/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.netstandard2.0.cs b/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.netstandard2.0.cs index 25640917de5bb..822d5b41d1404 100644 --- a/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.netstandard2.0.cs +++ b/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.netstandard2.0.cs @@ -51,7 +51,7 @@ public BlobClient(System.Uri blobUri, Azure.Storage.StorageSharedKeyCredential c } public partial class BlobClientOptions : Azure.Core.ClientOptions { - public BlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2024_11_04) { } + public BlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2025_01_05) { } public Azure.Storage.Blobs.Models.BlobAudience? Audience { get { throw null; } set { } } public Azure.Storage.Blobs.Models.CustomerProvidedKey? CustomerProvidedKey { get { throw null; } set { } } public bool EnableTenantDiscovery { get { throw null; } set { } } @@ -522,6 +522,7 @@ public BlobDownloadDetails() { } public long BlobSequenceNumber { get { throw null; } } public Azure.Storage.Blobs.Models.BlobType BlobType { get { throw null; } } public string CacheControl { get { throw null; } } + public byte[] ContentCrc { get { throw null; } } public string ContentDisposition { get { throw null; } } public string ContentEncoding { get { throw null; } } public byte[] ContentHash { get { throw null; } } @@ -567,6 +568,7 @@ internal BlobDownloadInfo() { } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public string ContentType { get { throw null; } } public Azure.Storage.Blobs.Models.BlobDownloadDetails Details { get { throw null; } } + public bool ExpectTrailingDetails { get { throw null; } } public void Dispose() { } } public partial class BlobDownloadOptions @@ -588,6 +590,7 @@ public partial class BlobDownloadStreamingResult : System.IDisposable internal BlobDownloadStreamingResult() { } public System.IO.Stream Content { get { throw null; } } public Azure.Storage.Blobs.Models.BlobDownloadDetails Details { get { throw null; } } + public bool ExpectTrailingDetails { get { throw null; } } public void Dispose() { } } public partial class BlobDownloadToOptions @@ -1850,7 +1853,7 @@ public PageBlobClient(System.Uri blobUri, Azure.Storage.StorageSharedKeyCredenti } public partial class SpecializedBlobClientOptions : Azure.Storage.Blobs.BlobClientOptions { - public SpecializedBlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2024_11_04) : base (default(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion)) { } + public SpecializedBlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2025_01_05) : base (default(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion)) { } public Azure.Storage.ClientSideEncryptionOptions ClientSideEncryption { get { throw null; } set { } } } public static partial class SpecializedBlobExtensions diff --git a/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.netstandard2.1.cs b/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.netstandard2.1.cs index 25640917de5bb..822d5b41d1404 100644 --- a/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.netstandard2.1.cs +++ b/sdk/storage/Azure.Storage.Blobs/api/Azure.Storage.Blobs.netstandard2.1.cs @@ -51,7 +51,7 @@ public BlobClient(System.Uri blobUri, Azure.Storage.StorageSharedKeyCredential c } public partial class BlobClientOptions : Azure.Core.ClientOptions { - public BlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2024_11_04) { } + public BlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2025_01_05) { } public Azure.Storage.Blobs.Models.BlobAudience? Audience { get { throw null; } set { } } public Azure.Storage.Blobs.Models.CustomerProvidedKey? CustomerProvidedKey { get { throw null; } set { } } public bool EnableTenantDiscovery { get { throw null; } set { } } @@ -522,6 +522,7 @@ public BlobDownloadDetails() { } public long BlobSequenceNumber { get { throw null; } } public Azure.Storage.Blobs.Models.BlobType BlobType { get { throw null; } } public string CacheControl { get { throw null; } } + public byte[] ContentCrc { get { throw null; } } public string ContentDisposition { get { throw null; } } public string ContentEncoding { get { throw null; } } public byte[] ContentHash { get { throw null; } } @@ -567,6 +568,7 @@ internal BlobDownloadInfo() { } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public string ContentType { get { throw null; } } public Azure.Storage.Blobs.Models.BlobDownloadDetails Details { get { throw null; } } + public bool ExpectTrailingDetails { get { throw null; } } public void Dispose() { } } public partial class BlobDownloadOptions @@ -588,6 +590,7 @@ public partial class BlobDownloadStreamingResult : System.IDisposable internal BlobDownloadStreamingResult() { } public System.IO.Stream Content { get { throw null; } } public Azure.Storage.Blobs.Models.BlobDownloadDetails Details { get { throw null; } } + public bool ExpectTrailingDetails { get { throw null; } } public void Dispose() { } } public partial class BlobDownloadToOptions @@ -1850,7 +1853,7 @@ public PageBlobClient(System.Uri blobUri, Azure.Storage.StorageSharedKeyCredenti } public partial class SpecializedBlobClientOptions : Azure.Storage.Blobs.BlobClientOptions { - public SpecializedBlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2024_11_04) : base (default(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion)) { } + public SpecializedBlobClientOptions(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion version = Azure.Storage.Blobs.BlobClientOptions.ServiceVersion.V2025_01_05) : base (default(Azure.Storage.Blobs.BlobClientOptions.ServiceVersion)) { } public Azure.Storage.ClientSideEncryptionOptions ClientSideEncryption { get { throw null; } set { } } } public static partial class SpecializedBlobExtensions diff --git a/sdk/storage/Azure.Storage.Blobs/assets.json b/sdk/storage/Azure.Storage.Blobs/assets.json index 0facb33e2a026..1994292f7b658 100644 --- a/sdk/storage/Azure.Storage.Blobs/assets.json +++ b/sdk/storage/Azure.Storage.Blobs/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/storage/Azure.Storage.Blobs", - "Tag": "net/storage/Azure.Storage.Blobs_5c382dfb14" + "Tag": "net/storage/Azure.Storage.Blobs_c5174c4663" } diff --git a/sdk/storage/Azure.Storage.Blobs/samples/Azure.Storage.Blobs.Samples.Tests.csproj b/sdk/storage/Azure.Storage.Blobs/samples/Azure.Storage.Blobs.Samples.Tests.csproj index 77fd767c3486c..568dd6cba9516 100644 --- a/sdk/storage/Azure.Storage.Blobs/samples/Azure.Storage.Blobs.Samples.Tests.csproj +++ b/sdk/storage/Azure.Storage.Blobs/samples/Azure.Storage.Blobs.Samples.Tests.csproj @@ -16,6 +16,7 @@ + diff --git a/sdk/storage/Azure.Storage.Blobs/src/AppendBlobClient.cs b/sdk/storage/Azure.Storage.Blobs/src/AppendBlobClient.cs index e70d5e02c82d7..9a110cf8eb13a 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/AppendBlobClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/AppendBlobClient.cs @@ -1242,14 +1242,39 @@ internal async Task> AppendBlockInternal( BlobErrors.VerifyHttpsCustomerProvidedKey(Uri, ClientConfiguration.CustomerProvidedKey); Errors.VerifyStreamPosition(content, nameof(content)); - // compute hash BEFORE attaching progress handler - ContentHasher.GetHashResult hashResult = await ContentHasher.GetHashOrDefaultInternal( - content, - validationOptions, - async, - cancellationToken).ConfigureAwait(false); - - content = content.WithNoDispose().WithProgress(progressHandler); + ContentHasher.GetHashResult hashResult = null; + long contentLength = (content?.Length - content?.Position) ?? 0; + long? structuredContentLength = default; + string structuredBodyType = null; + if (validationOptions != null && + validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && + ClientSideEncryption == null) // don't allow feature combination + { + // report progress in terms of caller bytes, not encoded bytes + structuredContentLength = contentLength; + contentLength = (content?.Length - content?.Position) ?? 0; + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; + content = content.WithNoDispose().WithProgress(progressHandler); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); + contentLength = (content?.Length - content?.Position) ?? 0; + } + else + { + // compute hash BEFORE attaching progress handler + hashResult = await ContentHasher.GetHashOrDefaultInternal( + content, + validationOptions, + async, + cancellationToken).ConfigureAwait(false); + content = content.WithNoDispose().WithProgress(progressHandler); + } ResponseWithHeaders response; @@ -1267,6 +1292,8 @@ internal async Task> AppendBlockInternal( encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, encryptionScope: ClientConfiguration.EncryptionScope, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, ifModifiedSince: conditions?.IfModifiedSince, ifUnmodifiedSince: conditions?.IfUnmodifiedSince, ifMatch: conditions?.IfMatch?.ToString(), @@ -1289,6 +1316,8 @@ internal async Task> AppendBlockInternal( encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, encryptionScope: ClientConfiguration.EncryptionScope, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, ifModifiedSince: conditions?.IfModifiedSince, ifUnmodifiedSince: conditions?.IfUnmodifiedSince, ifMatch: conditions?.IfMatch?.ToString(), diff --git a/sdk/storage/Azure.Storage.Blobs/src/Azure.Storage.Blobs.csproj b/sdk/storage/Azure.Storage.Blobs/src/Azure.Storage.Blobs.csproj index 8b09c620d1654..e29acc40ca38b 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/Azure.Storage.Blobs.csproj +++ b/sdk/storage/Azure.Storage.Blobs/src/Azure.Storage.Blobs.csproj @@ -52,6 +52,8 @@ + + @@ -91,6 +93,11 @@ + + + + + diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs index aa91edb9f6c41..6b95b04c703db 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs @@ -1031,6 +1031,7 @@ private async Task> DownloadInternal( ContentHash = blobDownloadDetails.ContentHash, ContentLength = blobDownloadDetails.ContentLength, ContentType = blobDownloadDetails.ContentType, + ExpectTrailingDetails = blobDownloadStreamingResult.ExpectTrailingDetails, }, response.GetRawResponse()); } #endregion @@ -1547,30 +1548,52 @@ internal virtual async ValueTask> Download // Wrap the response Content in a RetriableStream so we // can return it before it's finished downloading, but still // allow retrying if it fails. - Stream stream = RetriableStream.Create( - response.Value.Content, - startOffset => - StartDownloadAsync( - range, - conditionsWithEtag, - validationOptions, - startOffset, - async, - cancellationToken) - .EnsureCompleted() - .Value.Content, - async startOffset => - (await StartDownloadAsync( - range, - conditionsWithEtag, - validationOptions, - startOffset, - async, - cancellationToken) - .ConfigureAwait(false)) - .Value.Content, - ClientConfiguration.Pipeline.ResponseClassifier, - Constants.MaxReliabilityRetries); + ValueTask> Factory(long offset, bool async, CancellationToken cancellationToken) + => StartDownloadAsync( + range, + conditionsWithEtag, + validationOptions, + offset, + async, + cancellationToken); + async ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData)> StructuredMessageFactory( + long offset, bool async, CancellationToken cancellationToken) + { + Response result = await Factory(offset, async, cancellationToken).ConfigureAwait(false); + return StructuredMessageDecodingStream.WrapStream(result.Value.Content, result.Value.Details.ContentLength); + } + Stream stream; + if (response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) + { + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = StructuredMessageDecodingStream.WrapStream( + response.Value.Content, response.Value.Details.ContentLength); + stream = new StructuredMessageDecodingRetriableStream( + decodingStream, + decodedData, + StructuredMessage.Flags.StorageCrc64, + startOffset => StructuredMessageFactory(startOffset, async: false, cancellationToken) + .EnsureCompleted(), + async startOffset => await StructuredMessageFactory(startOffset, async: true, cancellationToken) + .ConfigureAwait(false), + decodedData => + { + response.Value.Details.ContentCrc = new byte[StructuredMessage.Crc64Length]; + decodedData.Crc.WriteCrc64(response.Value.Details.ContentCrc); + }, + ClientConfiguration.Pipeline.ResponseClassifier, + Constants.MaxReliabilityRetries); + } + else + { + stream = RetriableStream.Create( + response.Value.Content, + startOffset => Factory(startOffset, async: false, cancellationToken) + .EnsureCompleted().Value.Content, + async startOffset => (await Factory(startOffset, async: true, cancellationToken) + .ConfigureAwait(false)).Value.Content, + ClientConfiguration.Pipeline.ResponseClassifier, + Constants.MaxReliabilityRetries); + } stream = stream.WithNoDispose().WithProgress(progressHandler); @@ -1578,7 +1601,11 @@ internal virtual async ValueTask> Download * Buffer response stream and ensure it matches the transactional checksum if any. * Storage will not return a checksum for payload >4MB, so this buffer is capped similarly. * Checksum validation is opt-in, so this buffer is part of that opt-in. */ - if (validationOptions != default && validationOptions.ChecksumAlgorithm != StorageChecksumAlgorithm.None && validationOptions.AutoValidateChecksum) + if (validationOptions != default && + validationOptions.ChecksumAlgorithm != StorageChecksumAlgorithm.None && + validationOptions.AutoValidateChecksum && + // structured message decoding does the validation for us + !response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) { // safe-buffer; transactional hash download limit well below maxInt var readDestStream = new MemoryStream((int)response.Value.Details.ContentLength); @@ -1649,8 +1676,8 @@ await ContentHasher.AssertResponseHashMatchInternal( /// notifications that the operation should be cancelled. /// /// - /// A describing the - /// downloaded blob. contains + /// A describing the + /// downloaded blob. contains /// the blob's data. /// /// @@ -1689,13 +1716,29 @@ private async ValueTask> StartDownloadAsyn operationName: nameof(BlobBaseClient.Download), parameterName: nameof(conditions)); + bool? rangeGetContentMD5 = null; + bool? rangeGetContentCRC64 = null; + string structuredBodyType = null; + switch (validationOptions?.ChecksumAlgorithm.ResolveAuto()) + { + case StorageChecksumAlgorithm.MD5: + rangeGetContentMD5 = true; + break; + case StorageChecksumAlgorithm.StorageCrc64: + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; + break; + default: + break; + } + if (async) { response = await BlobRestClient.DownloadAsync( range: pageRange?.ToString(), leaseId: conditions?.LeaseId, - rangeGetContentMD5: validationOptions?.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.MD5 ? true : null, - rangeGetContentCRC64: validationOptions?.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 ? true : null, + rangeGetContentMD5: rangeGetContentMD5, + rangeGetContentCRC64: rangeGetContentCRC64, + structuredBodyType: structuredBodyType, encryptionKey: ClientConfiguration.CustomerProvidedKey?.EncryptionKey, encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, @@ -1712,8 +1755,9 @@ private async ValueTask> StartDownloadAsyn response = BlobRestClient.Download( range: pageRange?.ToString(), leaseId: conditions?.LeaseId, - rangeGetContentMD5: validationOptions?.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.MD5 ? true : null, - rangeGetContentCRC64: validationOptions?.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 ? true : null, + rangeGetContentMD5: rangeGetContentMD5, + rangeGetContentCRC64: rangeGetContentCRC64, + structuredBodyType: structuredBodyType, encryptionKey: ClientConfiguration.CustomerProvidedKey?.EncryptionKey, encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, @@ -1729,9 +1773,11 @@ private async ValueTask> StartDownloadAsyn long length = response.IsUnavailable() ? 0 : response.Headers.ContentLength ?? 0; ClientConfiguration.Pipeline.LogTrace($"Response: {response.GetRawResponse().Status}, ContentLength: {length}"); - return Response.FromValue( + Response result = Response.FromValue( response.ToBlobDownloadStreamingResult(), response.GetRawResponse()); + result.Value.ExpectTrailingDetails = structuredBodyType != null; + return result; } #endregion diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobClientOptions.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobClientOptions.cs index b16cefc83a535..f312e621bffc4 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobClientOptions.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobClientOptions.cs @@ -318,6 +318,8 @@ private void AddHeadersAndQueryParameters() Diagnostics.LoggedHeaderNames.Add("x-ms-encryption-key-sha256"); Diagnostics.LoggedHeaderNames.Add("x-ms-copy-source-error-code"); Diagnostics.LoggedHeaderNames.Add("x-ms-copy-source-status-code"); + Diagnostics.LoggedHeaderNames.Add("x-ms-structured-body"); + Diagnostics.LoggedHeaderNames.Add("x-ms-structured-content-length"); Diagnostics.LoggedQueryParameters.Add("comp"); Diagnostics.LoggedQueryParameters.Add("maxresults"); diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobClientSideDecryptor.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobClientSideDecryptor.cs index 9006282fab5b7..59b036d4b20bd 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobClientSideDecryptor.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobClientSideDecryptor.cs @@ -186,7 +186,7 @@ private static bool CanIgnorePadding(ContentRange? contentRange) // did we request the last block? // end is inclusive/0-index, so end = n and size = n+1 means we requested the last block - if (contentRange.Value.Size - contentRange.Value.End == 1) + if (contentRange.Value.TotalResourceLength - contentRange.Value.End == 1) { return false; } diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlockBlobClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlockBlobClient.cs index f5348303e57f0..00e6bf0780e2f 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlockBlobClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlockBlobClient.cs @@ -875,14 +875,35 @@ internal virtual async Task> UploadInternal( scope.Start(); Errors.VerifyStreamPosition(content, nameof(content)); - // compute hash BEFORE attaching progress handler - ContentHasher.GetHashResult hashResult = await ContentHasher.GetHashOrDefaultInternal( - content, - validationOptions, - async, - cancellationToken).ConfigureAwait(false); - - content = content?.WithNoDispose().WithProgress(progressHandler); + ContentHasher.GetHashResult hashResult = null; + long contentLength = (content?.Length - content?.Position) ?? 0; + long? structuredContentLength = default; + string structuredBodyType = null; + if (content != null && + validationOptions != null && + validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && + ClientSideEncryption == null) // don't allow feature combination + { + // report progress in terms of caller bytes, not encoded bytes + structuredContentLength = contentLength; + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; + content = content.WithNoDispose().WithProgress(progressHandler); + content = new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64); + contentLength = content.Length - content.Position; + } + else + { + // compute hash BEFORE attaching progress handler + hashResult = await ContentHasher.GetHashOrDefaultInternal( + content, + validationOptions, + async, + cancellationToken).ConfigureAwait(false); + content = content.WithNoDispose().WithProgress(progressHandler); + } ResponseWithHeaders response; @@ -921,6 +942,8 @@ internal virtual async Task> UploadInternal( legalHold: legalHold, transactionalContentMD5: hashResult?.MD5AsArray, transactionalContentCrc64: hashResult?.StorageCrc64AsArray, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, cancellationToken: cancellationToken) .ConfigureAwait(false); } @@ -953,6 +976,8 @@ internal virtual async Task> UploadInternal( legalHold: legalHold, transactionalContentMD5: hashResult?.MD5AsArray, transactionalContentCrc64: hashResult?.StorageCrc64AsArray, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, cancellationToken: cancellationToken); } @@ -1305,14 +1330,39 @@ internal virtual async Task> StageBlockInternal( Errors.VerifyStreamPosition(content, nameof(content)); - // compute hash BEFORE attaching progress handler - ContentHasher.GetHashResult hashResult = await ContentHasher.GetHashOrDefaultInternal( - content, - validationOptions, - async, - cancellationToken).ConfigureAwait(false); - - content = content.WithNoDispose().WithProgress(progressHandler); + ContentHasher.GetHashResult hashResult = null; + long contentLength = (content?.Length - content?.Position) ?? 0; + long? structuredContentLength = default; + string structuredBodyType = null; + if (validationOptions != null && + validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && + ClientSideEncryption == null) // don't allow feature combination + { + // report progress in terms of caller bytes, not encoded bytes + structuredContentLength = contentLength; + contentLength = (content?.Length - content?.Position) ?? 0; + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; + content = content.WithNoDispose().WithProgress(progressHandler); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); + contentLength = (content?.Length - content?.Position) ?? 0; + } + else + { + // compute hash BEFORE attaching progress handler + hashResult = await ContentHasher.GetHashOrDefaultInternal( + content, + validationOptions, + async, + cancellationToken).ConfigureAwait(false); + content = content.WithNoDispose().WithProgress(progressHandler); + } ResponseWithHeaders response; @@ -1320,7 +1370,7 @@ internal virtual async Task> StageBlockInternal( { response = await BlockBlobRestClient.StageBlockAsync( blockId: base64BlockId, - contentLength: (content?.Length - content?.Position) ?? 0, + contentLength: contentLength, body: content, transactionalContentCrc64: hashResult?.StorageCrc64AsArray, transactionalContentMD5: hashResult?.MD5AsArray, @@ -1329,6 +1379,8 @@ internal virtual async Task> StageBlockInternal( encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, encryptionScope: ClientConfiguration.EncryptionScope, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, cancellationToken: cancellationToken) .ConfigureAwait(false); } @@ -1336,7 +1388,7 @@ internal virtual async Task> StageBlockInternal( { response = BlockBlobRestClient.StageBlock( blockId: base64BlockId, - contentLength: (content?.Length - content?.Position) ?? 0, + contentLength: contentLength, body: content, transactionalContentCrc64: hashResult?.StorageCrc64AsArray, transactionalContentMD5: hashResult?.MD5AsArray, @@ -1345,6 +1397,8 @@ internal virtual async Task> StageBlockInternal( encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, encryptionScope: ClientConfiguration.EncryptionScope, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, cancellationToken: cancellationToken); } @@ -2791,7 +2845,7 @@ internal async Task OpenWriteInternal( immutabilityPolicy: default, legalHold: default, progressHandler: default, - transferValidationOverride: default, + transferValidationOverride: new() { ChecksumAlgorithm = StorageChecksumAlgorithm.None }, operationName: default, async: async, cancellationToken: cancellationToken) diff --git a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadDetails.cs b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadDetails.cs index bc119822cdc12..0490ec239798e 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadDetails.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadDetails.cs @@ -34,6 +34,14 @@ public class BlobDownloadDetails public byte[] ContentHash { get; internal set; } #pragma warning restore CA1819 // Properties should not return arrays + /// + /// When requested using , this value contains the CRC for the download blob range. + /// This value may only become populated once the network stream is fully consumed. If this instance is accessed through + /// , the network stream has already been consumed. Otherwise, consume the content stream before + /// checking this value. + /// + public byte[] ContentCrc { get; internal set; } + /// /// Returns the date and time the container was last modified. Any operation that modifies the blob, including an update of the blob's metadata or properties, changes the last-modified time of the blob. /// diff --git a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadInfo.cs b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadInfo.cs index e034573b54b3a..b42801e36ab55 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadInfo.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadInfo.cs @@ -4,6 +4,8 @@ using System; using System.ComponentModel; using System.IO; +using System.Threading.Tasks; +using Azure.Core; using Azure.Storage.Shared; namespace Azure.Storage.Blobs.Models @@ -49,6 +51,14 @@ public class BlobDownloadInfo : IDisposable, IDownloadedContent /// public BlobDownloadDetails Details { get; internal set; } + /// + /// Indicates some contents of are mixed into the response stream. + /// They will not be set until has been fully consumed. These details + /// will be extracted from the content stream by the library before the calling code can + /// encounter them. + /// + public bool ExpectTrailingDetails { get; internal set; } + /// /// Constructor. /// diff --git a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadStreamingResult.cs b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadStreamingResult.cs index 4fbada6e67aad..9b7d4d4e00dad 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadStreamingResult.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadStreamingResult.cs @@ -24,6 +24,14 @@ internal BlobDownloadStreamingResult() { } /// public Stream Content { get; internal set; } + /// + /// Indicates some contents of are mixed into the response stream. + /// They will not be set until has been fully consumed. These details + /// will be extracted from the content stream by the library before the calling code can + /// encounter them. + /// + public bool ExpectTrailingDetails { get; internal set; } + /// /// Disposes the by calling Dispose on the underlying stream. /// diff --git a/sdk/storage/Azure.Storage.Blobs/src/PageBlobClient.cs b/sdk/storage/Azure.Storage.Blobs/src/PageBlobClient.cs index fa575e41b8ebe..7038897531fbb 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/PageBlobClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/PageBlobClient.cs @@ -1363,15 +1363,42 @@ internal async Task> UploadPagesInternal( scope.Start(); Errors.VerifyStreamPosition(content, nameof(content)); - // compute hash BEFORE attaching progress handler - ContentHasher.GetHashResult hashResult = await ContentHasher.GetHashOrDefaultInternal( - content, - validationOptions, - async, - cancellationToken).ConfigureAwait(false); - - content = content?.WithNoDispose().WithProgress(progressHandler); - HttpRange range = new HttpRange(offset, (content?.Length - content?.Position) ?? null); + ContentHasher.GetHashResult hashResult = null; + long contentLength = (content?.Length - content?.Position) ?? 0; + long? structuredContentLength = default; + string structuredBodyType = null; + HttpRange range; + if (validationOptions != null && + validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && + ClientSideEncryption == null) // don't allow feature combination + { + // report progress in terms of caller bytes, not encoded bytes + structuredContentLength = contentLength; + contentLength = (content?.Length - content?.Position) ?? 0; + range = new HttpRange(offset, (content?.Length - content?.Position) ?? null); + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; + content = content?.WithNoDispose().WithProgress(progressHandler); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); + contentLength = (content?.Length - content?.Position) ?? 0; + } + else + { + // compute hash BEFORE attaching progress handler + hashResult = await ContentHasher.GetHashOrDefaultInternal( + content, + validationOptions, + async, + cancellationToken).ConfigureAwait(false); + content = content?.WithNoDispose().WithProgress(progressHandler); + range = new HttpRange(offset, (content?.Length - content?.Position) ?? null); + } ResponseWithHeaders response; @@ -1388,6 +1415,8 @@ internal async Task> UploadPagesInternal( encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, encryptionScope: ClientConfiguration.EncryptionScope, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, ifSequenceNumberLessThanOrEqualTo: conditions?.IfSequenceNumberLessThanOrEqual, ifSequenceNumberLessThan: conditions?.IfSequenceNumberLessThan, ifSequenceNumberEqualTo: conditions?.IfSequenceNumberEqual, @@ -1412,6 +1441,8 @@ internal async Task> UploadPagesInternal( encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, encryptionScope: ClientConfiguration.EncryptionScope, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, ifSequenceNumberLessThanOrEqualTo: conditions?.IfSequenceNumberLessThanOrEqual, ifSequenceNumberLessThan: conditions?.IfSequenceNumberLessThan, ifSequenceNumberEqualTo: conditions?.IfSequenceNumberEqual, diff --git a/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs b/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs index 2c52d0c256e34..1b14bcf98ec04 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs @@ -22,6 +22,8 @@ internal class PartitionedDownloader private const string _operationName = nameof(BlobBaseClient) + "." + nameof(BlobBaseClient.DownloadTo); private const string _innerOperationName = nameof(BlobBaseClient) + "." + nameof(BlobBaseClient.DownloadStreaming); + private const int Crc64Len = Constants.StorageCrc64SizeInBytes; + /// /// The client used to download the blob. /// @@ -48,6 +50,7 @@ internal class PartitionedDownloader /// private readonly StorageChecksumAlgorithm _validationAlgorithm; private readonly int _checksumSize; + // TODO disabling master crc temporarily. segment CRCs still handled. private bool UseMasterCrc => _validationAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64; private StorageCrc64HashAlgorithm _masterCrcCalculator = null; @@ -200,20 +203,31 @@ public async Task DownloadToInternal( } // Destination wrapped in master crc step if needed (must wait until after encryption wrap check) - Memory composedCrc = default; + byte[] composedCrcBuf = default; if (UseMasterCrc) { _masterCrcCalculator = StorageCrc64HashAlgorithm.Create(); destination = ChecksumCalculatingStream.GetWriteStream(destination, _masterCrcCalculator.Append); - disposables.Add(_arrayPool.RentAsMemoryDisposable( - Constants.StorageCrc64SizeInBytes, out composedCrc)); - composedCrc.Span.Clear(); + disposables.Add(_arrayPool.RentDisposable(Crc64Len, out composedCrcBuf)); + composedCrcBuf.Clear(); } // If the first segment was the entire blob, we'll copy that to // the output stream and finish now - long initialLength = initialResponse.Value.Details.ContentLength; - long totalLength = ParseRangeTotalLength(initialResponse.Value.Details.ContentRange); + long initialLength; + long totalLength; + // Get blob content length downloaded from content range when available to handle transit encoding + if (string.IsNullOrWhiteSpace(initialResponse.Value.Details.ContentRange)) + { + initialLength = initialResponse.Value.Details.ContentLength; + totalLength = 0; + } + else + { + ContentRange recievedRange = ContentRange.Parse(initialResponse.Value.Details.ContentRange); + initialLength = recievedRange.GetRangeLength(); + totalLength = recievedRange.TotalResourceLength.Value; + } if (initialLength == totalLength) { await HandleOneShotDownload(initialResponse, destination, async, cancellationToken) @@ -239,15 +253,16 @@ await HandleOneShotDownload(initialResponse, destination, async, cancellationTok } else { - using (_arrayPool.RentAsMemoryDisposable(_checksumSize, out Memory partitionChecksum)) + using (_arrayPool.RentDisposable(_checksumSize, out byte[] partitionChecksum)) { - await CopyToInternal(initialResponse, destination, partitionChecksum, async, cancellationToken).ConfigureAwait(false); + await CopyToInternal(initialResponse, destination, new(partitionChecksum, 0, _checksumSize), async, cancellationToken).ConfigureAwait(false); if (UseMasterCrc) { StorageCrc64Composer.Compose( - (composedCrc.ToArray(), 0L), - (partitionChecksum.ToArray(), initialResponse.Value.Details.ContentLength) - ).CopyTo(composedCrc); + (composedCrcBuf, 0L), + (partitionChecksum, initialResponse.Value.Details.ContentRange.GetContentRangeLengthOrDefault() + ?? initialResponse.Value.Details.ContentLength) + ).AsSpan(0, Crc64Len).CopyTo(composedCrcBuf); } } } @@ -286,15 +301,16 @@ await HandleOneShotDownload(initialResponse, destination, async, cancellationTok else { Response result = await responseValueTask.ConfigureAwait(false); - using (_arrayPool.RentAsMemoryDisposable(_checksumSize, out Memory partitionChecksum)) + using (_arrayPool.RentDisposable(_checksumSize, out byte[] partitionChecksum)) { - await CopyToInternal(result, destination, partitionChecksum, async, cancellationToken).ConfigureAwait(false); + await CopyToInternal(result, destination, new(partitionChecksum, 0, _checksumSize), async, cancellationToken).ConfigureAwait(false); if (UseMasterCrc) { StorageCrc64Composer.Compose( - (composedCrc.ToArray(), 0L), - (partitionChecksum.ToArray(), result.Value.Details.ContentLength) - ).CopyTo(composedCrc); + (composedCrcBuf, 0L), + (partitionChecksum, result.Value.Details.ContentRange.GetContentRangeLengthOrDefault() + ?? result.Value.Details.ContentLength) + ).AsSpan(0, Crc64Len).CopyTo(composedCrcBuf); } } } @@ -310,7 +326,7 @@ await HandleOneShotDownload(initialResponse, destination, async, cancellationTok } #pragma warning restore AZC0110 // DO NOT use await keyword in possibly synchronous scope. - await FinalizeDownloadInternal(destination, composedCrc, async, cancellationToken) + await FinalizeDownloadInternal(destination, composedCrcBuf?.AsMemory(0, Crc64Len) ?? default, async, cancellationToken) .ConfigureAwait(false); return initialResponse.GetRawResponse(); @@ -328,7 +344,7 @@ async Task ConsumeQueuedTask() // CopyToAsync causes ConsumeQueuedTask to wait until the // download is complete - using (_arrayPool.RentAsMemoryDisposable(_checksumSize, out Memory partitionChecksum)) + using (_arrayPool.RentDisposable(_checksumSize, out byte[] partitionChecksum)) { await CopyToInternal( response, @@ -337,13 +353,14 @@ await CopyToInternal( async, cancellationToken) .ConfigureAwait(false); - if (UseMasterCrc) - { - StorageCrc64Composer.Compose( - (composedCrc.ToArray(), 0L), - (partitionChecksum.ToArray(), response.Value.Details.ContentLength) - ).CopyTo(composedCrc); - } + if (UseMasterCrc) + { + StorageCrc64Composer.Compose( + (composedCrcBuf, 0L), + (partitionChecksum, response.Value.Details.ContentRange.GetContentRangeLengthOrDefault() + ?? response.Value.Details.ContentLength) + ).AsSpan(0, Crc64Len).CopyTo(composedCrcBuf); + } } } } @@ -379,7 +396,7 @@ await FinalizeDownloadInternal(destination, partitionChecksum, async, cancellati private async Task FinalizeDownloadInternal( Stream destination, - Memory composedCrc, + ReadOnlyMemory composedCrc, bool async, CancellationToken cancellationToken) { @@ -395,20 +412,6 @@ private async Task FinalizeDownloadInternal( } } - private static long ParseRangeTotalLength(string range) - { - if (range == null) - { - return 0; - } - int lengthSeparator = range.IndexOf("/", StringComparison.InvariantCultureIgnoreCase); - if (lengthSeparator == -1) - { - throw BlobErrors.ParsingFullHttpRangeFailed(range); - } - return long.Parse(range.Substring(lengthSeparator + 1), CultureInfo.InvariantCulture); - } - private async Task CopyToInternal( Response response, Stream destination, @@ -417,7 +420,9 @@ private async Task CopyToInternal( CancellationToken cancellationToken) { CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - using IHasher hasher = ContentHasher.GetHasherFromAlgorithmId(_validationAlgorithm); + // if structured message, this crc is validated in the decoding process. don't decode it here. + bool structuredMessage = response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader); + using IHasher hasher = structuredMessage ? null : ContentHasher.GetHasherFromAlgorithmId(_validationAlgorithm); using Stream rawSource = response.Value.Content; using Stream source = hasher != null ? ChecksumCalculatingStream.GetReadStream(rawSource, hasher.AppendHash) @@ -429,7 +434,13 @@ await source.CopyToInternal( cancellationToken) .ConfigureAwait(false); - if (hasher != null) + // with structured message, the message integrity will already be validated, + // but we can still get the checksum out of the response object + if (structuredMessage) + { + response.Value.Details.ContentCrc?.CopyTo(checksumBuffer.Span); + } + else if (hasher != null) { hasher.GetFinalHash(checksumBuffer.Span); (ReadOnlyMemory checksum, StorageChecksumAlgorithm _) diff --git a/sdk/storage/Azure.Storage.Blobs/src/autorest.md b/sdk/storage/Azure.Storage.Blobs/src/autorest.md index 7160bd89aba05..a96db9856ca58 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/autorest.md +++ b/sdk/storage/Azure.Storage.Blobs/src/autorest.md @@ -34,7 +34,7 @@ directive: if (property.includes('/{containerName}/{blob}')) { $[property]["parameters"] = $[property]["parameters"].filter(function(param) { return (typeof param['$ref'] === "undefined") || (false == param['$ref'].endsWith("#/parameters/ContainerName") && false == param['$ref'].endsWith("#/parameters/Blob"))}); - } + } else if (property.includes('/{containerName}')) { $[property]["parameters"] = $[property]["parameters"].filter(function(param) { return (typeof param['$ref'] === "undefined") || (false == param['$ref'].endsWith("#/parameters/ContainerName"))}); @@ -158,7 +158,7 @@ directive: var newName = property.replace('/{containerName}/{blob}', ''); $[newName] = $[oldName]; delete $[oldName]; - } + } else if (property.includes('/{containerName}')) { var oldName = property; diff --git a/sdk/storage/Azure.Storage.Blobs/tests/Azure.Storage.Blobs.Tests.csproj b/sdk/storage/Azure.Storage.Blobs/tests/Azure.Storage.Blobs.Tests.csproj index 62c7b6d17e63e..1c3856c83b64e 100644 --- a/sdk/storage/Azure.Storage.Blobs/tests/Azure.Storage.Blobs.Tests.csproj +++ b/sdk/storage/Azure.Storage.Blobs/tests/Azure.Storage.Blobs.Tests.csproj @@ -6,6 +6,9 @@ Microsoft Azure.Storage.Blobs client library tests false + + BlobSDK + diff --git a/sdk/storage/Azure.Storage.Blobs/tests/BlobBaseClientTransferValidationTests.cs b/sdk/storage/Azure.Storage.Blobs/tests/BlobBaseClientTransferValidationTests.cs index 73d11612f1d8c..3ec448e6d1ed0 100644 --- a/sdk/storage/Azure.Storage.Blobs/tests/BlobBaseClientTransferValidationTests.cs +++ b/sdk/storage/Azure.Storage.Blobs/tests/BlobBaseClientTransferValidationTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; +using System.Buffers; using System.IO; using System.Threading.Tasks; using Azure.Core.TestFramework; @@ -37,7 +39,10 @@ protected override async Task> GetDispo StorageChecksumAlgorithm uploadAlgorithm = StorageChecksumAlgorithm.None, StorageChecksumAlgorithm downloadAlgorithm = StorageChecksumAlgorithm.None) { - var disposingContainer = await ClientBuilder.GetTestContainerAsync(service: service, containerName: containerName); + var disposingContainer = await ClientBuilder.GetTestContainerAsync( + service: service, + containerName: containerName, + publicAccessType: PublicAccessType.None); disposingContainer.Container.ClientConfiguration.TransferValidation.Upload.ChecksumAlgorithm = uploadAlgorithm; disposingContainer.Container.ClientConfiguration.TransferValidation.Download.ChecksumAlgorithm = downloadAlgorithm; @@ -91,57 +96,96 @@ public override void TestAutoResolve() } #region Added Tests - [TestCaseSource("GetValidationAlgorithms")] - public async Task ExpectedDownloadStreamingStreamTypeReturned(StorageChecksumAlgorithm algorithm) + [Test] + public virtual async Task OlderServiceVersionThrowsOnStructuredMessage() { - await using var test = await GetDisposingContainerAsync(); + // use service version before structured message was introduced + await using DisposingContainer disposingContainer = await ClientBuilder.GetTestContainerAsync( + service: ClientBuilder.GetServiceClient_SharedKey( + InstrumentClientOptions(new BlobClientOptions(BlobClientOptions.ServiceVersion.V2024_11_04))), + publicAccessType: PublicAccessType.None); // Arrange - var data = GetRandomBuffer(Constants.KB); - BlobClient blob = InstrumentClient(test.Container.GetBlobClient(GetNewResourceName())); - using (var stream = new MemoryStream(data)) + const int dataLength = Constants.KB; + var data = GetRandomBuffer(dataLength); + + var resourceName = GetNewResourceName(); + var blob = InstrumentClient(disposingContainer.Container.GetBlobClient(GetNewResourceName())); + await blob.UploadAsync(BinaryData.FromBytes(data)); + + var validationOptions = new DownloadTransferValidationOptions { - await blob.UploadAsync(stream); - } - // don't make options instance at all for no hash request - DownloadTransferValidationOptions transferValidation = algorithm == StorageChecksumAlgorithm.None - ? default - : new DownloadTransferValidationOptions { ChecksumAlgorithm = algorithm }; + ChecksumAlgorithm = StorageChecksumAlgorithm.StorageCrc64 + }; + AsyncTestDelegate operation = async () => await (await blob.DownloadStreamingAsync( + new BlobDownloadOptions + { + Range = new HttpRange(length: Constants.StructuredMessage.MaxDownloadCrcWithHeader + 1), + TransferValidation = validationOptions, + })).Value.Content.CopyToAsync(Stream.Null); + Assert.That(operation, Throws.TypeOf()); + } + + [Test] + public async Task StructuredMessagePopulatesCrcDownloadStreaming() + { + await using DisposingContainer disposingContainer = await ClientBuilder.GetTestContainerAsync( + publicAccessType: PublicAccessType.None); + + const int dataLength = Constants.KB; + byte[] data = GetRandomBuffer(dataLength); + byte[] dataCrc = new byte[8]; + StorageCrc64Calculator.ComputeSlicedSafe(data, 0L).WriteCrc64(dataCrc); + + var blob = disposingContainer.Container.GetBlobClient(GetNewResourceName()); + await blob.UploadAsync(BinaryData.FromBytes(data)); - // Act - Response response = await blob.DownloadStreamingAsync(new BlobDownloadOptions + Response response = await blob.DownloadStreamingAsync(new() { - TransferValidation = transferValidation, - Range = new HttpRange(length: data.Length) + TransferValidation = new DownloadTransferValidationOptions + { + ChecksumAlgorithm = StorageChecksumAlgorithm.StorageCrc64 + } }); - // Assert - // validated stream is buffered - Assert.AreEqual(typeof(MemoryStream), response.Value.Content.GetType()); + // crc is not present until response stream is consumed + Assert.That(response.Value.Details.ContentCrc, Is.Null); + + byte[] downloadedData; + using (MemoryStream ms = new()) + { + await response.Value.Content.CopyToAsync(ms); + downloadedData = ms.ToArray(); + } + + Assert.That(response.Value.Details.ContentCrc, Is.EqualTo(dataCrc)); + Assert.That(downloadedData, Is.EqualTo(data)); } [Test] - public async Task ExpectedDownloadStreamingStreamTypeReturned_None() + public async Task StructuredMessagePopulatesCrcDownloadContent() { - await using var test = await GetDisposingContainerAsync(); + await using DisposingContainer disposingContainer = await ClientBuilder.GetTestContainerAsync( + publicAccessType: PublicAccessType.None); - // Arrange - var data = GetRandomBuffer(Constants.KB); - BlobClient blob = InstrumentClient(test.Container.GetBlobClient(GetNewResourceName())); - using (var stream = new MemoryStream(data)) - { - await blob.UploadAsync(stream); - } + const int dataLength = Constants.KB; + byte[] data = GetRandomBuffer(dataLength); + byte[] dataCrc = new byte[8]; + StorageCrc64Calculator.ComputeSlicedSafe(data, 0L).WriteCrc64(dataCrc); + + var blob = disposingContainer.Container.GetBlobClient(GetNewResourceName()); + await blob.UploadAsync(BinaryData.FromBytes(data)); - // Act - Response response = await blob.DownloadStreamingAsync(new BlobDownloadOptions + Response response = await blob.DownloadContentAsync(new BlobDownloadOptions() { - Range = new HttpRange(length: data.Length) + TransferValidation = new DownloadTransferValidationOptions + { + ChecksumAlgorithm = StorageChecksumAlgorithm.StorageCrc64 + } }); - // Assert - // unvalidated stream type is private; just check we didn't get back a buffered stream - Assert.AreNotEqual(typeof(MemoryStream), response.Value.Content.GetType()); + Assert.That(response.Value.Details.ContentCrc, Is.EqualTo(dataCrc)); + Assert.That(response.Value.Content.ToArray(), Is.EqualTo(data)); } #endregion } diff --git a/sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs b/sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs index 5d391440ea1b6..e85ff3aa5473f 100644 --- a/sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs +++ b/sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs @@ -1343,7 +1343,7 @@ public void CanParseLargeContentRange() { long compareValue = (long)Int32.MaxValue + 1; //Increase max int32 by one ContentRange contentRange = ContentRange.Parse($"bytes 0 {compareValue} {compareValue}"); - Assert.AreEqual((long)Int32.MaxValue + 1, contentRange.Size); + Assert.AreEqual((long)Int32.MaxValue + 1, contentRange.TotalResourceLength); Assert.AreEqual(0, contentRange.Start); Assert.AreEqual((long)Int32.MaxValue + 1, contentRange.End); } diff --git a/sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs b/sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs index d8d4756a510c1..af408264c5bfa 100644 --- a/sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs +++ b/sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs @@ -305,7 +305,7 @@ public Response GetStream(HttpRange range, BlobRequ ContentHash = new byte[] { 1, 2, 3 }, LastModified = DateTimeOffset.Now, Metadata = new Dictionary() { { "meta", "data" } }, - ContentRange = $"bytes {range.Offset}-{range.Offset + contentLength}/{_length}", + ContentRange = $"bytes {range.Offset}-{Math.Max(1, range.Offset + contentLength - 1)}/{_length}", ETag = s_etag, ContentEncoding = "test", CacheControl = "test", diff --git a/sdk/storage/Azure.Storage.Common/samples/Azure.Storage.Common.Samples.Tests.csproj b/sdk/storage/Azure.Storage.Common/samples/Azure.Storage.Common.Samples.Tests.csproj index 7d454aeaa0af2..aeca4497a8770 100644 --- a/sdk/storage/Azure.Storage.Common/samples/Azure.Storage.Common.Samples.Tests.csproj +++ b/sdk/storage/Azure.Storage.Common/samples/Azure.Storage.Common.Samples.Tests.csproj @@ -19,6 +19,7 @@ + PreserveNewest diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/ChecksumExtensions.cs b/sdk/storage/Azure.Storage.Common/src/Shared/ChecksumExtensions.cs new file mode 100644 index 0000000000000..48304640eee43 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/ChecksumExtensions.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers.Binary; + +namespace Azure.Storage; + +internal static class ChecksumExtensions +{ + public static void WriteCrc64(this ulong crc, Span dest) + => BinaryPrimitives.WriteUInt64LittleEndian(dest, crc); + + public static bool TryWriteCrc64(this ulong crc, Span dest) + => BinaryPrimitives.TryWriteUInt64LittleEndian(dest, crc); + + public static ulong ReadCrc64(this ReadOnlySpan crc) + => BinaryPrimitives.ReadUInt64LittleEndian(crc); + + public static bool TryReadCrc64(this ReadOnlySpan crc, out ulong value) + => BinaryPrimitives.TryReadUInt64LittleEndian(crc, out value); +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/Constants.cs b/sdk/storage/Azure.Storage.Common/src/Shared/Constants.cs index 3e00882188fba..35d5c1f1fde8c 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/Constants.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/Constants.cs @@ -657,6 +657,15 @@ internal static class AccountResources internal static readonly int[] PathStylePorts = { 10000, 10001, 10002, 10003, 10004, 10100, 10101, 10102, 10103, 10104, 11000, 11001, 11002, 11003, 11004, 11100, 11101, 11102, 11103, 11104 }; } + internal static class StructuredMessage + { + public const string StructuredMessageHeader = "x-ms-structured-body"; + public const string StructuredContentLength = "x-ms-structured-content-length"; + public const string CrcStructuredMessage = "XSM/1.0; properties=crc64"; + public const int DefaultSegmentContentLength = 4 * MB; + public const int MaxDownloadCrcWithHeader = 4 * MB; + } + internal static class ClientSideEncryption { public const string HttpMessagePropertyKeyV1 = "Azure.Storage.StorageTelemetryPolicy.ClientSideEncryption.V1"; diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/ContentRange.cs b/sdk/storage/Azure.Storage.Common/src/Shared/ContentRange.cs index f656382efad2b..cb3b0a7bee189 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/ContentRange.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/ContentRange.cs @@ -82,20 +82,20 @@ public RangeUnit(string value) public long? End { get; } /// - /// Size of this range, measured in this instance's . + /// Size of the entire resource this range is from, measured in this instance's . /// - public long? Size { get; } + public long? TotalResourceLength { get; } /// /// Unit this range is measured in. Generally "bytes". /// public RangeUnit Unit { get; } - public ContentRange(RangeUnit unit, long? start, long? end, long? size) + public ContentRange(RangeUnit unit, long? start, long? end, long? totalResourceLength) { Start = start; End = end; - Size = size; + TotalResourceLength = totalResourceLength; Unit = unit; } @@ -113,7 +113,7 @@ public static ContentRange Parse(string headerValue) string unit = default; long? start = default; long? end = default; - long? size = default; + long? resourceSize = default; try { @@ -136,10 +136,10 @@ public static ContentRange Parse(string headerValue) var rawSize = tokens[blobSizeIndex]; if (rawSize != WildcardMarker) { - size = long.Parse(rawSize, CultureInfo.InvariantCulture); + resourceSize = long.Parse(rawSize, CultureInfo.InvariantCulture); } - return new ContentRange(unit, start, end, size); + return new ContentRange(unit, start, end, resourceSize); } catch (IndexOutOfRangeException) { @@ -165,7 +165,7 @@ public static HttpRange ToHttpRange(ContentRange contentRange) /// /// Indicates whether this instance and a specified are equal /// - public bool Equals(ContentRange other) => (other.Start == Start) && (other.End == End) && (other.Unit == Unit) && (other.Size == Size); + public bool Equals(ContentRange other) => (other.Start == Start) && (other.End == End) && (other.Unit == Unit) && (other.TotalResourceLength == TotalResourceLength); /// /// Determines if two values are the same. @@ -185,6 +185,6 @@ public static HttpRange ToHttpRange(ContentRange contentRange) /// [EditorBrowsable(EditorBrowsableState.Never)] - public override int GetHashCode() => HashCodeBuilder.Combine(Start, End, Size, Unit.GetHashCode()); + public override int GetHashCode() => HashCodeBuilder.Combine(Start, End, TotalResourceLength, Unit.GetHashCode()); } } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/ContentRangeExtensions.cs b/sdk/storage/Azure.Storage.Common/src/Shared/ContentRangeExtensions.cs new file mode 100644 index 0000000000000..160a69b19a9c8 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/ContentRangeExtensions.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Azure.Storage.Cryptography; + +internal static class ContentRangeExtensions +{ + public static long? GetContentRangeLengthOrDefault(this string contentRange) + => string.IsNullOrWhiteSpace(contentRange) + ? default : ContentRange.Parse(contentRange).GetRangeLength(); + + public static long GetRangeLength(this ContentRange contentRange) + => contentRange.End.Value - contentRange.Start.Value + 1; +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/Errors.Clients.cs b/sdk/storage/Azure.Storage.Common/src/Shared/Errors.Clients.cs index 2a5fe38668104..867607e551e6a 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/Errors.Clients.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/Errors.Clients.cs @@ -3,6 +3,7 @@ using System; using System.Globalization; +using System.IO; using System.Linq; using System.Security.Authentication; using System.Xml.Serialization; @@ -105,9 +106,18 @@ public static ArgumentException VersionNotSupported(string paramName) public static RequestFailedException ClientRequestIdMismatch(Response response, string echo, string original) => new RequestFailedException(response.Status, $"Response x-ms-client-request-id '{echo}' does not match the original expected request id, '{original}'.", null); + public static InvalidDataException StructuredMessageNotAcknowledgedGET(Response response) + => new InvalidDataException($"Response does not acknowledge structured message was requested. Unknown data structure in response body."); + + public static InvalidDataException StructuredMessageNotAcknowledgedPUT(Response response) + => new InvalidDataException($"Response does not acknowledge structured message was sent. Unexpected data may have been persisted to storage."); + public static ArgumentException TransactionalHashingNotSupportedWithClientSideEncryption() => new ArgumentException("Client-side encryption and transactional hashing are not supported at the same time."); + public static InvalidDataException ExpectedStructuredMessage() + => new InvalidDataException($"Expected {Constants.StructuredMessage.StructuredMessageHeader} in response, but found none."); + public static void VerifyHttpsTokenAuth(Uri uri) { if (uri.Scheme != Constants.Https) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs b/sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs index 6b89a59011d51..e3372665928c1 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs @@ -72,6 +72,9 @@ public static ArgumentException CannotDeferTransactionalHashVerification() public static ArgumentException CannotInitializeWriteStreamWithData() => new ArgumentException("Initialized buffer for StorageWriteStream must be empty."); + public static InvalidDataException InvalidStructuredMessage(string optionalMessage = default) + => new InvalidDataException(("Invalid structured message data. " + optionalMessage ?? "").Trim()); + internal static void VerifyStreamPosition(Stream stream, string streamName) { if (stream != null && stream.CanSeek && stream.Length > 0 && stream.Position >= stream.Length) @@ -80,6 +83,22 @@ internal static void VerifyStreamPosition(Stream stream, string streamName) } } + internal static void AssertBufferMinimumSize(ReadOnlySpan buffer, int minSize, string paramName) + { + if (buffer.Length < minSize) + { + throw new ArgumentException($"Expected buffer Length of at least {minSize} bytes. Got {buffer.Length}.", paramName); + } + } + + internal static void AssertBufferExactSize(ReadOnlySpan buffer, int size, string paramName) + { + if (buffer.Length != size) + { + throw new ArgumentException($"Expected buffer Length of exactly {size} bytes. Got {buffer.Length}.", paramName); + } + } + public static void ThrowIfParamNull(object obj, string paramName) { if (obj == null) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/LazyLoadingReadOnlyStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/LazyLoadingReadOnlyStream.cs index c3e9c641c3fea..fe2db427bef02 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/LazyLoadingReadOnlyStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/LazyLoadingReadOnlyStream.cs @@ -249,41 +249,9 @@ private async Task DownloadInternal(bool async, CancellationToken cancellat response = await _downloadInternalFunc(range, _validationOptions, async, cancellationToken).ConfigureAwait(false); using Stream networkStream = response.Value.Content; - - // The number of bytes we just downloaded. - long downloadSize = GetResponseRange(response.GetRawResponse()).Length.Value; - - // The number of bytes we copied in the last loop. - int copiedBytes; - - // Bytes we have copied so far. - int totalCopiedBytes = 0; - - // Bytes remaining to copy. It is save to truncate the long because we asked for a max of int _buffer size bytes. - int remainingBytes = (int)downloadSize; - - do - { - if (async) - { - copiedBytes = await networkStream.ReadAsync( - buffer: _buffer, - offset: totalCopiedBytes, - count: remainingBytes, - cancellationToken: cancellationToken).ConfigureAwait(false); - } - else - { - copiedBytes = networkStream.Read( - buffer: _buffer, - offset: totalCopiedBytes, - count: remainingBytes); - } - - totalCopiedBytes += copiedBytes; - remainingBytes -= copiedBytes; - } - while (copiedBytes != 0); + // use stream copy to ensure consumption of any trailing metadata (e.g. structured message) + // allow buffer limits to catch the error of data size mismatch + int totalCopiedBytes = (int) await networkStream.CopyToInternal(new MemoryStream(_buffer), async, cancellationToken).ConfigureAwait((false)); _bufferPosition = 0; _bufferLength = totalCopiedBytes; @@ -291,7 +259,7 @@ private async Task DownloadInternal(bool async, CancellationToken cancellat // if we deferred transactional hash validation on download, validate now // currently we always defer but that may change - if (_validationOptions != default && _validationOptions.ChecksumAlgorithm != StorageChecksumAlgorithm.None && !_validationOptions.AutoValidateChecksum) + if (_validationOptions != default && _validationOptions.ChecksumAlgorithm == StorageChecksumAlgorithm.MD5 && !_validationOptions.AutoValidateChecksum) // TODO better condition { ContentHasher.AssertResponseHashMatch(_buffer, _bufferPosition, _bufferLength, _validationOptions.ChecksumAlgorithm, response.GetRawResponse()); } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/PooledMemoryStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/PooledMemoryStream.cs index 3e218d18a90af..6070329d10d3d 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/PooledMemoryStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/PooledMemoryStream.cs @@ -251,7 +251,7 @@ public override int Read(byte[] buffer, int offset, int count) Length - Position, bufferCount - (Position - offsetOfBuffer), count - read); - Array.Copy(currentBuffer, Position - offsetOfBuffer, buffer, read, toCopy); + Array.Copy(currentBuffer, Position - offsetOfBuffer, buffer, offset + read, toCopy); read += toCopy; Position += toCopy; } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StorageCrc64Composer.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StorageCrc64Composer.cs index ab6b76d78a87e..307ff23b21144 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StorageCrc64Composer.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StorageCrc64Composer.cs @@ -12,22 +12,52 @@ namespace Azure.Storage /// internal static class StorageCrc64Composer { - public static Memory Compose(params (byte[] Crc64, long OriginalDataLength)[] partitions) + public static byte[] Compose(params (byte[] Crc64, long OriginalDataLength)[] partitions) + => Compose(partitions.AsEnumerable()); + + public static byte[] Compose(IEnumerable<(byte[] Crc64, long OriginalDataLength)> partitions) { - return Compose(partitions.AsEnumerable()); + ulong result = Compose(partitions.Select(tup => (BitConverter.ToUInt64(tup.Crc64, 0), tup.OriginalDataLength))); + return BitConverter.GetBytes(result); } - public static Memory Compose(IEnumerable<(byte[] Crc64, long OriginalDataLength)> partitions) + public static byte[] Compose(params (ReadOnlyMemory Crc64, long OriginalDataLength)[] partitions) + => Compose(partitions.AsEnumerable()); + + public static byte[] Compose(IEnumerable<(ReadOnlyMemory Crc64, long OriginalDataLength)> partitions) { - ulong result = Compose(partitions.Select(tup => (BitConverter.ToUInt64(tup.Crc64, 0), tup.OriginalDataLength))); - return new Memory(BitConverter.GetBytes(result)); +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + ulong result = Compose(partitions.Select(tup => (BitConverter.ToUInt64(tup.Crc64.Span), tup.OriginalDataLength))); +#else + ulong result = Compose(partitions.Select(tup => (System.BitConverter.ToUInt64(tup.Crc64.ToArray(), 0), tup.OriginalDataLength))); +#endif + return BitConverter.GetBytes(result); } + public static byte[] Compose( + ReadOnlySpan leftCrc64, long leftOriginalDataLength, + ReadOnlySpan rightCrc64, long rightOriginalDataLength) + { +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + ulong result = Compose( + (BitConverter.ToUInt64(leftCrc64), leftOriginalDataLength), + (BitConverter.ToUInt64(rightCrc64), rightOriginalDataLength)); +#else + ulong result = Compose( + (BitConverter.ToUInt64(leftCrc64.ToArray(), 0), leftOriginalDataLength), + (BitConverter.ToUInt64(rightCrc64.ToArray(), 0), rightOriginalDataLength)); +#endif + return BitConverter.GetBytes(result); + } + + public static ulong Compose(params (ulong Crc64, long OriginalDataLength)[] partitions) + => Compose(partitions.AsEnumerable()); + public static ulong Compose(IEnumerable<(ulong Crc64, long OriginalDataLength)> partitions) { ulong composedCrc = 0; long composedDataLength = 0; - foreach (var tup in partitions) + foreach ((ulong crc64, long originalDataLength) in partitions) { composedCrc = StorageCrc64Calculator.Concatenate( uInitialCrcAB: 0, @@ -35,9 +65,9 @@ public static ulong Compose(IEnumerable<(ulong Crc64, long OriginalDataLength)> uFinalCrcA: composedCrc, uSizeA: (ulong) composedDataLength, uInitialCrcB: 0, - uFinalCrcB: tup.Crc64, - uSizeB: (ulong)tup.OriginalDataLength); - composedDataLength += tup.OriginalDataLength; + uFinalCrcB: crc64, + uSizeB: (ulong)originalDataLength); + composedDataLength += originalDataLength; } return composedCrc; } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StorageRequestValidationPipelinePolicy.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StorageRequestValidationPipelinePolicy.cs index 0cef4f4d8d4ed..9f4ddb5249e82 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StorageRequestValidationPipelinePolicy.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StorageRequestValidationPipelinePolicy.cs @@ -33,6 +33,35 @@ public override void OnReceivedResponse(HttpMessage message) { throw Errors.ClientRequestIdMismatch(message.Response, echo.First(), original); } + + if (message.Request.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader) && + message.Request.Headers.Contains(Constants.StructuredMessage.StructuredContentLength)) + { + AssertStructuredMessageAcknowledgedPUT(message); + } + else if (message.Request.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) + { + AssertStructuredMessageAcknowledgedGET(message); + } + } + + private static void AssertStructuredMessageAcknowledgedPUT(HttpMessage message) + { + if (!message.Response.IsError && + !message.Response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) + { + throw Errors.StructuredMessageNotAcknowledgedPUT(message.Response); + } + } + + private static void AssertStructuredMessageAcknowledgedGET(HttpMessage message) + { + if (!message.Response.IsError && + !(message.Response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader) && + message.Response.Headers.Contains(Constants.StructuredMessage.StructuredContentLength))) + { + throw Errors.StructuredMessageNotAcknowledgedGET(message.Response); + } } } } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StorageVersionExtensions.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StorageVersionExtensions.cs index 2a7bd90fb82a1..44c0973ea9be1 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StorageVersionExtensions.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StorageVersionExtensions.cs @@ -46,7 +46,7 @@ internal static class StorageVersionExtensions /// public const ServiceVersion LatestVersion = #if BlobSDK || QueueSDK || FileSDK || DataLakeSDK || ChangeFeedSDK || DataMovementSDK || BlobDataMovementSDK || ShareDataMovementSDK - ServiceVersion.V2024_11_04; + ServiceVersion.V2025_01_05; #else ERROR_STORAGE_SERVICE_NOT_DEFINED; #endif diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StreamExtensions.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StreamExtensions.cs index 31f121d414ea4..c8803ecf421e7 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StreamExtensions.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StreamExtensions.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; +using System.Buffers; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -48,7 +50,7 @@ public static async Task WriteInternal( } } - public static Task CopyToInternal( + public static Task CopyToInternal( this Stream src, Stream dest, bool async, @@ -79,21 +81,33 @@ public static Task CopyToInternal( /// Cancellation token for the operation. /// /// - public static async Task CopyToInternal( + public static async Task CopyToInternal( this Stream src, Stream dest, int bufferSize, bool async, CancellationToken cancellationToken) { + using IDisposable _ = ArrayPool.Shared.RentDisposable(bufferSize, out byte[] buffer); + long totalRead = 0; + int read; if (async) { - await src.CopyToAsync(dest, bufferSize, cancellationToken).ConfigureAwait(false); + while (0 < (read = await src.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false))) + { + totalRead += read; + await dest.WriteAsync(buffer, 0, read, cancellationToken).ConfigureAwait(false); + } } else { - src.CopyTo(dest, bufferSize); + while (0 < (read = src.Read(buffer, 0, buffer.Length))) + { + totalRead += read; + dest.Write(buffer, 0, read); + } } + return totalRead; } } } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs new file mode 100644 index 0000000000000..a0a46837797b9 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.IO; +using Azure.Storage.Common; + +namespace Azure.Storage.Shared; + +internal static class StructuredMessage +{ + public const int Crc64Length = 8; + + [Flags] + public enum Flags + { + None = 0, + StorageCrc64 = 1, + } + + public static class V1_0 + { + public const byte MessageVersionByte = 1; + + public const int StreamHeaderLength = 13; + public const int StreamHeaderVersionOffset = 0; + public const int StreamHeaderMessageLengthOffset = 1; + public const int StreamHeaderFlagsOffset = 9; + public const int StreamHeaderSegmentCountOffset = 11; + + public const int SegmentHeaderLength = 10; + public const int SegmentHeaderNumOffset = 0; + public const int SegmentHeaderContentLengthOffset = 2; + + #region Stream Header + public static void ReadStreamHeader( + ReadOnlySpan buffer, + out long messageLength, + out Flags flags, + out int totalSegments) + { + Errors.AssertBufferExactSize(buffer, 13, nameof(buffer)); + if (buffer[StreamHeaderVersionOffset] != 1) + { + throw new InvalidDataException("Unrecognized version of structured message."); + } + messageLength = (long)BinaryPrimitives.ReadUInt64LittleEndian(buffer.Slice(StreamHeaderMessageLengthOffset, 8)); + flags = (Flags)BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(StreamHeaderFlagsOffset, 2)); + totalSegments = BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(StreamHeaderSegmentCountOffset, 2)); + } + + public static int WriteStreamHeader( + Span buffer, + long messageLength, + Flags flags, + int totalSegments) + { + const int versionOffset = 0; + const int messageLengthOffset = 1; + const int flagsOffset = 9; + const int numSegmentsOffset = 11; + + Errors.AssertBufferMinimumSize(buffer, StreamHeaderLength, nameof(buffer)); + + buffer[versionOffset] = MessageVersionByte; + BinaryPrimitives.WriteUInt64LittleEndian(buffer.Slice(messageLengthOffset, 8), (ulong)messageLength); + BinaryPrimitives.WriteUInt16LittleEndian(buffer.Slice(flagsOffset, 2), (ushort)flags); + BinaryPrimitives.WriteUInt16LittleEndian(buffer.Slice(numSegmentsOffset, 2), (ushort)totalSegments); + + return StreamHeaderLength; + } + + /// + /// Gets stream header in a buffer rented from the provided ArrayPool. + /// + /// + /// Disposable to return the buffer to the pool. + /// + public static IDisposable GetStreamHeaderBytes( + ArrayPool pool, + out Memory bytes, + long messageLength, + Flags flags, + int totalSegments) + { + Argument.AssertNotNull(pool, nameof(pool)); + IDisposable disposable = pool.RentAsMemoryDisposable(StreamHeaderLength, out bytes); + WriteStreamHeader(bytes.Span, messageLength, flags, totalSegments); + return disposable; + } + #endregion + + #region StreamFooter + public static int GetStreamFooterSize(Flags flags) + => flags.HasFlag(Flags.StorageCrc64) ? Crc64Length : 0; + + public static void ReadStreamFooter( + ReadOnlySpan buffer, + Flags flags, + out ulong crc64) + { + int expectedBufferSize = GetSegmentFooterSize(flags); + Errors.AssertBufferExactSize(buffer, expectedBufferSize, nameof(buffer)); + + crc64 = flags.HasFlag(Flags.StorageCrc64) ? buffer.ReadCrc64() : default; + } + + public static int WriteStreamFooter(Span buffer, ReadOnlySpan crc64 = default) + { + int requiredSpace = 0; + if (!crc64.IsEmpty) + { + Errors.AssertBufferExactSize(crc64, Crc64Length, nameof(crc64)); + requiredSpace += Crc64Length; + } + + Errors.AssertBufferMinimumSize(buffer, requiredSpace, nameof(buffer)); + int offset = 0; + if (!crc64.IsEmpty) + { + crc64.CopyTo(buffer.Slice(offset, Crc64Length)); + offset += Crc64Length; + } + + return offset; + } + + /// + /// Gets stream header in a buffer rented from the provided ArrayPool. + /// + /// + /// Disposable to return the buffer to the pool. + /// + public static IDisposable GetStreamFooterBytes( + ArrayPool pool, + out Memory bytes, + ReadOnlySpan crc64 = default) + { + Argument.AssertNotNull(pool, nameof(pool)); + IDisposable disposable = pool.RentAsMemoryDisposable(StreamHeaderLength, out bytes); + WriteStreamFooter(bytes.Span, crc64); + return disposable; + } + #endregion + + #region SegmentHeader + public static void ReadSegmentHeader( + ReadOnlySpan buffer, + out int segmentNum, + out long contentLength) + { + Errors.AssertBufferExactSize(buffer, 10, nameof(buffer)); + segmentNum = BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(0, 2)); + contentLength = (long)BinaryPrimitives.ReadUInt64LittleEndian(buffer.Slice(2, 8)); + } + + public static int WriteSegmentHeader(Span buffer, int segmentNum, long segmentLength) + { + const int segmentNumOffset = 0; + const int segmentLengthOffset = 2; + + Errors.AssertBufferMinimumSize(buffer, SegmentHeaderLength, nameof(buffer)); + + BinaryPrimitives.WriteUInt16LittleEndian(buffer.Slice(segmentNumOffset, 2), (ushort)segmentNum); + BinaryPrimitives.WriteUInt64LittleEndian(buffer.Slice(segmentLengthOffset, 8), (ulong)segmentLength); + + return SegmentHeaderLength; + } + + /// + /// Gets segment header in a buffer rented from the provided ArrayPool. + /// + /// + /// Disposable to return the buffer to the pool. + /// + public static IDisposable GetSegmentHeaderBytes( + ArrayPool pool, + out Memory bytes, + int segmentNum, + long segmentLength) + { + Argument.AssertNotNull(pool, nameof(pool)); + IDisposable disposable = pool.RentAsMemoryDisposable(SegmentHeaderLength, out bytes); + WriteSegmentHeader(bytes.Span, segmentNum, segmentLength); + return disposable; + } + #endregion + + #region SegmentFooter + public static int GetSegmentFooterSize(Flags flags) + => flags.HasFlag(Flags.StorageCrc64) ? Crc64Length : 0; + + public static void ReadSegmentFooter( + ReadOnlySpan buffer, + Flags flags, + out ulong crc64) + { + int expectedBufferSize = GetSegmentFooterSize(flags); + Errors.AssertBufferExactSize(buffer, expectedBufferSize, nameof(buffer)); + + crc64 = flags.HasFlag(Flags.StorageCrc64) ? buffer.ReadCrc64() : default; + } + + public static int WriteSegmentFooter(Span buffer, ReadOnlySpan crc64 = default) + { + int requiredSpace = 0; + if (!crc64.IsEmpty) + { + Errors.AssertBufferExactSize(crc64, Crc64Length, nameof(crc64)); + requiredSpace += Crc64Length; + } + + Errors.AssertBufferMinimumSize(buffer, requiredSpace, nameof(buffer)); + int offset = 0; + if (!crc64.IsEmpty) + { + crc64.CopyTo(buffer.Slice(offset, Crc64Length)); + offset += Crc64Length; + } + + return offset; + } + + /// + /// Gets stream header in a buffer rented from the provided ArrayPool. + /// + /// + /// Disposable to return the buffer to the pool. + /// + public static IDisposable GetSegmentFooterBytes( + ArrayPool pool, + out Memory bytes, + ReadOnlySpan crc64 = default) + { + Argument.AssertNotNull(pool, nameof(pool)); + IDisposable disposable = pool.RentAsMemoryDisposable(StreamHeaderLength, out bytes); + WriteSegmentFooter(bytes.Span, crc64); + return disposable; + } + #endregion + } +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs new file mode 100644 index 0000000000000..22dfaef259972 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Storage.Shared; + +internal class StructuredMessageDecodingRetriableStream : Stream +{ + public class DecodedData + { + public ulong Crc { get; set; } + } + + private readonly Stream _innerRetriable; + private long _decodedBytesRead; + + private readonly StructuredMessage.Flags _expectedFlags; + private readonly List _decodedDatas; + private readonly Action _onComplete; + + private StorageCrc64HashAlgorithm _totalContentCrc; + + private readonly Func _decodingStreamFactory; + private readonly Func> _decodingAsyncStreamFactory; + + public StructuredMessageDecodingRetriableStream( + Stream initialDecodingStream, + StructuredMessageDecodingStream.RawDecodedData initialDecodedData, + StructuredMessage.Flags expectedFlags, + Func decodingStreamFactory, + Func> decodingAsyncStreamFactory, + Action onComplete, + ResponseClassifier responseClassifier, + int maxRetries) + { + _decodingStreamFactory = decodingStreamFactory; + _decodingAsyncStreamFactory = decodingAsyncStreamFactory; + _innerRetriable = RetriableStream.Create(initialDecodingStream, StreamFactory, StreamFactoryAsync, responseClassifier, maxRetries); + _decodedDatas = new() { initialDecodedData }; + _expectedFlags = expectedFlags; + _onComplete = onComplete; + + if (expectedFlags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + _totalContentCrc = StorageCrc64HashAlgorithm.Create(); + } + } + + private Stream StreamFactory(long _) + { + long offset = _decodedDatas.SelectMany(d => d.SegmentCrcs).Select(s => s.SegmentLen).Sum(); + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = _decodingStreamFactory(offset); + _decodedDatas.Add(decodedData); + FastForwardInternal(decodingStream, _decodedBytesRead - offset, false).EnsureCompleted(); + return decodingStream; + } + + private async ValueTask StreamFactoryAsync(long _) + { + long offset = _decodedDatas.SelectMany(d => d.SegmentCrcs).Select(s => s.SegmentLen).Sum(); + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = await _decodingAsyncStreamFactory(offset).ConfigureAwait(false); + _decodedDatas.Add(decodedData); + await FastForwardInternal(decodingStream, _decodedBytesRead - offset, true).ConfigureAwait(false); + return decodingStream; + } + + private static async ValueTask FastForwardInternal(Stream stream, long bytes, bool async) + { + using (ArrayPool.Shared.RentDisposable(4 * Constants.KB, out byte[] buffer)) + { + if (async) + { + while (bytes > 0) + { + bytes -= await stream.ReadAsync(buffer, 0, (int)Math.Min(bytes, buffer.Length)).ConfigureAwait(false); + } + } + else + { + while (bytes > 0) + { + bytes -= stream.Read(buffer, 0, (int)Math.Min(bytes, buffer.Length)); + } + } + } + } + + protected override void Dispose(bool disposing) + { + _decodedDatas.Clear(); + _innerRetriable.Dispose(); + } + + private void OnCompleted() + { + DecodedData final = new(); + if (_totalContentCrc != null) + { + final.Crc = ValidateCrc(); + } + _onComplete?.Invoke(final); + } + + private ulong ValidateCrc() + { + using IDisposable _ = ArrayPool.Shared.RentDisposable(StructuredMessage.Crc64Length * 2, out byte[] buf); + Span calculatedBytes = new(buf, 0, StructuredMessage.Crc64Length); + _totalContentCrc.GetCurrentHash(calculatedBytes); + ulong calculated = BinaryPrimitives.ReadUInt64LittleEndian(calculatedBytes); + + ulong reported = _decodedDatas.Count == 1 + ? _decodedDatas.First().TotalCrc.Value + : StorageCrc64Composer.Compose(_decodedDatas.SelectMany(d => d.SegmentCrcs)); + + if (calculated != reported) + { + Span reportedBytes = new(buf, calculatedBytes.Length, StructuredMessage.Crc64Length); + BinaryPrimitives.WriteUInt64LittleEndian(reportedBytes, reported); + throw Errors.ChecksumMismatch(calculatedBytes, reportedBytes); + } + + return calculated; + } + + #region Read + public override int Read(byte[] buffer, int offset, int count) + { + int read = _innerRetriable.Read(buffer, offset, count); + _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } + else + { + _totalContentCrc?.Append(new ReadOnlySpan(buffer, offset, read)); + } + return read; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int read = await _innerRetriable.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } + else + { + _totalContentCrc?.Append(new ReadOnlySpan(buffer, offset, read)); + } + return read; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override int Read(Span buffer) + { + int read = _innerRetriable.Read(buffer); + _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } + else + { + _totalContentCrc?.Append(buffer.Slice(0, read)); + } + return read; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + int read = await _innerRetriable.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } + else + { + _totalContentCrc?.Append(buffer.Span.Slice(0, read)); + } + return read; + } +#endif + + public override int ReadByte() + { + int val = _innerRetriable.ReadByte(); + _decodedBytesRead += 1; + if (val == -1) + { + OnCompleted(); + } + return val; + } + + public override int EndRead(IAsyncResult asyncResult) + { + int read = _innerRetriable.EndRead(asyncResult); + _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } + return read; + } + #endregion + + #region Passthru + public override bool CanRead => _innerRetriable.CanRead; + + public override bool CanSeek => _innerRetriable.CanSeek; + + public override bool CanWrite => _innerRetriable.CanWrite; + + public override bool CanTimeout => _innerRetriable.CanTimeout; + + public override long Length => _innerRetriable.Length; + + public override long Position { get => _innerRetriable.Position; set => _innerRetriable.Position = value; } + + public override void Flush() => _innerRetriable.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => _innerRetriable.FlushAsync(cancellationToken); + + public override long Seek(long offset, SeekOrigin origin) => _innerRetriable.Seek(offset, origin); + + public override void SetLength(long value) => _innerRetriable.SetLength(value); + + public override void Write(byte[] buffer, int offset, int count) => _innerRetriable.Write(buffer, offset, count); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => _innerRetriable.WriteAsync(buffer, offset, count, cancellationToken); + + public override void WriteByte(byte value) => _innerRetriable.WriteByte(value); + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => _innerRetriable.BeginWrite(buffer, offset, count, callback, state); + + public override void EndWrite(IAsyncResult asyncResult) => _innerRetriable.EndWrite(asyncResult); + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => _innerRetriable.BeginRead(buffer, offset, count, callback, state); + + public override int ReadTimeout { get => _innerRetriable.ReadTimeout; set => _innerRetriable.ReadTimeout = value; } + + public override int WriteTimeout { get => _innerRetriable.WriteTimeout; set => _innerRetriable.WriteTimeout = value; } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override void Write(ReadOnlySpan buffer) => _innerRetriable.Write(buffer); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => _innerRetriable.WriteAsync(buffer, cancellationToken); +#endif + #endregion +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs new file mode 100644 index 0000000000000..e6b193ae18260 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs @@ -0,0 +1,542 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.Storage.Common; + +namespace Azure.Storage.Shared; + +/// +/// Decodes a structured message stream as the data is read. +/// +/// +/// Wraps the inner stream in a , which avoids using its internal +/// buffer if individual Read() calls are larger than it. This ensures one of the three scenarios +/// +/// +/// Read buffer >= stream buffer: +/// There is enough space in the read buffer for inline metadata to be safely +/// extracted in only one read to the true inner stream. +/// +/// +/// Read buffer < next inline metadata: +/// The stream buffer has been activated, and we can read multiple small times from the inner stream +/// without multi-reading the real stream, even when partway through an existing stream buffer. +/// +/// +/// Else: +/// Same as #1, but also the already-allocated stream buffer has been used to slightly improve +/// resource churn when reading inner stream. +/// +/// +/// +internal class StructuredMessageDecodingStream : Stream +{ + internal class RawDecodedData + { + public long? InnerStreamLength { get; set; } + public int? TotalSegments { get; set; } + public StructuredMessage.Flags? Flags { get; set; } + public List<(ulong SegmentCrc, long SegmentLen)> SegmentCrcs { get; } = new(); + public ulong? TotalCrc { get; set; } + public bool DecodeCompleted { get; set; } + } + + private enum SMRegion + { + StreamHeader, + StreamFooter, + SegmentHeader, + SegmentFooter, + SegmentContent, + } + + private readonly Stream _innerBufferedStream; + + private byte[] _metadataBuffer = ArrayPool.Shared.Rent(Constants.KB); + private int _metadataBufferOffset = 0; + private int _metadataBufferLength = 0; + + private int _streamHeaderLength; + private int _streamFooterLength; + private int _segmentHeaderLength; + private int _segmentFooterLength; + + private long? _expectedInnerStreamLength; + + private bool _disposed; + + private readonly RawDecodedData _decodedData; + private StorageCrc64HashAlgorithm _totalContentCrc; + private StorageCrc64HashAlgorithm _segmentCrc; + + private readonly bool _validateChecksums; + + public override bool CanRead => true; + + public override bool CanWrite => false; + + public override bool CanSeek => false; + + public override bool CanTimeout => _innerBufferedStream.CanTimeout; + + public override int ReadTimeout => _innerBufferedStream.ReadTimeout; + + public override int WriteTimeout => _innerBufferedStream.WriteTimeout; + + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public static (Stream DecodedStream, RawDecodedData DecodedData) WrapStream( + Stream innerStream, + long? expextedStreamLength = default) + { + RawDecodedData data = new(); + return (new StructuredMessageDecodingStream(innerStream, data, expextedStreamLength), data); + } + + private StructuredMessageDecodingStream( + Stream innerStream, + RawDecodedData decodedData, + long? expectedStreamLength) + { + Argument.AssertNotNull(innerStream, nameof(innerStream)); + Argument.AssertNotNull(decodedData, nameof(decodedData)); + + _expectedInnerStreamLength = expectedStreamLength; + _innerBufferedStream = new BufferedStream(innerStream); + _decodedData = decodedData; + + // Assumes stream will be structured message 1.0. Will validate this when consuming stream. + _streamHeaderLength = StructuredMessage.V1_0.StreamHeaderLength; + _segmentHeaderLength = StructuredMessage.V1_0.SegmentHeaderLength; + + _validateChecksums = true; + } + + #region Write + public override void Flush() => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + #endregion + + #region Read + public override int Read(byte[] buf, int offset, int count) + { + int decodedRead; + int read; + do + { + read = _innerBufferedStream.Read(buf, offset, count); + _innerStreamConsumed += read; + decodedRead = Decode(new Span(buf, offset, read)); + } while (decodedRead <= 0 && read > 0); + + if (read <= 0) + { + AssertDecodeFinished(); + } + + return decodedRead; + } + + public override async Task ReadAsync(byte[] buf, int offset, int count, CancellationToken cancellationToken) + { + int decodedRead; + int read; + do + { + read = await _innerBufferedStream.ReadAsync(buf, offset, count, cancellationToken).ConfigureAwait(false); + _innerStreamConsumed += read; + decodedRead = Decode(new Span(buf, offset, read)); + } while (decodedRead <= 0 && read > 0); + + if (read <= 0) + { + AssertDecodeFinished(); + } + + return decodedRead; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override int Read(Span buf) + { + int decodedRead; + int read; + do + { + read = _innerBufferedStream.Read(buf); + _innerStreamConsumed += read; + decodedRead = Decode(buf.Slice(0, read)); + } while (decodedRead <= 0 && read > 0); + + if (read <= 0) + { + AssertDecodeFinished(); + } + + return decodedRead; + } + + public override async ValueTask ReadAsync(Memory buf, CancellationToken cancellationToken = default) + { + int decodedRead; + int read; + do + { + read = await _innerBufferedStream.ReadAsync(buf).ConfigureAwait(false); + _innerStreamConsumed += read; + decodedRead = Decode(buf.Slice(0, read).Span); + } while (decodedRead <= 0 && read > 0); + + if (read <= 0) + { + AssertDecodeFinished(); + } + + return decodedRead; + } +#endif + + private void AssertDecodeFinished() + { + if (_streamFooterLength > 0 && !_decodedData.DecodeCompleted) + { + throw Errors.InvalidStructuredMessage("Premature end of stream."); + } + _decodedData.DecodeCompleted = true; + } + + private long _innerStreamConsumed = 0; + private long _decodedContentConsumed = 0; + private SMRegion _currentRegion = SMRegion.StreamHeader; + private int _currentSegmentNum = 0; + private long _currentSegmentContentLength; + private long _currentSegmentContentRemaining; + private long CurrentRegionLength => _currentRegion switch + { + SMRegion.StreamHeader => _streamHeaderLength, + SMRegion.StreamFooter => _streamFooterLength, + SMRegion.SegmentHeader => _segmentHeaderLength, + SMRegion.SegmentFooter => _segmentFooterLength, + SMRegion.SegmentContent => _currentSegmentContentLength, + _ => 0, + }; + + /// + /// Decodes given bytes in place. Decoding based on internal stream position info. + /// Decoded data size will be less than or equal to encoded data length. + /// + /// + /// Length of the decoded data in . + /// + private int Decode(Span buffer) + { + if (buffer.IsEmpty) + { + return 0; + } + List<(int Offset, int Count)> gaps = new(); + + int bufferConsumed = ProcessMetadataBuffer(buffer); + + if (bufferConsumed > 0) + { + gaps.Add((0, bufferConsumed)); + } + + while (bufferConsumed < buffer.Length) + { + if (_currentRegion == SMRegion.SegmentContent) + { + int read = (int)Math.Min(buffer.Length - bufferConsumed, _currentSegmentContentRemaining); + _totalContentCrc?.Append(buffer.Slice(bufferConsumed, read)); + _segmentCrc?.Append(buffer.Slice(bufferConsumed, read)); + bufferConsumed += read; + _decodedContentConsumed += read; + _currentSegmentContentRemaining -= read; + if (_currentSegmentContentRemaining == 0) + { + _currentRegion = SMRegion.SegmentFooter; + } + } + else if (buffer.Length - bufferConsumed < CurrentRegionLength) + { + SavePartialMetadata(buffer.Slice(bufferConsumed)); + gaps.Add((bufferConsumed, buffer.Length - bufferConsumed)); + bufferConsumed = buffer.Length; + } + else + { + int processed = _currentRegion switch + { + SMRegion.StreamHeader => ProcessStreamHeader(buffer.Slice(bufferConsumed)), + SMRegion.StreamFooter => ProcessStreamFooter(buffer.Slice(bufferConsumed)), + SMRegion.SegmentHeader => ProcessSegmentHeader(buffer.Slice(bufferConsumed)), + SMRegion.SegmentFooter => ProcessSegmentFooter(buffer.Slice(bufferConsumed)), + _ => 0, + }; + // TODO surface error if processed is 0 + gaps.Add((bufferConsumed, processed)); + bufferConsumed += processed; + } + } + + if (gaps.Count == 0) + { + return buffer.Length; + } + + // gaps is already sorted by offset due to how it was assembled + int gap = 0; + for (int i = gaps.First().Offset; i < buffer.Length; i++) + { + if (gaps.Count > 0 && gaps.First().Offset == i) + { + int count = gaps.First().Count; + gap += count; + i += count - 1; + gaps.RemoveAt(0); + } + else + { + buffer[i - gap] = buffer[i]; + } + } + return buffer.Length - gap; + } + + /// + /// Processes metadata in the internal buffer, if any. Appends any necessary data + /// from the append buffer to complete metadata. + /// + /// + /// Bytes consumed from . + /// + private int ProcessMetadataBuffer(ReadOnlySpan append) + { + if (_metadataBufferLength == 0) + { + return 0; + } + if (_currentRegion == SMRegion.SegmentContent) + { + return 0; + } + int appended = 0; + if (_metadataBufferLength < CurrentRegionLength && append.Length > 0) + { + appended = Math.Min((int)CurrentRegionLength - _metadataBufferLength, append.Length); + SavePartialMetadata(append.Slice(0, appended)); + } + if (_metadataBufferLength == CurrentRegionLength) + { + Span metadata = new(_metadataBuffer, _metadataBufferOffset, (int)CurrentRegionLength); + switch (_currentRegion) + { + case SMRegion.StreamHeader: + ProcessStreamHeader(metadata); + break; + case SMRegion.StreamFooter: + ProcessStreamFooter(metadata); + break; + case SMRegion.SegmentHeader: + ProcessSegmentHeader(metadata); + break; + case SMRegion.SegmentFooter: + ProcessSegmentFooter(metadata); + break; + } + _metadataBufferOffset = 0; + _metadataBufferLength = 0; + } + return appended; + } + + private void SavePartialMetadata(ReadOnlySpan span) + { + // safety array resize w/ArrayPool + if (_metadataBufferLength + span.Length > _metadataBuffer.Length) + { + ResizeMetadataBuffer(2 * (_metadataBufferLength + span.Length)); + } + + // realign any existing content if necessary + if (_metadataBufferLength != 0 && _metadataBufferOffset != 0) + { + // don't use Array.Copy() to move elements in the same array + for (int i = 0; i < _metadataBufferLength; i++) + { + _metadataBuffer[i] = _metadataBuffer[i + _metadataBufferOffset]; + } + _metadataBufferOffset = 0; + } + + span.CopyTo(new Span(_metadataBuffer, _metadataBufferOffset + _metadataBufferLength, span.Length)); + _metadataBufferLength += span.Length; + } + + private int ProcessStreamHeader(ReadOnlySpan span) + { + StructuredMessage.V1_0.ReadStreamHeader( + span.Slice(0, _streamHeaderLength), + out long streamLength, + out StructuredMessage.Flags flags, + out int totalSegments); + + _decodedData.InnerStreamLength = streamLength; + _decodedData.Flags = flags; + _decodedData.TotalSegments = totalSegments; + + if (_expectedInnerStreamLength.HasValue && _expectedInnerStreamLength.Value != streamLength) + { + throw Errors.InvalidStructuredMessage("Unexpected message size."); + } + + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + _segmentFooterLength = StructuredMessage.Crc64Length; + _streamFooterLength = StructuredMessage.Crc64Length; + if (_validateChecksums) + { + _segmentCrc = StorageCrc64HashAlgorithm.Create(); + _totalContentCrc = StorageCrc64HashAlgorithm.Create(); + } + } + _currentRegion = SMRegion.SegmentHeader; + return _streamHeaderLength; + } + + private int ProcessStreamFooter(ReadOnlySpan span) + { + int footerLen = StructuredMessage.V1_0.GetStreamFooterSize(_decodedData.Flags.Value); + StructuredMessage.V1_0.ReadStreamFooter( + span.Slice(0, footerLen), + _decodedData.Flags.Value, + out ulong reportedCrc); + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + if (_validateChecksums) + { + ValidateCrc64(_totalContentCrc, reportedCrc); + } + _decodedData.TotalCrc = reportedCrc; + } + + if (_innerStreamConsumed != _decodedData.InnerStreamLength) + { + throw Errors.InvalidStructuredMessage("Unexpected message size."); + } + if (_currentSegmentNum != _decodedData.TotalSegments) + { + throw Errors.InvalidStructuredMessage("Missing expected message segments."); + } + + _decodedData.DecodeCompleted = true; + return footerLen; + } + + private int ProcessSegmentHeader(ReadOnlySpan span) + { + StructuredMessage.V1_0.ReadSegmentHeader( + span.Slice(0, _segmentHeaderLength), + out int newSegNum, + out _currentSegmentContentLength); + _currentSegmentContentRemaining = _currentSegmentContentLength; + if (newSegNum != _currentSegmentNum + 1) + { + throw Errors.InvalidStructuredMessage("Unexpected segment number in structured message."); + } + _currentSegmentNum = newSegNum; + _currentRegion = SMRegion.SegmentContent; + return _segmentHeaderLength; + } + + private int ProcessSegmentFooter(ReadOnlySpan span) + { + int footerLen = StructuredMessage.V1_0.GetSegmentFooterSize(_decodedData.Flags.Value); + StructuredMessage.V1_0.ReadSegmentFooter( + span.Slice(0, footerLen), + _decodedData.Flags.Value, + out ulong reportedCrc); + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + if (_validateChecksums) + { + ValidateCrc64(_segmentCrc, reportedCrc); + _segmentCrc = StorageCrc64HashAlgorithm.Create(); + } + _decodedData.SegmentCrcs.Add((reportedCrc, _currentSegmentContentLength)); + } + _currentRegion = _currentSegmentNum == _decodedData.TotalSegments ? SMRegion.StreamFooter : SMRegion.SegmentHeader; + return footerLen; + } + + private static void ValidateCrc64(StorageCrc64HashAlgorithm calculation, ulong reported) + { + using IDisposable _ = ArrayPool.Shared.RentDisposable(StructuredMessage.Crc64Length * 2, out byte[] buf); + Span calculatedBytes = new(buf, 0, StructuredMessage.Crc64Length); + Span reportedBytes = new(buf, calculatedBytes.Length, StructuredMessage.Crc64Length); + calculation.GetCurrentHash(calculatedBytes); + reported.WriteCrc64(reportedBytes); + if (!calculatedBytes.SequenceEqual(reportedBytes)) + { + throw Errors.ChecksumMismatch(calculatedBytes, reportedBytes); + } + } + #endregion + + public override long Seek(long offset, SeekOrigin origin) + => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (_disposed) + { + return; + } + + if (disposing) + { + _innerBufferedStream.Dispose(); + _disposed = true; + } + } + + private void ResizeMetadataBuffer(int newSize) + { + byte[] newBuf = ArrayPool.Shared.Rent(newSize); + Array.Copy(_metadataBuffer, _metadataBufferOffset, newBuf, 0, _metadataBufferLength); + ArrayPool.Shared.Return(_metadataBuffer); + _metadataBuffer = newBuf; + } + + private void AlignMetadataBuffer() + { + if (_metadataBufferOffset != 0 && _metadataBufferLength != 0) + { + for (int i = 0; i < _metadataBufferLength; i++) + { + _metadataBuffer[i] = _metadataBuffer[_metadataBufferOffset + i]; + } + _metadataBufferOffset = 0; + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageEncodingStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageEncodingStream.cs new file mode 100644 index 0000000000000..cb0ef340155ec --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageEncodingStream.cs @@ -0,0 +1,545 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; +using Azure.Storage.Common; + +namespace Azure.Storage.Shared; + +internal class StructuredMessageEncodingStream : Stream +{ + private readonly Stream _innerStream; + + private readonly int _streamHeaderLength; + private readonly int _streamFooterLength; + private readonly int _segmentHeaderLength; + private readonly int _segmentFooterLength; + private readonly int _segmentContentLength; + + private readonly StructuredMessage.Flags _flags; + private bool _disposed; + + private bool UseCrcSegment => _flags.HasFlag(StructuredMessage.Flags.StorageCrc64); + private readonly StorageCrc64HashAlgorithm _totalCrc; + private StorageCrc64HashAlgorithm _segmentCrc; + private readonly byte[] _segmentCrcs; + private int _latestSegmentCrcd = 0; + + #region Segments + /// + /// Gets the 1-indexed segment number the underlying stream is currently positioned in. + /// 1-indexed to match segment labelling as specified by SM spec. + /// + private int CurrentInnerSegment => (int)Math.Floor(_innerStream.Position / (float)_segmentContentLength) + 1; + + /// + /// Gets the 1-indexed segment number the encoded data stream is currently positioned in. + /// 1-indexed to match segment labelling as specified by SM spec. + /// + private int CurrentEncodingSegment + { + get + { + // edge case: always on final segment when at end of inner stream + if (_innerStream.Position == _innerStream.Length) + { + return TotalSegments; + } + // when writing footer, inner stream is positioned at next segment, + // but this stream is still writing the previous one + if (_currentRegion == SMRegion.SegmentFooter) + { + return CurrentInnerSegment - 1; + } + return CurrentInnerSegment; + } + } + + /// + /// Segment length including header and footer. + /// + private int SegmentTotalLength => _segmentHeaderLength + _segmentContentLength + _segmentFooterLength; + + private int TotalSegments => GetTotalSegments(_innerStream, _segmentContentLength); + private static int GetTotalSegments(Stream innerStream, long segmentContentLength) + { + return (int)Math.Ceiling(innerStream.Length / (float)segmentContentLength); + } + #endregion + + public override bool CanRead => true; + + public override bool CanWrite => false; + + public override bool CanSeek => _innerStream.CanSeek; + + public override bool CanTimeout => _innerStream.CanTimeout; + + public override int ReadTimeout => _innerStream.ReadTimeout; + + public override int WriteTimeout => _innerStream.WriteTimeout; + + public override long Length => + _streamHeaderLength + _streamFooterLength + + (_segmentHeaderLength + _segmentFooterLength) * TotalSegments + + _innerStream.Length; + + #region Position + private enum SMRegion + { + StreamHeader, + StreamFooter, + SegmentHeader, + SegmentFooter, + SegmentContent, + } + + private SMRegion _currentRegion = SMRegion.StreamHeader; + private int _currentRegionPosition = 0; + + private long _maxSeekPosition = 0; + + public override long Position + { + get + { + return _currentRegion switch + { + SMRegion.StreamHeader => _currentRegionPosition, + SMRegion.StreamFooter => _streamHeaderLength + + TotalSegments * (_segmentHeaderLength + _segmentFooterLength) + + _innerStream.Length + + _currentRegionPosition, + SMRegion.SegmentHeader => _innerStream.Position + + _streamHeaderLength + + (CurrentEncodingSegment - 1) * (_segmentHeaderLength + _segmentFooterLength) + + _currentRegionPosition, + SMRegion.SegmentFooter => _innerStream.Position + + _streamHeaderLength + + // Inner stream has moved to next segment but we're still writing the previous segment footer + CurrentEncodingSegment * (_segmentHeaderLength + _segmentFooterLength) - + _segmentFooterLength + _currentRegionPosition, + SMRegion.SegmentContent => _innerStream.Position + + _streamHeaderLength + + CurrentEncodingSegment * (_segmentHeaderLength + _segmentFooterLength) - + _segmentFooterLength, + _ => throw new InvalidDataException($"{nameof(StructuredMessageEncodingStream)} invalid state."), + }; + } + set + { + Argument.AssertInRange(value, 0, _maxSeekPosition, nameof(value)); + if (value < _streamHeaderLength) + { + _currentRegion = SMRegion.StreamHeader; + _currentRegionPosition = (int)value; + _innerStream.Position = 0; + return; + } + if (value >= Length - _streamFooterLength) + { + _currentRegion = SMRegion.StreamFooter; + _currentRegionPosition = (int)(value - (Length - _streamFooterLength)); + _innerStream.Position = _innerStream.Length; + return; + } + int newSegmentNum = 1 + (int)Math.Floor((value - _streamHeaderLength) / (double)(_segmentHeaderLength + _segmentFooterLength + _segmentContentLength)); + int segmentPosition = (int)(value - _streamHeaderLength - + ((newSegmentNum - 1) * (_segmentHeaderLength + _segmentFooterLength + _segmentContentLength))); + + if (segmentPosition < _segmentHeaderLength) + { + _currentRegion = SMRegion.SegmentHeader; + _currentRegionPosition = (int)((value - _streamHeaderLength) % SegmentTotalLength); + _innerStream.Position = (newSegmentNum - 1) * _segmentContentLength; + return; + } + if (segmentPosition < _segmentHeaderLength + _segmentContentLength) + { + _currentRegion = SMRegion.SegmentContent; + _currentRegionPosition = (int)((value - _streamHeaderLength) % SegmentTotalLength) - + _segmentHeaderLength; + _innerStream.Position = (newSegmentNum - 1) * _segmentContentLength + _currentRegionPosition; + return; + } + + _currentRegion = SMRegion.SegmentFooter; + _currentRegionPosition = (int)((value - _streamHeaderLength) % SegmentTotalLength) - + _segmentHeaderLength - _segmentContentLength; + _innerStream.Position = newSegmentNum * _segmentContentLength; + } + } + #endregion + + public StructuredMessageEncodingStream( + Stream innerStream, + int segmentContentLength, + StructuredMessage.Flags flags) + { + Argument.AssertNotNull(innerStream, nameof(innerStream)); + if (innerStream.GetLengthOrDefault() == default) + { + throw new ArgumentException("Stream must have known length.", nameof(innerStream)); + } + if (innerStream.Position != 0) + { + throw new ArgumentException("Stream must be at starting position.", nameof(innerStream)); + } + // stream logic likely breaks down with segment length of 1; enforce >=2 rather than just positive number + // real world scenarios will probably use a minimum of tens of KB + Argument.AssertInRange(segmentContentLength, 2, int.MaxValue, nameof(segmentContentLength)); + + _flags = flags; + _segmentContentLength = segmentContentLength; + + _streamHeaderLength = StructuredMessage.V1_0.StreamHeaderLength; + _streamFooterLength = UseCrcSegment ? StructuredMessage.Crc64Length : 0; + _segmentHeaderLength = StructuredMessage.V1_0.SegmentHeaderLength; + _segmentFooterLength = UseCrcSegment ? StructuredMessage.Crc64Length : 0; + + if (UseCrcSegment) + { + _totalCrc = StorageCrc64HashAlgorithm.Create(); + _segmentCrc = StorageCrc64HashAlgorithm.Create(); + _segmentCrcs = ArrayPool.Shared.Rent( + GetTotalSegments(innerStream, segmentContentLength) * StructuredMessage.Crc64Length); + innerStream = ChecksumCalculatingStream.GetReadStream(innerStream, span => + { + _totalCrc.Append(span); + _segmentCrc.Append(span); + }); + } + + _innerStream = innerStream; + } + + #region Write + public override void Flush() => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + #endregion + + #region Read + public override int Read(byte[] buffer, int offset, int count) + => ReadInternal(buffer, offset, count, async: false, cancellationToken: default).EnsureCompleted(); + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => await ReadInternal(buffer, offset, count, async: true, cancellationToken).ConfigureAwait(false); + + private async ValueTask ReadInternal(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < count && Position < Length) + { + int subreadOffset = offset + totalRead; + int subreadCount = count - totalRead; + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += await ReadFromInnerStreamInternal( + buffer, subreadOffset, subreadCount, async, cancellationToken).ConfigureAwait(false); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override int Read(Span buffer) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < buffer.Length && Position < Length) + { + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(buffer.Slice(totalRead)); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += ReadFromInnerStream(buffer.Slice(totalRead)); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < buffer.Length && Position < Length) + { + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(buffer.Slice(totalRead).Span); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += await ReadFromInnerStreamAsync(buffer.Slice(totalRead), cancellationToken).ConfigureAwait(false); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } +#endif + + #region Read Headers/Footers + private int ReadFromStreamHeader(Span buffer) + { + int read = Math.Min(buffer.Length, _streamHeaderLength - _currentRegionPosition); + using IDisposable _ = StructuredMessage.V1_0.GetStreamHeaderBytes( + ArrayPool.Shared, out Memory headerBytes, Length, _flags, TotalSegments); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _streamHeaderLength) + { + _currentRegion = SMRegion.SegmentHeader; + _currentRegionPosition = 0; + } + + return read; + } + + private int ReadFromStreamFooter(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentFooterLength - _currentRegionPosition); + if (read <= 0) + { + return 0; + } + + using IDisposable _ = StructuredMessage.V1_0.GetStreamFooterBytes( + ArrayPool.Shared, + out Memory footerBytes, + crc64: UseCrcSegment + ? _totalCrc.GetCurrentHash() // TODO array pooling + : default); + footerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + return read; + } + + private int ReadFromSegmentHeader(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentHeaderLength - _currentRegionPosition); + using IDisposable _ = StructuredMessage.V1_0.GetSegmentHeaderBytes( + ArrayPool.Shared, + out Memory headerBytes, + CurrentInnerSegment, + Math.Min(_segmentContentLength, _innerStream.Length - _innerStream.Position)); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _segmentHeaderLength) + { + _currentRegion = SMRegion.SegmentContent; + _currentRegionPosition = 0; + } + + return read; + } + + private int ReadFromSegmentFooter(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentFooterLength - _currentRegionPosition); + if (read < 0) + { + return 0; + } + + using IDisposable _ = StructuredMessage.V1_0.GetSegmentFooterBytes( + ArrayPool.Shared, + out Memory headerBytes, + crc64: UseCrcSegment + ? new Span( + _segmentCrcs, + (CurrentEncodingSegment-1) * _totalCrc.HashLengthInBytes, + _totalCrc.HashLengthInBytes) + : default); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _segmentFooterLength) + { + _currentRegion = _innerStream.Position == _innerStream.Length + ? SMRegion.StreamFooter : SMRegion.SegmentHeader; + _currentRegionPosition = 0; + } + + return read; + } + #endregion + + #region ReadUnderlyingStream + private int MaxInnerStreamRead => _segmentContentLength - _currentRegionPosition; + + private void CleanupContentSegment() + { + if (_currentRegionPosition == _segmentContentLength || _innerStream.Position >= _innerStream.Length) + { + _currentRegion = SMRegion.SegmentFooter; + _currentRegionPosition = 0; + if (UseCrcSegment && CurrentEncodingSegment - 1 == _latestSegmentCrcd) + { + _segmentCrc.GetCurrentHash(new Span( + _segmentCrcs, + _latestSegmentCrcd * _segmentCrc.HashLengthInBytes, + _segmentCrc.HashLengthInBytes)); + _latestSegmentCrcd++; + _segmentCrc = StorageCrc64HashAlgorithm.Create(); + } + } + } + + private async ValueTask ReadFromInnerStreamInternal( + byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken) + { + int read = async + ? await _innerStream.ReadAsync(buffer, offset, Math.Min(count, MaxInnerStreamRead)).ConfigureAwait(false) + : _innerStream.Read(buffer, offset, Math.Min(count, MaxInnerStreamRead)); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + private int ReadFromInnerStream(Span buffer) + { + if (MaxInnerStreamRead < buffer.Length) + { + buffer = buffer.Slice(0, MaxInnerStreamRead); + } + int read = _innerStream.Read(buffer); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } + + private async ValueTask ReadFromInnerStreamAsync(Memory buffer, CancellationToken cancellationToken) + { + if (MaxInnerStreamRead < buffer.Length) + { + buffer = buffer.Slice(0, MaxInnerStreamRead); + } + int read = await _innerStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } +#endif + #endregion + + // don't allow stream to seek too far forward. track how far the stream has been naturally read. + private void UpdateLatestPosition() + { + if (_maxSeekPosition < Position) + { + _maxSeekPosition = Position; + } + } + #endregion + + public override long Seek(long offset, SeekOrigin origin) + { + switch (origin) + { + case SeekOrigin.Begin: + Position = offset; + break; + case SeekOrigin.Current: + Position += offset; + break; + case SeekOrigin.End: + Position = Length + offset; + break; + } + return Position; + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (_disposed) + { + return; + } + + if (disposing) + { + _innerStream.Dispose(); + _disposed = true; + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessagePrecalculatedCrcWrapperStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessagePrecalculatedCrcWrapperStream.cs new file mode 100644 index 0000000000000..3569ef4339735 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessagePrecalculatedCrcWrapperStream.cs @@ -0,0 +1,451 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; +using Azure.Storage.Common; + +namespace Azure.Storage.Shared; + +internal class StructuredMessagePrecalculatedCrcWrapperStream : Stream +{ + private readonly Stream _innerStream; + + private readonly int _streamHeaderLength; + private readonly int _streamFooterLength; + private readonly int _segmentHeaderLength; + private readonly int _segmentFooterLength; + + private bool _disposed; + + private readonly byte[] _crc; + + public override bool CanRead => true; + + public override bool CanWrite => false; + + public override bool CanSeek => _innerStream.CanSeek; + + public override bool CanTimeout => _innerStream.CanTimeout; + + public override int ReadTimeout => _innerStream.ReadTimeout; + + public override int WriteTimeout => _innerStream.WriteTimeout; + + public override long Length => + _streamHeaderLength + _streamFooterLength + + _segmentHeaderLength + _segmentFooterLength + + _innerStream.Length; + + #region Position + private enum SMRegion + { + StreamHeader, + StreamFooter, + SegmentHeader, + SegmentFooter, + SegmentContent, + } + + private SMRegion _currentRegion = SMRegion.StreamHeader; + private int _currentRegionPosition = 0; + + private long _maxSeekPosition = 0; + + public override long Position + { + get + { + return _currentRegion switch + { + SMRegion.StreamHeader => _currentRegionPosition, + SMRegion.SegmentHeader => _innerStream.Position + + _streamHeaderLength + + _currentRegionPosition, + SMRegion.SegmentContent => _streamHeaderLength + + _segmentHeaderLength + + _innerStream.Position, + SMRegion.SegmentFooter => _streamHeaderLength + + _segmentHeaderLength + + _innerStream.Length + + _currentRegionPosition, + SMRegion.StreamFooter => _streamHeaderLength + + _segmentHeaderLength + + _innerStream.Length + + _segmentFooterLength + + _currentRegionPosition, + _ => throw new InvalidDataException($"{nameof(StructuredMessageEncodingStream)} invalid state."), + }; + } + set + { + Argument.AssertInRange(value, 0, _maxSeekPosition, nameof(value)); + if (value < _streamHeaderLength) + { + _currentRegion = SMRegion.StreamHeader; + _currentRegionPosition = (int)value; + _innerStream.Position = 0; + return; + } + if (value < _streamHeaderLength + _segmentHeaderLength) + { + _currentRegion = SMRegion.SegmentHeader; + _currentRegionPosition = (int)(value - _streamHeaderLength); + _innerStream.Position = 0; + return; + } + if (value < _streamHeaderLength + _segmentHeaderLength + _innerStream.Length) + { + _currentRegion = SMRegion.SegmentContent; + _currentRegionPosition = (int)(value - _streamHeaderLength - _segmentHeaderLength); + _innerStream.Position = value - _streamHeaderLength - _segmentHeaderLength; + return; + } + if (value < _streamHeaderLength + _segmentHeaderLength + _innerStream.Length + _segmentFooterLength) + { + _currentRegion = SMRegion.SegmentFooter; + _currentRegionPosition = (int)(value - _streamHeaderLength - _segmentHeaderLength - _innerStream.Length); + _innerStream.Position = _innerStream.Length; + return; + } + + _currentRegion = SMRegion.StreamFooter; + _currentRegionPosition = (int)(value - _streamHeaderLength - _segmentHeaderLength - _innerStream.Length - _segmentFooterLength); + _innerStream.Position = _innerStream.Length; + } + } + #endregion + + public StructuredMessagePrecalculatedCrcWrapperStream( + Stream innerStream, + ReadOnlySpan precalculatedCrc) + { + Argument.AssertNotNull(innerStream, nameof(innerStream)); + if (innerStream.GetLengthOrDefault() == default) + { + throw new ArgumentException("Stream must have known length.", nameof(innerStream)); + } + if (innerStream.Position != 0) + { + throw new ArgumentException("Stream must be at starting position.", nameof(innerStream)); + } + + _streamHeaderLength = StructuredMessage.V1_0.StreamHeaderLength; + _streamFooterLength = StructuredMessage.Crc64Length; + _segmentHeaderLength = StructuredMessage.V1_0.SegmentHeaderLength; + _segmentFooterLength = StructuredMessage.Crc64Length; + + _crc = ArrayPool.Shared.Rent(StructuredMessage.Crc64Length); + precalculatedCrc.CopyTo(_crc); + + _innerStream = innerStream; + } + + #region Write + public override void Flush() => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + #endregion + + #region Read + public override int Read(byte[] buffer, int offset, int count) + => ReadInternal(buffer, offset, count, async: false, cancellationToken: default).EnsureCompleted(); + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => await ReadInternal(buffer, offset, count, async: true, cancellationToken).ConfigureAwait(false); + + private async ValueTask ReadInternal(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < count && Position < Length) + { + int subreadOffset = offset + totalRead; + int subreadCount = count - totalRead; + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += await ReadFromInnerStreamInternal( + buffer, subreadOffset, subreadCount, async, cancellationToken).ConfigureAwait(false); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override int Read(Span buffer) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < buffer.Length && Position < Length) + { + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(buffer.Slice(totalRead)); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += ReadFromInnerStream(buffer.Slice(totalRead)); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < buffer.Length && Position < Length) + { + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(buffer.Slice(totalRead).Span); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += await ReadFromInnerStreamAsync(buffer.Slice(totalRead), cancellationToken).ConfigureAwait(false); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } +#endif + + #region Read Headers/Footers + private int ReadFromStreamHeader(Span buffer) + { + int read = Math.Min(buffer.Length, _streamHeaderLength - _currentRegionPosition); + using IDisposable _ = StructuredMessage.V1_0.GetStreamHeaderBytes( + ArrayPool.Shared, + out Memory headerBytes, + Length, + StructuredMessage.Flags.StorageCrc64, + totalSegments: 1); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _streamHeaderLength) + { + _currentRegion = SMRegion.SegmentHeader; + _currentRegionPosition = 0; + } + + return read; + } + + private int ReadFromStreamFooter(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentFooterLength - _currentRegionPosition); + if (read <= 0) + { + return 0; + } + + using IDisposable _ = StructuredMessage.V1_0.GetStreamFooterBytes( + ArrayPool.Shared, + out Memory footerBytes, + new ReadOnlySpan(_crc, 0, StructuredMessage.Crc64Length)); + footerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + return read; + } + + private int ReadFromSegmentHeader(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentHeaderLength - _currentRegionPosition); + using IDisposable _ = StructuredMessage.V1_0.GetSegmentHeaderBytes( + ArrayPool.Shared, + out Memory headerBytes, + segmentNum: 1, + _innerStream.Length); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _segmentHeaderLength) + { + _currentRegion = SMRegion.SegmentContent; + _currentRegionPosition = 0; + } + + return read; + } + + private int ReadFromSegmentFooter(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentFooterLength - _currentRegionPosition); + if (read < 0) + { + return 0; + } + + using IDisposable _ = StructuredMessage.V1_0.GetSegmentFooterBytes( + ArrayPool.Shared, + out Memory headerBytes, + new ReadOnlySpan(_crc, 0, StructuredMessage.Crc64Length)); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _segmentFooterLength) + { + _currentRegion = _innerStream.Position == _innerStream.Length + ? SMRegion.StreamFooter : SMRegion.SegmentHeader; + _currentRegionPosition = 0; + } + + return read; + } + #endregion + + #region ReadUnderlyingStream + private void CleanupContentSegment() + { + if (_innerStream.Position >= _innerStream.Length) + { + _currentRegion = SMRegion.SegmentFooter; + _currentRegionPosition = 0; + } + } + + private async ValueTask ReadFromInnerStreamInternal( + byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken) + { + int read = async + ? await _innerStream.ReadAsync(buffer, offset, count).ConfigureAwait(false) + : _innerStream.Read(buffer, offset, count); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + private int ReadFromInnerStream(Span buffer) + { + int read = _innerStream.Read(buffer); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } + + private async ValueTask ReadFromInnerStreamAsync(Memory buffer, CancellationToken cancellationToken) + { + int read = await _innerStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } +#endif + #endregion + + // don't allow stream to seek too far forward. track how far the stream has been naturally read. + private void UpdateLatestPosition() + { + if (_maxSeekPosition < Position) + { + _maxSeekPosition = Position; + } + } + #endregion + + public override long Seek(long offset, SeekOrigin origin) + { + switch (origin) + { + case SeekOrigin.Begin: + Position = offset; + break; + case SeekOrigin.Current: + Position += offset; + break; + case SeekOrigin.End: + Position = Length + offset; + break; + } + return Position; + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (_disposed) + { + return; + } + + if (disposing) + { + ArrayPool.Shared.Return(_crc); + _innerStream.Dispose(); + _disposed = true; + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/TransferValidationOptionsExtensions.cs b/sdk/storage/Azure.Storage.Common/src/Shared/TransferValidationOptionsExtensions.cs index af21588b4ae09..763d385240383 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/TransferValidationOptionsExtensions.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/TransferValidationOptionsExtensions.cs @@ -9,14 +9,7 @@ public static StorageChecksumAlgorithm ResolveAuto(this StorageChecksumAlgorithm { if (checksumAlgorithm == StorageChecksumAlgorithm.Auto) { -#if BlobSDK || DataLakeSDK || CommonSDK return StorageChecksumAlgorithm.StorageCrc64; -#elif FileSDK // file shares don't support crc64 - return StorageChecksumAlgorithm.MD5; -#else - throw new System.NotSupportedException( - $"{typeof(TransferValidationOptionsExtensions).FullName}.{nameof(ResolveAuto)} is not supported."); -#endif } return checksumAlgorithm; } diff --git a/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj b/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj index 5db86ebee984b..2863b85f6feb2 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj +++ b/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj @@ -13,9 +13,12 @@ + + + @@ -28,6 +31,7 @@ + @@ -46,6 +50,11 @@ + + + + + diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs index 7411eb1499312..f4e4b92ed73c4 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs @@ -15,6 +15,7 @@ internal class FaultyStream : Stream private readonly Exception _exceptionToRaise; private int _remainingExceptions; private Action _onFault; + private long _position = 0; public FaultyStream( Stream innerStream, @@ -40,7 +41,7 @@ public FaultyStream( public override long Position { - get => _innerStream.Position; + get => CanSeek ? _innerStream.Position : _position; set => _innerStream.Position = value; } @@ -53,7 +54,9 @@ public override int Read(byte[] buffer, int offset, int count) { if (_remainingExceptions == 0 || Position + count <= _raiseExceptionAt || _raiseExceptionAt >= _innerStream.Length) { - return _innerStream.Read(buffer, offset, count); + int read = _innerStream.Read(buffer, offset, count); + _position += read; + return read; } else { @@ -61,11 +64,13 @@ public override int Read(byte[] buffer, int offset, int count) } } - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { if (_remainingExceptions == 0 || Position + count <= _raiseExceptionAt || _raiseExceptionAt >= _innerStream.Length) { - return _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + int read = await _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + _position += read; + return read; } else { diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs new file mode 100644 index 0000000000000..828c41179bba3 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using Azure.Core; +using Azure.Core.Pipeline; +using Azure.Storage.Shared; + +namespace Azure.Storage.Test.Shared +{ + internal class ObserveStructuredMessagePolicy : HttpPipelineSynchronousPolicy + { + private readonly HashSet _requestScopes = new(); + + private readonly HashSet _responseScopes = new(); + + public ObserveStructuredMessagePolicy() + { + } + + public override void OnSendingRequest(HttpMessage message) + { + if (_requestScopes.Count > 0) + { + byte[] encodedContent; + byte[] underlyingContent; + StructuredMessageDecodingStream.RawDecodedData decodedData; + using (MemoryStream ms = new()) + { + message.Request.Content.WriteTo(ms, default); + encodedContent = ms.ToArray(); + using (MemoryStream ms2 = new()) + { + (Stream s, decodedData) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedContent)); + s.CopyTo(ms2); + underlyingContent = ms2.ToArray(); + } + } + } + } + + public override void OnReceivedResponse(HttpMessage message) + { + } + + public IDisposable CheckRequestScope() => CheckMessageScope.CheckRequestScope(this); + + public IDisposable CheckResponseScope() => CheckMessageScope.CheckResponseScope(this); + + private class CheckMessageScope : IDisposable + { + private bool _isRequestScope; + private ObserveStructuredMessagePolicy _policy; + + public static CheckMessageScope CheckRequestScope(ObserveStructuredMessagePolicy policy) + { + CheckMessageScope result = new() + { + _isRequestScope = true, + _policy = policy + }; + result._policy._requestScopes.Add(result); + return result; + } + + public static CheckMessageScope CheckResponseScope(ObserveStructuredMessagePolicy policy) + { + CheckMessageScope result = new() + { + _isRequestScope = false, + _policy = policy + }; + result._policy._responseScopes.Add(result); + return result; + } + + public void Dispose() + { + (_isRequestScope ? _policy._requestScopes : _policy._responseScopes).Remove(this); + } + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs new file mode 100644 index 0000000000000..ad395e862f827 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Linq; +using System.Text; +using Azure.Core; +using NUnit.Framework; + +namespace Azure.Storage; + +public static partial class RequestExtensions +{ + public static string AssertHeaderPresent(this Request request, string headerName) + { + if (request.Headers.TryGetValue(headerName, out string value)) + { + return headerName == Constants.StructuredMessage.StructuredMessageHeader ? null : value; + } + StringBuilder sb = new StringBuilder() + .AppendLine($"`{headerName}` expected on request but was not found.") + .AppendLine($"{request.Method} {request.Uri}") + .AppendLine(string.Join("\n", request.Headers.Select(h => $"{h.Name}: {h.Value}s"))) + ; + Assert.Fail(sb.ToString()); + return null; + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/TamperStreamContentsPolicy.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/TamperStreamContentsPolicy.cs index f4198e9dfd532..7e6c78117f53b 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/TamperStreamContentsPolicy.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/TamperStreamContentsPolicy.cs @@ -14,7 +14,7 @@ internal class TamperStreamContentsPolicy : HttpPipelineSynchronousPolicy /// /// Default tampering that changes the first byte of the stream. /// - private static readonly Func _defaultStreamTransform = stream => + private static Func GetTamperByteStreamTransform(long position) => stream => { if (stream is not MemoryStream) { @@ -23,10 +23,10 @@ internal class TamperStreamContentsPolicy : HttpPipelineSynchronousPolicy stream = buffer; } - stream.Position = 0; + stream.Position = position; var firstByte = stream.ReadByte(); - stream.Position = 0; + stream.Position = position; stream.WriteByte((byte)((firstByte + 1) % byte.MaxValue)); stream.Position = 0; @@ -37,9 +37,12 @@ internal class TamperStreamContentsPolicy : HttpPipelineSynchronousPolicy public TamperStreamContentsPolicy(Func streamTransform = default) { - _streamTransform = streamTransform ?? _defaultStreamTransform; + _streamTransform = streamTransform ?? GetTamperByteStreamTransform(0); } + public static TamperStreamContentsPolicy TamperByteAt(long position) + => new(GetTamperByteStreamTransform(position)); + public bool TransformRequestBody { get; set; } public bool TransformResponseBody { get; set; } diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs index c18492d2fb4dd..248acf8811960 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs @@ -5,10 +5,13 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Security.Cryptography; using System.Threading.Tasks; using Azure.Core; +using Azure.Core.Diagnostics; +using Azure.Core.Pipeline; using Azure.Core.TestFramework; -using FastSerialization; +using Azure.Storage.Shared; using NUnit.Framework; namespace Azure.Storage.Test.Shared @@ -190,21 +193,15 @@ protected string GetNewResourceName() /// The actual checksum value expected to be on the request, if known. Defaults to no specific value expected or checked. /// /// An assertion to put into a pipeline policy. - internal static Action GetRequestChecksumAssertion(StorageChecksumAlgorithm algorithm, Func isChecksumExpected = default, byte[] expectedChecksum = default) + internal static Action GetRequestChecksumHeaderAssertion(StorageChecksumAlgorithm algorithm, Func isChecksumExpected = default, byte[] expectedChecksum = default) { // action to assert a request header is as expected - void AssertChecksum(RequestHeaders headers, string headerName) + void AssertChecksum(Request req, string headerName) { - if (headers.TryGetValue(headerName, out string checksum)) + string checksum = req.AssertHeaderPresent(headerName); + if (expectedChecksum != default) { - if (expectedChecksum != default) - { - Assert.AreEqual(Convert.ToBase64String(expectedChecksum), checksum); - } - } - else - { - Assert.Fail($"{headerName} expected on request but was not found."); + Assert.AreEqual(Convert.ToBase64String(expectedChecksum), checksum); } }; @@ -219,14 +216,39 @@ void AssertChecksum(RequestHeaders headers, string headerName) switch (algorithm.ResolveAuto()) { case StorageChecksumAlgorithm.MD5: - AssertChecksum(request.Headers, "Content-MD5"); + AssertChecksum(request, "Content-MD5"); break; case StorageChecksumAlgorithm.StorageCrc64: - AssertChecksum(request.Headers, "x-ms-content-crc64"); + AssertChecksum(request, Constants.StructuredMessage.StructuredMessageHeader); break; default: - throw new Exception($"Bad {nameof(StorageChecksumAlgorithm)} provided to {nameof(GetRequestChecksumAssertion)}."); + throw new Exception($"Bad {nameof(StorageChecksumAlgorithm)} provided to {nameof(GetRequestChecksumHeaderAssertion)}."); + } + }; + } + + internal static Action GetRequestStructuredMessageAssertion( + StructuredMessage.Flags flags, + Func isStructuredMessageExpected = default, + long? structuredContentSegmentLength = default) + { + return request => + { + // filter some requests out with predicate + if (isStructuredMessageExpected != default && !isStructuredMessageExpected(request)) + { + return; } + + Assert.That(request.Headers.TryGetValue("x-ms-structured-body", out string structuredBody)); + Assert.That(structuredBody, Does.Contain("XSM/1.0")); + if (flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + Assert.That(structuredBody, Does.Contain("crc64")); + } + + Assert.That(request.Headers.TryGetValue("Content-Length", out string contentLength)); + Assert.That(request.Headers.TryGetValue("x-ms-structured-content-length", out string structuredContentLength)); }; } @@ -278,32 +300,66 @@ void AssertChecksum(ResponseHeaders headers, string headerName) AssertChecksum(response.Headers, "Content-MD5"); break; case StorageChecksumAlgorithm.StorageCrc64: - AssertChecksum(response.Headers, "x-ms-content-crc64"); + AssertChecksum(response.Headers, Constants.StructuredMessage.StructuredMessageHeader); break; default: - throw new Exception($"Bad {nameof(StorageChecksumAlgorithm)} provided to {nameof(GetRequestChecksumAssertion)}."); + throw new Exception($"Bad {nameof(StorageChecksumAlgorithm)} provided to {nameof(GetRequestChecksumHeaderAssertion)}."); } }; } + internal static Action GetResponseStructuredMessageAssertion( + StructuredMessage.Flags flags, + Func isStructuredMessageExpected = default) + { + return response => + { + // filter some requests out with predicate + if (isStructuredMessageExpected != default && !isStructuredMessageExpected(response)) + { + return; + } + + Assert.That(response.Headers.TryGetValue("x-ms-structured-body", out string structuredBody)); + Assert.That(structuredBody, Does.Contain("XSM/1.0")); + if (flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + Assert.That(structuredBody, Does.Contain("crc64")); + } + + Assert.That(response.Headers.TryGetValue("Content-Length", out string contentLength)); + Assert.That(response.Headers.TryGetValue("x-ms-structured-content-length", out string structuredContentLength)); + }; + } + /// /// Asserts the service returned an error that expected checksum did not match checksum on upload. /// /// Async action to upload data to service. /// Checksum algorithm used. - internal static void AssertWriteChecksumMismatch(AsyncTestDelegate writeAction, StorageChecksumAlgorithm algorithm) + internal static void AssertWriteChecksumMismatch( + AsyncTestDelegate writeAction, + StorageChecksumAlgorithm algorithm, + bool expectStructuredMessage = false) { var exception = ThrowsOrInconclusiveAsync(writeAction); - switch (algorithm.ResolveAuto()) + if (expectStructuredMessage) { - case StorageChecksumAlgorithm.MD5: - Assert.AreEqual("Md5Mismatch", exception.ErrorCode); - break; - case StorageChecksumAlgorithm.StorageCrc64: - Assert.AreEqual("Crc64Mismatch", exception.ErrorCode); - break; - default: - throw new ArgumentException("Test arguments contain bad algorithm specifier."); + Assert.That(exception.ErrorCode, Is.EqualTo("Crc64Mismatch")); + } + else + { + switch (algorithm.ResolveAuto()) + { + case StorageChecksumAlgorithm.MD5: + Assert.That(exception.ErrorCode, Is.EqualTo("Md5Mismatch")); + break; + case StorageChecksumAlgorithm.StorageCrc64: + Assert.That(exception.ErrorCode, Is.EqualTo("Crc64Mismatch")); + break; + default: + throw new ArgumentException("Test arguments contain bad algorithm specifier."); + } } } #endregion @@ -348,6 +404,7 @@ public virtual async Task UploadPartitionSuccessfulHashComputation(StorageChecks await using IDisposingContainer disposingContainer = await GetDisposingContainerAsync(); // Arrange + bool expectStructuredMessage = algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64; const int dataLength = Constants.KB; var data = GetRandomBuffer(dataLength); var validationOptions = new UploadTransferValidationOptions @@ -356,7 +413,10 @@ public virtual async Task UploadPartitionSuccessfulHashComputation(StorageChecks }; // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion(algorithm)); + var assertion = algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 + ? GetRequestStructuredMessageAssertion(StructuredMessage.Flags.StorageCrc64, null, dataLength) + : GetRequestChecksumHeaderAssertion(algorithm); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: assertion); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -406,7 +466,11 @@ public virtual async Task UploadPartitionUsePrecalculatedHash(StorageChecksumAlg }; // make pipeline assertion for checking precalculated checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion(algorithm, expectedChecksum: precalculatedChecksum)); + // precalculated partition upload will never use structured message. always check header + var assertion = GetRequestChecksumHeaderAssertion( + algorithm, + expectedChecksum: algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 ? default : precalculatedChecksum); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: assertion); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -423,12 +487,12 @@ public virtual async Task UploadPartitionUsePrecalculatedHash(StorageChecksumAlg AsyncTestDelegate operation = async () => await UploadPartitionAsync(client, stream, validationOptions); // Assert - AssertWriteChecksumMismatch(operation, algorithm); + AssertWriteChecksumMismatch(operation, algorithm, algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64); } } [TestCaseSource(nameof(GetValidationAlgorithms))] - public virtual async Task UploadPartitionMismatchedHashThrows(StorageChecksumAlgorithm algorithm) + public virtual async Task UploadPartitionTamperedStreamThrows(StorageChecksumAlgorithm algorithm) { await using IDisposingContainer disposingContainer = await GetDisposingContainerAsync(); @@ -441,7 +505,7 @@ public virtual async Task UploadPartitionMismatchedHashThrows(StorageChecksumAlg }; // Tamper with stream contents in the pipeline to simulate silent failure in the transit layer - var streamTamperPolicy = new TamperStreamContentsPolicy(); + var streamTamperPolicy = TamperStreamContentsPolicy.TamperByteAt(100); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(streamTamperPolicy, HttpPipelinePosition.PerCall); @@ -456,9 +520,10 @@ public virtual async Task UploadPartitionMismatchedHashThrows(StorageChecksumAlg // Act streamTamperPolicy.TransformRequestBody = true; AsyncTestDelegate operation = async () => await UploadPartitionAsync(client, stream, validationOptions); - + using var listener = AzureEventSourceListener.CreateConsoleLogger(); // Assert - AssertWriteChecksumMismatch(operation, algorithm); + AssertWriteChecksumMismatch(operation, algorithm, + expectStructuredMessage: algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64); } } @@ -473,7 +538,10 @@ public virtual async Task UploadPartitionUsesDefaultClientValidationOptions( var data = GetRandomBuffer(dataLength); // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion(clientAlgorithm)); + var assertion = clientAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 + ? GetRequestStructuredMessageAssertion(StructuredMessage.Flags.StorageCrc64, null, dataLength) + : GetRequestChecksumHeaderAssertion(clientAlgorithm); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: assertion); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -512,7 +580,10 @@ public virtual async Task UploadPartitionOverwritesDefaultClientValidationOption }; // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion(overrideAlgorithm)); + var assertion = overrideAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 + ? GetRequestStructuredMessageAssertion(StructuredMessage.Flags.StorageCrc64, null, dataLength) + : GetRequestChecksumHeaderAssertion(overrideAlgorithm); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: assertion); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -555,10 +626,14 @@ public virtual async Task UploadPartitionDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (request.Headers.Contains("x-ms-content-crc64")) + if (request.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } + if (request.Headers.Contains("x-ms-structured-body")) + { + Assert.Fail($"Structured body used when none expected."); + } }); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -601,9 +676,11 @@ public virtual async Task OpenWriteSuccessfulHashComputation( }; // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion(algorithm)); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumHeaderAssertion(algorithm)); var clientOptions = ClientBuilder.GetOptions(); + //ObserveStructuredMessagePolicy observe = new(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); + //clientOptions.AddPolicy(observe, HttpPipelinePosition.BeforeTransport); var client = await GetResourceClientAsync( disposingContainer.Container, @@ -616,6 +693,7 @@ public virtual async Task OpenWriteSuccessfulHashComputation( using var writeStream = await OpenWriteAsync(client, validationOptions, streamBufferSize); // Assert + //using var obsv = observe.CheckRequestScope(); using (checksumPipelineAssertion.CheckRequestScope()) { foreach (var _ in Enumerable.Range(0, streamWrites)) @@ -644,7 +722,7 @@ public virtual async Task OpenWriteMismatchedHashThrows(StorageChecksumAlgorithm // Tamper with stream contents in the pipeline to simulate silent failure in the transit layer var clientOptions = ClientBuilder.GetOptions(); - var tamperPolicy = new TamperStreamContentsPolicy(); + var tamperPolicy = TamperStreamContentsPolicy.TamperByteAt(100); clientOptions.AddPolicy(tamperPolicy, HttpPipelinePosition.PerCall); var client = await GetResourceClientAsync( @@ -682,7 +760,7 @@ public virtual async Task OpenWriteUsesDefaultClientValidationOptions( var data = GetRandomBuffer(dataLength); // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion(clientAlgorithm)); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumHeaderAssertion(clientAlgorithm)); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -726,7 +804,7 @@ public virtual async Task OpenWriteOverwritesDefaultClientValidationOptions( }; // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion(overrideAlgorithm)); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumHeaderAssertion(overrideAlgorithm)); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -774,7 +852,7 @@ public virtual async Task OpenWriteDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (request.Headers.Contains("x-ms-content-crc64")) + if (request.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -886,7 +964,7 @@ public virtual async Task ParallelUploadSplitSuccessfulHashComputation(StorageCh // make pipeline assertion for checking checksum was present on upload var checksumPipelineAssertion = new AssertMessageContentsPolicy( - checkRequest: GetRequestChecksumAssertion(algorithm, isChecksumExpected: ParallelUploadIsChecksumExpected)); + checkRequest: GetRequestChecksumHeaderAssertion(algorithm, isChecksumExpected: ParallelUploadIsChecksumExpected)); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -923,8 +1001,10 @@ public virtual async Task ParallelUploadOneShotSuccessfulHashComputation(Storage }; // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy( - checkRequest: GetRequestChecksumAssertion(algorithm, isChecksumExpected: ParallelUploadIsChecksumExpected)); + var assertion = algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 + ? GetRequestStructuredMessageAssertion(StructuredMessage.Flags.StorageCrc64, ParallelUploadIsChecksumExpected, dataLength) + : GetRequestChecksumHeaderAssertion(algorithm, isChecksumExpected: ParallelUploadIsChecksumExpected); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: assertion); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -981,7 +1061,7 @@ public virtual async Task ParallelUploadPrecalculatedComposableHashAccepted(Stor PrecalculatedChecksum = hash }; - var client = await GetResourceClientAsync(disposingContainer.Container, dataLength); + var client = await GetResourceClientAsync(disposingContainer.Container, dataLength, createResource: true); // Act await DoesNotThrowOrInconclusiveAsync( @@ -1011,8 +1091,10 @@ public virtual async Task ParallelUploadUsesDefaultClientValidationOptions( }; // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion( - clientAlgorithm, isChecksumExpected: ParallelUploadIsChecksumExpected)); + var assertion = clientAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && !split + ? GetRequestStructuredMessageAssertion(StructuredMessage.Flags.StorageCrc64, ParallelUploadIsChecksumExpected, dataLength) + : GetRequestChecksumHeaderAssertion(clientAlgorithm, isChecksumExpected: ParallelUploadIsChecksumExpected); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: assertion); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -1063,8 +1145,10 @@ public virtual async Task ParallelUploadOverwritesDefaultClientValidationOptions }; // make pipeline assertion for checking checksum was present on upload - var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumAssertion( - overrideAlgorithm, isChecksumExpected: ParallelUploadIsChecksumExpected)); + var assertion = overrideAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && !split + ? GetRequestStructuredMessageAssertion(StructuredMessage.Flags.StorageCrc64, ParallelUploadIsChecksumExpected, dataLength) + : GetRequestChecksumHeaderAssertion(overrideAlgorithm, isChecksumExpected: ParallelUploadIsChecksumExpected); + var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: assertion); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -1119,7 +1203,7 @@ public virtual async Task ParallelUploadDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (request.Headers.Contains("x-ms-content-crc64")) + if (request.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -1184,15 +1268,17 @@ public virtual async Task ParallelDownloadSuccessfulHashVerification( }; // Act - var dest = new MemoryStream(); + byte[] dest; + using (MemoryStream ms = new()) using (checksumPipelineAssertion.CheckRequestScope()) { - await ParallelDownloadAsync(client, dest, validationOptions, transferOptions); + await ParallelDownloadAsync(client, ms, validationOptions, transferOptions); + dest = ms.ToArray(); } // Assert // Assertion was in the pipeline and the SDK not throwing means the checksum was validated - Assert.IsTrue(dest.ToArray().SequenceEqual(data)); + Assert.IsTrue(dest.SequenceEqual(data)); } [Test] @@ -1357,7 +1443,7 @@ public virtual async Task ParallelDownloadDisablesDefaultClientValidationOptions { Assert.Fail($"Hash found when none expected."); } - if (response.Headers.Contains("x-ms-content-crc64")) + if (response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -1565,7 +1651,7 @@ public virtual async Task OpenReadDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (response.Headers.Contains("x-ms-content-crc64")) + if (response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -1615,7 +1701,7 @@ public virtual async Task DownloadSuccessfulHashVerification(StorageChecksumAlgo var validationOptions = new DownloadTransferValidationOptions { ChecksumAlgorithm = algorithm }; // Act - var dest = new MemoryStream(); + using var dest = new MemoryStream(); var response = await DownloadPartitionAsync(client, dest, validationOptions, new HttpRange(length: data.Length)); // Assert @@ -1626,13 +1712,71 @@ public virtual async Task DownloadSuccessfulHashVerification(StorageChecksumAlgo Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains("x-ms-content-crc64")); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); break; } - Assert.IsTrue(dest.ToArray().SequenceEqual(data)); + var result = dest.ToArray(); + Assert.IsTrue(result.SequenceEqual(data)); + } + + [TestCase(StorageChecksumAlgorithm.StorageCrc64, Constants.StructuredMessage.MaxDownloadCrcWithHeader, false, false)] + [TestCase(StorageChecksumAlgorithm.StorageCrc64, Constants.StructuredMessage.MaxDownloadCrcWithHeader-1, false, false)] + [TestCase(StorageChecksumAlgorithm.StorageCrc64, Constants.StructuredMessage.MaxDownloadCrcWithHeader+1, true, false)] + [TestCase(StorageChecksumAlgorithm.MD5, Constants.StructuredMessage.MaxDownloadCrcWithHeader+1, false, true)] + public virtual async Task DownloadApporpriatelyUsesStructuredMessage( + StorageChecksumAlgorithm algorithm, + int? downloadLen, + bool expectStructuredMessage, + bool expectThrow) + { + await using IDisposingContainer disposingContainer = await GetDisposingContainerAsync(); + + // Arrange + const int dataLength = Constants.KB; + var data = GetRandomBuffer(dataLength); + + var resourceName = GetNewResourceName(); + var client = await GetResourceClientAsync( + disposingContainer.Container, + resourceLength: dataLength, + createResource: true, + resourceName: resourceName); + await SetupDataAsync(client, new MemoryStream(data)); + + // make pipeline assertion for checking checksum was present on download + HttpPipelinePolicy checksumPipelineAssertion = new AssertMessageContentsPolicy(checkResponse: expectStructuredMessage + ? GetResponseStructuredMessageAssertion(StructuredMessage.Flags.StorageCrc64) + : GetResponseChecksumAssertion(algorithm)); + TClientOptions clientOptions = ClientBuilder.GetOptions(); + clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); + + client = await GetResourceClientAsync( + disposingContainer.Container, + resourceLength: dataLength, + resourceName: resourceName, + createResource: false, + downloadAlgorithm: algorithm, + options: clientOptions); + + var validationOptions = new DownloadTransferValidationOptions { ChecksumAlgorithm = algorithm }; + + // Act + var dest = new MemoryStream(); + AsyncTestDelegate operation = async () => await DownloadPartitionAsync( + client, dest, validationOptions, downloadLen.HasValue ? new HttpRange(length: downloadLen.Value) : default); + // Assert (policies checked use of content validation) + if (expectThrow) + { + Assert.That(operation, Throws.TypeOf()); + } + else + { + Assert.That(operation, Throws.Nothing); + Assert.IsTrue(dest.ToArray().SequenceEqual(data)); + } } [Test, Combinatorial] @@ -1658,7 +1802,9 @@ public virtual async Task DownloadHashMismatchThrows( // alter response contents in pipeline, forcing a checksum mismatch on verification step var clientOptions = ClientBuilder.GetOptions(); - clientOptions.AddPolicy(new TamperStreamContentsPolicy() { TransformResponseBody = true }, HttpPipelinePosition.PerCall); + var tamperPolicy = TamperStreamContentsPolicy.TamperByteAt(50); + tamperPolicy.TransformResponseBody = true; + clientOptions.AddPolicy(tamperPolicy, HttpPipelinePosition.PerCall); client = await GetResourceClientAsync( disposingContainer.Container, createResource: false, @@ -1670,7 +1816,7 @@ public virtual async Task DownloadHashMismatchThrows( AsyncTestDelegate operation = async () => await DownloadPartitionAsync(client, dest, validationOptions, new HttpRange(length: data.Length)); // Assert - if (validate) + if (validate || algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64) { // SDK responsible for finding bad checksum. Throw. ThrowsOrInconclusiveAsync(operation); @@ -1728,7 +1874,7 @@ public virtual async Task DownloadUsesDefaultClientValidationOptions( Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains("x-ms-content-crc64")); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); @@ -1788,7 +1934,7 @@ public virtual async Task DownloadOverwritesDefaultClientValidationOptions( Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains("x-ms-content-crc64")); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); @@ -1827,7 +1973,7 @@ public virtual async Task DownloadDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (response.Headers.Contains("x-ms-content-crc64")) + if (response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -1850,7 +1996,54 @@ public virtual async Task DownloadDisablesDefaultClientValidationOptions( // Assert // no policies this time; just check response headers Assert.False(response.Headers.Contains("Content-MD5")); - Assert.False(response.Headers.Contains("x-ms-content-crc64")); + Assert.False(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)); + Assert.IsTrue(dest.ToArray().SequenceEqual(data)); + } + + [Test] + public virtual async Task DownloadRecoversFromInterruptWithValidation( + [ValueSource(nameof(GetValidationAlgorithms))] StorageChecksumAlgorithm algorithm) + { + using var _ = AzureEventSourceListener.CreateConsoleLogger(); + int dataLen = algorithm.ResolveAuto() switch { + StorageChecksumAlgorithm.StorageCrc64 => 5 * Constants.MB, // >4MB for multisegment + _ => Constants.KB, + }; + + await using IDisposingContainer disposingContainer = await GetDisposingContainerAsync(); + + // Arrange + var data = GetRandomBuffer(dataLen); + + TClientOptions options = ClientBuilder.GetOptions(); + options.AddPolicy(new FaultyDownloadPipelinePolicy(dataLen - 512, new IOException(), () => { }), HttpPipelinePosition.BeforeTransport); + var client = await GetResourceClientAsync( + disposingContainer.Container, + resourceLength: dataLen, + createResource: true, + options: options); + await SetupDataAsync(client, new MemoryStream(data)); + + var validationOptions = new DownloadTransferValidationOptions { ChecksumAlgorithm = algorithm }; + + // Act + var dest = new MemoryStream(); + var response = await DownloadPartitionAsync(client, dest, validationOptions, new HttpRange(length: data.Length)); + + // Assert + // no policies this time; just check response headers + switch (algorithm.ResolveAuto()) + { + case StorageChecksumAlgorithm.MD5: + Assert.True(response.Headers.Contains("Content-MD5")); + break; + case StorageChecksumAlgorithm.StorageCrc64: + Assert.True(response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)); + break; + default: + Assert.Fail("Test can't validate given algorithm type."); + break; + } Assert.IsTrue(dest.ToArray().SequenceEqual(data)); } #endregion @@ -1891,7 +2084,7 @@ public async Task RoundtripWIthDefaults() // make pipeline assertion for checking checksum was present on upload AND download var checksumPipelineAssertion = new AssertMessageContentsPolicy( - checkRequest: GetRequestChecksumAssertion(expectedAlgorithm, isChecksumExpected: ParallelUploadIsChecksumExpected), + checkRequest: GetRequestChecksumHeaderAssertion(expectedAlgorithm, isChecksumExpected: ParallelUploadIsChecksumExpected), checkResponse: GetResponseChecksumAssertion(expectedAlgorithm)); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs new file mode 100644 index 0000000000000..a0f9158040b11 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -0,0 +1,246 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers.Binary; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Storage.Shared; +using Azure.Storage.Test.Shared; +using Microsoft.Diagnostics.Tracing.Parsers.AspNet; +using Moq; +using NUnit.Framework; + +namespace Azure.Storage.Tests; + +[TestFixture(true)] +[TestFixture(false)] +public class StructuredMessageDecodingRetriableStreamTests +{ + public bool Async { get; } + + public StructuredMessageDecodingRetriableStreamTests(bool async) + { + Async = async; + } + + private Mock AllExceptionsRetry() + { + Mock mock = new(MockBehavior.Strict); + mock.Setup(rc => rc.IsRetriableException(It.IsAny())).Returns(true); + return mock; + } + + [Test] + public async ValueTask UninterruptedStream() + { + byte[] data = new Random().NextBytesInline(4 * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream + using (Stream src = new MemoryStream(data)) + using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream(src, new(), default, default, default, default, default, 1)) + using (Stream dst = new MemoryStream(dest)) + { + await retriableSrc.CopyToInternal(dst, Async, default); + } + + Assert.AreEqual(data, dest); + } + + [Test] + public async Task Interrupt_DataIntact([Values(true, false)] bool multipleInterrupts) + { + const int segments = 4; + const int segmentLen = Constants.KB; + const int readLen = 128; + const int interruptPos = segmentLen + (3 * readLen) + 10; + + Random r = new(); + byte[] data = r.NextBytesInline(segments * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.RawDecodedData initialDecodedData = new() + { + TotalSegments = segments, + InnerStreamLength = data.Length, + Flags = StructuredMessage.Flags.StorageCrc64 + }; + // for test purposes, initialize a DecodedData, since we are not actively decoding in this test + initialDecodedData.SegmentCrcs.Add((BinaryPrimitives.ReadUInt64LittleEndian(r.NextBytesInline(StructuredMessage.Crc64Length)), segmentLen)); + + (Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData) Factory(long offset, bool faulty) + { + Stream stream = new MemoryStream(data, (int)offset, data.Length - (int)offset); + if (faulty) + { + stream = new FaultyStream(stream, interruptPos, 1, new Exception(), () => { }); + } + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.RawDecodedData decodedData = new() + { + TotalSegments = segments, + InnerStreamLength = data.Length, + Flags = StructuredMessage.Flags.StorageCrc64, + }; + // for test purposes, initialize a DecodedData, since we are not actively decoding in this test + initialDecodedData.SegmentCrcs.Add((BinaryPrimitives.ReadUInt64LittleEndian(r.NextBytesInline(StructuredMessage.Crc64Length)), segmentLen)); + return (stream, decodedData); + } + + // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream + using (Stream src = new MemoryStream(data)) + using (Stream faultySrc = new FaultyStream(src, interruptPos, 1, new Exception(), () => { })) + using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + faultySrc, + initialDecodedData, + default, + offset => Factory(offset, multipleInterrupts), + offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + null, + AllExceptionsRetry().Object, + int.MaxValue)) + using (Stream dst = new MemoryStream(dest)) + { + await retriableSrc.CopyToInternal(dst, readLen, Async, default); + } + + Assert.AreEqual(data, dest); + } + + [Test] + public async Task Interrupt_AppropriateRewind() + { + const int segments = 2; + const int segmentLen = Constants.KB; + const int dataLen = segments * segmentLen; + const int readLen = segmentLen / 4; + const int interruptOffset = 10; + const int interruptPos = segmentLen + (2 * readLen) + interruptOffset; + Random r = new(); + + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.RawDecodedData initialDecodedData = new() + { + TotalSegments = segments, + InnerStreamLength = segments * segmentLen, + Flags = StructuredMessage.Flags.StorageCrc64, + }; + // By the time of interrupt, there will be one segment reported + initialDecodedData.SegmentCrcs.Add((BinaryPrimitives.ReadUInt64LittleEndian(r.NextBytesInline(StructuredMessage.Crc64Length)), segmentLen)); + + Mock mock = new(MockBehavior.Strict); + mock.SetupGet(s => s.CanRead).Returns(true); + mock.SetupGet(s => s.CanSeek).Returns(false); + if (Async) + { + mock.SetupSequence(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), default)) + .Returns(Task.FromResult(readLen)) // start first segment + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) // finish first segment + .Returns(Task.FromResult(readLen)) // start second segment + .Returns(Task.FromResult(readLen)) + // faulty stream interrupt + .Returns(Task.FromResult(readLen * 2)) // restart second segment. fast-forward uses an internal 4KB buffer, so it will leap the 512 byte catchup all at once + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) // end second segment + .Returns(Task.FromResult(0)) // signal end of stream + .Returns(Task.FromResult(0)) // second signal needed for stream wrapping reasons + ; + } + else + { + mock.SetupSequence(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(readLen) // start first segment + .Returns(readLen) + .Returns(readLen) + .Returns(readLen) // finish first segment + .Returns(readLen) // start second segment + .Returns(readLen) + // faulty stream interrupt + .Returns(readLen * 2) // restart second segment. fast-forward uses an internal 4KB buffer, so it will leap the 512 byte catchup all at once + .Returns(readLen) + .Returns(readLen) // end second segment + .Returns(0) // signal end of stream + .Returns(0) // second signal needed for stream wrapping reasons + ; + } + Stream faultySrc = new FaultyStream(mock.Object, interruptPos, 1, new Exception(), default); + Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + faultySrc, + initialDecodedData, + default, + offset => (mock.Object, new()), + offset => new(Task.FromResult((mock.Object, new StructuredMessageDecodingStream.RawDecodedData()))), + null, + AllExceptionsRetry().Object, + 1); + + int totalRead = 0; + int read = 0; + byte[] buf = new byte[readLen]; + if (Async) + { + while ((read = await retriableSrc.ReadAsync(buf, 0, buf.Length)) > 0) + { + totalRead += read; + } + } + else + { + while ((read = retriableSrc.Read(buf, 0, buf.Length)) > 0) + { + totalRead += read; + } + } + await retriableSrc.CopyToInternal(Stream.Null, readLen, Async, default); + + // Asserts we read exactly the data length, excluding the fastforward of the inner stream + Assert.That(totalRead, Is.EqualTo(dataLen)); + } + + [Test] + public async Task Interrupt_ProperDecode([Values(true, false)] bool multipleInterrupts) + { + // decoding stream inserts a buffered layer of 4 KB. use larger sizes to avoid interference from it. + const int segments = 4; + const int segmentLen = 128 * Constants.KB; + const int readLen = 8 * Constants.KB; + const int interruptPos = segmentLen + (3 * readLen) + 10; + + Random r = new(); + byte[] data = r.NextBytesInline(segments * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + (Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData) Factory(long offset, bool faulty) + { + Stream stream = new MemoryStream(data, (int)offset, data.Length - (int)offset); + stream = new StructuredMessageEncodingStream(stream, segmentLen, StructuredMessage.Flags.StorageCrc64); + if (faulty) + { + stream = new FaultyStream(stream, interruptPos, 1, new Exception(), () => { }); + } + return StructuredMessageDecodingStream.WrapStream(stream); + } + + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = Factory(0, true); + using Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + decodingStream, + decodedData, + default, + offset => Factory(offset, multipleInterrupts), + offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + null, + AllExceptionsRetry().Object, + int.MaxValue); + using Stream dst = new MemoryStream(dest); + + await retriableSrc.CopyToInternal(dst, readLen, Async, default); + + Assert.AreEqual(data, dest); + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs new file mode 100644 index 0000000000000..2789672df4976 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers.Binary; +using System.Dynamic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.Storage.Blobs.Tests; +using Azure.Storage.Shared; +using NUnit.Framework; +using static Azure.Storage.Shared.StructuredMessage; + +namespace Azure.Storage.Tests +{ + [TestFixture(ReadMethod.SyncArray)] + [TestFixture(ReadMethod.AsyncArray)] +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + [TestFixture(ReadMethod.SyncSpan)] + [TestFixture(ReadMethod.AsyncMemory)] +#endif + public class StructuredMessageDecodingStreamTests + { + // Cannot just implement as passthru in the stream + // Must test each one + public enum ReadMethod + { + SyncArray, + AsyncArray, +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + SyncSpan, + AsyncMemory +#endif + } + + public ReadMethod Method { get; } + + public StructuredMessageDecodingStreamTests(ReadMethod method) + { + Method = method; + } + + private class CopyStreamException : Exception + { + public long TotalCopied { get; } + + public CopyStreamException(Exception inner, long totalCopied) + : base($"Failed read after {totalCopied}-many bytes.", inner) + { + TotalCopied = totalCopied; + } + } + private async ValueTask CopyStream(Stream source, Stream destination, int bufferSize = 81920) // number default for CopyTo impl + { + byte[] buf = new byte[bufferSize]; + int read; + long totalRead = 0; + try + { + switch (Method) + { + case ReadMethod.SyncArray: + while ((read = source.Read(buf, 0, bufferSize)) > 0) + { + totalRead += read; + destination.Write(buf, 0, read); + } + break; + case ReadMethod.AsyncArray: + while ((read = await source.ReadAsync(buf, 0, bufferSize)) > 0) + { + totalRead += read; + await destination.WriteAsync(buf, 0, read); + } + break; +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + case ReadMethod.SyncSpan: + while ((read = source.Read(new Span(buf))) > 0) + { + totalRead += read; + destination.Write(new Span(buf, 0, read)); + } + break; + case ReadMethod.AsyncMemory: + while ((read = await source.ReadAsync(new Memory(buf))) > 0) + { + totalRead += read; + await destination.WriteAsync(new Memory(buf, 0, read)); + } + break; +#endif + } + destination.Flush(); + } + catch (Exception ex) + { + throw new CopyStreamException(ex, totalRead); + } + return totalRead; + } + + [Test] + [Pairwise] + public async Task DecodesData( + [Values(2048, 2005)] int dataLength, + [Values(default, 512)] int? seglen, + [Values(8*Constants.KB, 512, 530, 3)] int readLen, + [Values(true, false)] bool useCrc) + { + int segmentContentLength = seglen ?? int.MaxValue; + Flags flags = useCrc ? Flags.StorageCrc64 : Flags.None; + + byte[] originalData = new byte[dataLength]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentContentLength, flags); + + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); + byte[] decodedData; + using (MemoryStream dest = new()) + { + await CopyStream(decodingStream, dest, readLen); + decodedData = dest.ToArray(); + } + + Assert.That(new Span(decodedData).SequenceEqual(originalData)); + } + + [Test] + public void BadStreamBadVersion() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + encodedData[0] = byte.MaxValue; + + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public async Task BadSegmentCrcThrows() + { + const int segmentLength = 256; + Random r = new(); + + byte[] originalData = new byte[2048]; + r.NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentLength, Flags.StorageCrc64); + + const int badBytePos = 1024; + encodedData[badBytePos] = (byte)~encodedData[badBytePos]; + + MemoryStream encodedDataStream = new(encodedData); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(encodedDataStream); + + // manual try/catch to validate the proccess failed mid-stream rather than the end + const int copyBufferSize = 4; + bool caught = false; + try + { + await CopyStream(decodingStream, Stream.Null, copyBufferSize); + } + catch (CopyStreamException ex) + { + caught = true; + Assert.That(ex.TotalCopied, Is.LessThanOrEqualTo(badBytePos)); + } + Assert.That(caught); + } + + [Test] + public void BadStreamCrcThrows() + { + const int segmentLength = 256; + Random r = new(); + + byte[] originalData = new byte[2048]; + r.NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentLength, Flags.StorageCrc64); + + encodedData[originalData.Length - 1] = (byte)~encodedData[originalData.Length - 1]; + + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public void BadStreamWrongContentLength() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + BinaryPrimitives.WriteInt64LittleEndian(new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), 123456789L); + + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [TestCase(-1)] + [TestCase(1)] + public void BadStreamWrongSegmentCount(int difference) + { + const int dataSize = 1024; + const int segmentSize = 256; + const int numSegments = 4; + + byte[] originalData = new byte[dataSize]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentSize, Flags.StorageCrc64); + + // rewrite the segment count to be different than the actual number of segments + BinaryPrimitives.WriteInt16LittleEndian( + new Span(encodedData, V1_0.StreamHeaderSegmentCountOffset, 2), (short)(numSegments + difference)); + + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public void BadStreamWrongSegmentNum() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + BinaryPrimitives.WriteInt16LittleEndian( + new Span(encodedData, V1_0.StreamHeaderLength + V1_0.SegmentHeaderNumOffset, 2), 123); + + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + [Combinatorial] + public async Task BadStreamWrongContentLength( + [Values(-1, 1)] int difference, + [Values(true, false)] bool lengthProvided) + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + BinaryPrimitives.WriteInt64LittleEndian( + new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), + encodedData.Length + difference); + + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream( + new MemoryStream(encodedData), + lengthProvided ? (long?)encodedData.Length : default); + + // manual try/catch with tiny buffer to validate the proccess failed mid-stream rather than the end + const int copyBufferSize = 4; + bool caught = false; + try + { + await CopyStream(decodingStream, Stream.Null, copyBufferSize); + } + catch (CopyStreamException ex) + { + caught = true; + if (lengthProvided) + { + Assert.That(ex.TotalCopied, Is.EqualTo(0)); + } + else + { + Assert.That(ex.TotalCopied, Is.EqualTo(originalData.Length)); + } + } + Assert.That(caught); + } + + [Test] + public void BadStreamMissingExpectedStreamFooter() + { + byte[] originalData = new byte[1024]; + new Random().NextBytes(originalData); + byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, 256, Flags.StorageCrc64); + + byte[] brokenData = new byte[encodedData.Length - Crc64Length]; + new Span(encodedData, 0, encodedData.Length - Crc64Length).CopyTo(brokenData); + + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(brokenData)); + Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); + } + + [Test] + public void NoSeek() + { + (Stream stream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream()); + + Assert.That(stream.CanSeek, Is.False); + Assert.That(() => stream.Length, Throws.TypeOf()); + Assert.That(() => stream.Position, Throws.TypeOf()); + Assert.That(() => stream.Position = 0, Throws.TypeOf()); + Assert.That(() => stream.Seek(0, SeekOrigin.Begin), Throws.TypeOf()); + } + + [Test] + public void NoWrite() + { + (Stream stream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream()); + byte[] data = new byte[1024]; + new Random().NextBytes(data); + + Assert.That(stream.CanWrite, Is.False); + Assert.That(() => stream.Write(data, 0, data.Length), + Throws.TypeOf()); + Assert.That(async () => await stream.WriteAsync(data, 0, data.Length, CancellationToken.None), + Throws.TypeOf()); +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + Assert.That(() => stream.Write(new Span(data)), + Throws.TypeOf()); + Assert.That(async () => await stream.WriteAsync(new Memory(data), CancellationToken.None), + Throws.TypeOf()); +#endif + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageEncodingStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageEncodingStreamTests.cs new file mode 100644 index 0000000000000..e0f91dee7de3a --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageEncodingStreamTests.cs @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers.Binary; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Azure.Storage.Blobs.Tests; +using Azure.Storage.Shared; +using NUnit.Framework; +using static Azure.Storage.Shared.StructuredMessage; + +namespace Azure.Storage.Tests +{ + [TestFixture(ReadMethod.SyncArray)] + [TestFixture(ReadMethod.AsyncArray)] +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + [TestFixture(ReadMethod.SyncSpan)] + [TestFixture(ReadMethod.AsyncMemory)] +#endif + public class StructuredMessageEncodingStreamTests + { + // Cannot just implement as passthru in the stream + // Must test each one + public enum ReadMethod + { + SyncArray, + AsyncArray, +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + SyncSpan, + AsyncMemory +#endif + } + + public ReadMethod Method { get; } + + public StructuredMessageEncodingStreamTests(ReadMethod method) + { + Method = method; + } + + private async ValueTask CopyStream(Stream source, Stream destination, int bufferSize = 81920) // number default for CopyTo impl + { + byte[] buf = new byte[bufferSize]; + int read; + switch (Method) + { + case ReadMethod.SyncArray: + while ((read = source.Read(buf, 0, bufferSize)) > 0) + { + destination.Write(buf, 0, read); + } + break; + case ReadMethod.AsyncArray: + while ((read = await source.ReadAsync(buf, 0, bufferSize)) > 0) + { + await destination.WriteAsync(buf, 0, read); + } + break; +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + case ReadMethod.SyncSpan: + while ((read = source.Read(new Span(buf))) > 0) + { + destination.Write(new Span(buf, 0, read)); + } + break; + case ReadMethod.AsyncMemory: + while ((read = await source.ReadAsync(new Memory(buf))) > 0) + { + await destination.WriteAsync(new Memory(buf, 0, read)); + } + break; +#endif + } + destination.Flush(); + } + + [Test] + [Pairwise] + public async Task EncodesData( + [Values(2048, 2005)] int dataLength, + [Values(default, 512)] int? seglen, + [Values(8 * Constants.KB, 512, 530, 3)] int readLen, + [Values(true, false)] bool useCrc) + { + int segmentContentLength = seglen ?? int.MaxValue; + Flags flags = useCrc ? Flags.StorageCrc64 : Flags.None; + + byte[] originalData = new byte[dataLength]; + new Random().NextBytes(originalData); + byte[] expectedEncodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentContentLength, flags); + + Stream encodingStream = new StructuredMessageEncodingStream(new MemoryStream(originalData), segmentContentLength, flags); + byte[] encodedData; + using (MemoryStream dest = new()) + { + await CopyStream(encodingStream, dest, readLen); + encodedData = dest.ToArray(); + } + + Assert.That(new Span(encodedData).SequenceEqual(expectedEncodedData)); + } + + [TestCase(0, 0)] // start + [TestCase(5, 0)] // partway through stream header + [TestCase(V1_0.StreamHeaderLength, 0)] // start of segment + [TestCase(V1_0.StreamHeaderLength + 3, 0)] // partway through segment header + [TestCase(V1_0.StreamHeaderLength + V1_0.SegmentHeaderLength, 0)] // start of segment content + [TestCase(V1_0.StreamHeaderLength + V1_0.SegmentHeaderLength + 123, 123)] // partway through segment content + [TestCase(V1_0.StreamHeaderLength + V1_0.SegmentHeaderLength + 512, 512)] // start of segment footer + [TestCase(V1_0.StreamHeaderLength + V1_0.SegmentHeaderLength + 515, 512)] // partway through segment footer + [TestCase(V1_0.StreamHeaderLength + 3*V1_0.SegmentHeaderLength + 2*Crc64Length + 1500, 1500)] // partway through not first segment content + public async Task Seek(int targetRewindOffset, int expectedInnerStreamPosition) + { + const int segmentLength = 512; + const int dataLength = 2055; + byte[] data = new byte[dataLength]; + new Random().NextBytes(data); + + MemoryStream dataStream = new(data); + StructuredMessageEncodingStream encodingStream = new(dataStream, segmentLength, Flags.StorageCrc64); + + // no support for seeking past existing read, need to consume whole stream before seeking + await CopyStream(encodingStream, Stream.Null); + + encodingStream.Position = targetRewindOffset; + Assert.That(encodingStream.Position, Is.EqualTo(targetRewindOffset)); + Assert.That(dataStream.Position, Is.EqualTo(expectedInnerStreamPosition)); + } + + [TestCase(0)] // start + [TestCase(5)] // partway through stream header + [TestCase(V1_0.StreamHeaderLength)] // start of segment + [TestCase(V1_0.StreamHeaderLength + 3)] // partway through segment header + [TestCase(V1_0.StreamHeaderLength + V1_0.SegmentHeaderLength)] // start of segment content + [TestCase(V1_0.StreamHeaderLength + V1_0.SegmentHeaderLength + 123)] // partway through segment content + [TestCase(V1_0.StreamHeaderLength + V1_0.SegmentHeaderLength + 512)] // start of segment footer + [TestCase(V1_0.StreamHeaderLength + V1_0.SegmentHeaderLength + 515)] // partway through segment footer + [TestCase(V1_0.StreamHeaderLength + 2 * V1_0.SegmentHeaderLength + Crc64Length + 1500)] // partway through not first segment content + public async Task SupportsRewind(int targetRewindOffset) + { + const int segmentLength = 512; + const int dataLength = 2055; + byte[] data = new byte[dataLength]; + new Random().NextBytes(data); + + Stream encodingStream = new StructuredMessageEncodingStream(new MemoryStream(data), segmentLength, Flags.StorageCrc64); + byte[] encodedData1; + using (MemoryStream dest = new()) + { + await CopyStream(encodingStream, dest); + encodedData1 = dest.ToArray(); + } + encodingStream.Position = targetRewindOffset; + byte[] encodedData2; + using (MemoryStream dest = new()) + { + await CopyStream(encodingStream, dest); + encodedData2 = dest.ToArray(); + } + + Assert.That(new Span(encodedData1).Slice(targetRewindOffset).SequenceEqual(encodedData2)); + } + + [Test] + public async Task SupportsFastForward() + { + const int segmentLength = 512; + const int dataLength = 2055; + byte[] data = new byte[dataLength]; + new Random().NextBytes(data); + + // must have read stream to fastforward. so read whole stream upfront & save result to check later + Stream encodingStream = new StructuredMessageEncodingStream(new MemoryStream(data), segmentLength, Flags.StorageCrc64); + byte[] encodedData; + using (MemoryStream dest = new()) + { + await CopyStream(encodingStream, dest); + encodedData = dest.ToArray(); + } + + encodingStream.Position = 0; + + bool skip = false; + const int increment = 499; + while (encodingStream.Position < encodingStream.Length) + { + if (skip) + { + encodingStream.Position = Math.Min(dataLength, encodingStream.Position + increment); + skip = !skip; + continue; + } + ReadOnlyMemory expected = new(encodedData, (int)encodingStream.Position, + (int)Math.Min(increment, encodedData.Length - encodingStream.Position)); + ReadOnlyMemory actual; + using (MemoryStream dest = new(increment)) + { + await CopyStream(WindowStream.GetWindow(encodingStream, increment), dest); + actual = dest.ToArray(); + } + Assert.That(expected.Span.SequenceEqual(actual.Span)); + skip = !skip; + } + } + + [Test] + public void NotSupportsFastForwardBeyondLatestRead() + { + const int segmentLength = 512; + const int dataLength = 2055; + byte[] data = new byte[dataLength]; + new Random().NextBytes(data); + + Stream encodingStream = new StructuredMessageEncodingStream(new MemoryStream(data), segmentLength, Flags.StorageCrc64); + + Assert.That(() => encodingStream.Position = 123, Throws.TypeOf()); + } + + [Test] + [Pairwise] + public async Task WrapperStreamCorrectData( + [Values(2048, 2005)] int dataLength, + [Values(8 * Constants.KB, 512, 530, 3)] int readLen) + { + int segmentContentLength = dataLength; + Flags flags = Flags.StorageCrc64; + + byte[] originalData = new byte[dataLength]; + new Random().NextBytes(originalData); + byte[] crc = CrcInline(originalData); + byte[] expectedEncodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentContentLength, flags); + + Stream encodingStream = new StructuredMessagePrecalculatedCrcWrapperStream(new MemoryStream(originalData), crc); + byte[] encodedData; + using (MemoryStream dest = new()) + { + await CopyStream(encodingStream, dest, readLen); + encodedData = dest.ToArray(); + } + + Assert.That(new Span(encodedData).SequenceEqual(expectedEncodedData)); + } + + private static void AssertExpectedStreamHeader(ReadOnlySpan actual, int originalDataLength, Flags flags, int expectedSegments) + { + int expectedFooterLen = flags.HasFlag(Flags.StorageCrc64) ? Crc64Length : 0; + + Assert.That(actual.Length, Is.EqualTo(V1_0.StreamHeaderLength)); + Assert.That(actual[0], Is.EqualTo(1)); + Assert.That(BinaryPrimitives.ReadInt64LittleEndian(actual.Slice(1, 8)), + Is.EqualTo(V1_0.StreamHeaderLength + expectedSegments * (V1_0.SegmentHeaderLength + expectedFooterLen) + originalDataLength)); + Assert.That(BinaryPrimitives.ReadInt16LittleEndian(actual.Slice(9, 2)), Is.EqualTo((short)flags)); + Assert.That(BinaryPrimitives.ReadInt16LittleEndian(actual.Slice(11, 2)), Is.EqualTo((short)expectedSegments)); + } + + private static void AssertExpectedSegmentHeader(ReadOnlySpan actual, int segmentNum, long contentLength) + { + Assert.That(BinaryPrimitives.ReadInt16LittleEndian(actual.Slice(0, 2)), Is.EqualTo((short) segmentNum)); + Assert.That(BinaryPrimitives.ReadInt64LittleEndian(actual.Slice(2, 8)), Is.EqualTo(contentLength)); + } + + private static byte[] CrcInline(ReadOnlySpan data) + { + var crc = StorageCrc64HashAlgorithm.Create(); + crc.Append(data); + return crc.GetCurrentHash(); + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageHelper.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageHelper.cs new file mode 100644 index 0000000000000..59e80320d96a0 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageHelper.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Azure.Storage.Shared; +using static Azure.Storage.Shared.StructuredMessage; + +namespace Azure.Storage.Blobs.Tests +{ + internal class StructuredMessageHelper + { + public static byte[] MakeEncodedData(ReadOnlySpan data, long segmentContentLength, Flags flags) + { + int segmentCount = (int) Math.Ceiling(data.Length / (double)segmentContentLength); + int segmentFooterLen = flags.HasFlag(Flags.StorageCrc64) ? 8 : 0; + int streamFooterLen = flags.HasFlag(Flags.StorageCrc64) ? 8 : 0; + + byte[] encodedData = new byte[ + V1_0.StreamHeaderLength + + segmentCount*(V1_0.SegmentHeaderLength + segmentFooterLen) + + streamFooterLen + + data.Length]; + V1_0.WriteStreamHeader( + new Span(encodedData, 0, V1_0.StreamHeaderLength), + encodedData.Length, + flags, + segmentCount); + + int i = V1_0.StreamHeaderLength; + int j = 0; + foreach (int seg in Enumerable.Range(1, segmentCount)) + { + int segContentLen = Math.Min((int)segmentContentLength, data.Length - j); + V1_0.WriteSegmentHeader( + new Span(encodedData, i, V1_0.SegmentHeaderLength), + seg, + segContentLen); + i += V1_0.SegmentHeaderLength; + + data.Slice(j, segContentLen) + .CopyTo(new Span(encodedData).Slice(i)); + i += segContentLen; + + if (flags.HasFlag(Flags.StorageCrc64)) + { + var crc = StorageCrc64HashAlgorithm.Create(); + crc.Append(data.Slice(j, segContentLen)); + crc.GetCurrentHash(new Span(encodedData, i, Crc64Length)); + i += Crc64Length; + } + j += segContentLen; + } + + if (flags.HasFlag(Flags.StorageCrc64)) + { + var crc = StorageCrc64HashAlgorithm.Create(); + crc.Append(data); + crc.GetCurrentHash(new Span(encodedData, i, Crc64Length)); + } + + return encodedData; + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs new file mode 100644 index 0000000000000..61583aa1ebe4e --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Azure.Storage.Shared; +using NUnit.Framework; +using static Azure.Storage.Shared.StructuredMessage; + +namespace Azure.Storage.Tests +{ + [TestFixture(ReadMethod.SyncArray)] + [TestFixture(ReadMethod.AsyncArray)] +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + [TestFixture(ReadMethod.SyncSpan)] + [TestFixture(ReadMethod.AsyncMemory)] +#endif + public class StructuredMessageStreamRoundtripTests + { + // Cannot just implement as passthru in the stream + // Must test each one + public enum ReadMethod + { + SyncArray, + AsyncArray, +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + SyncSpan, + AsyncMemory +#endif + } + + public ReadMethod Method { get; } + + public StructuredMessageStreamRoundtripTests(ReadMethod method) + { + Method = method; + } + + private class CopyStreamException : Exception + { + public long TotalCopied { get; } + + public CopyStreamException(Exception inner, long totalCopied) + : base($"Failed read after {totalCopied}-many bytes.", inner) + { + TotalCopied = totalCopied; + } + } + private async ValueTask CopyStream(Stream source, Stream destination, int bufferSize = 81920) // number default for CopyTo impl + { + byte[] buf = new byte[bufferSize]; + int read; + long totalRead = 0; + try + { + switch (Method) + { + case ReadMethod.SyncArray: + while ((read = source.Read(buf, 0, bufferSize)) > 0) + { + totalRead += read; + destination.Write(buf, 0, read); + } + break; + case ReadMethod.AsyncArray: + while ((read = await source.ReadAsync(buf, 0, bufferSize)) > 0) + { + totalRead += read; + await destination.WriteAsync(buf, 0, read); + } + break; +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + case ReadMethod.SyncSpan: + while ((read = source.Read(new Span(buf))) > 0) + { + totalRead += read; + destination.Write(new Span(buf, 0, read)); + } + break; + case ReadMethod.AsyncMemory: + while ((read = await source.ReadAsync(new Memory(buf))) > 0) + { + totalRead += read; + await destination.WriteAsync(new Memory(buf, 0, read)); + } + break; +#endif + } + destination.Flush(); + } + catch (Exception ex) + { + throw new CopyStreamException(ex, totalRead); + } + return totalRead; + } + + [Test] + [Pairwise] + public async Task RoundTrip( + [Values(2048, 2005)] int dataLength, + [Values(default, 512)] int? seglen, + [Values(8 * Constants.KB, 512, 530, 3)] int readLen, + [Values(true, false)] bool useCrc) + { + int segmentLength = seglen ?? int.MaxValue; + Flags flags = useCrc ? Flags.StorageCrc64 : Flags.None; + + byte[] originalData = new byte[dataLength]; + new Random().NextBytes(originalData); + + byte[] roundtripData; + using (MemoryStream source = new(originalData)) + using (Stream encode = new StructuredMessageEncodingStream(source, segmentLength, flags)) + using (Stream decode = StructuredMessageDecodingStream.WrapStream(encode).DecodedStream) + using (MemoryStream dest = new()) + { + await CopyStream(source, dest, readLen); + roundtripData = dest.ToArray(); + } + + Assert.That(originalData.SequenceEqual(roundtripData)); + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageTests.cs new file mode 100644 index 0000000000000..b4f1dfe178246 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageTests.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers.Binary; +using System.Collections.Generic; +using NUnit.Framework; +using static Azure.Storage.Shared.StructuredMessage; + +namespace Azure.Storage.Tests +{ + public class StructuredMessageTests + { + [TestCase(1024, Flags.None, 2)] + [TestCase(2000, Flags.StorageCrc64, 4)] + public void EncodeStreamHeader(int messageLength, int flags, int numSegments) + { + Span encoding = new(new byte[V1_0.StreamHeaderLength]); + V1_0.WriteStreamHeader(encoding, messageLength, (Flags)flags, numSegments); + + Assert.That(encoding[0], Is.EqualTo((byte)1)); + Assert.That(BinaryPrimitives.ReadUInt64LittleEndian(encoding.Slice(1, 8)), Is.EqualTo(messageLength)); + Assert.That(BinaryPrimitives.ReadUInt16LittleEndian(encoding.Slice(9, 2)), Is.EqualTo(flags)); + Assert.That(BinaryPrimitives.ReadUInt16LittleEndian(encoding.Slice(11, 2)), Is.EqualTo(numSegments)); + } + + [TestCase(V1_0.StreamHeaderLength)] + [TestCase(V1_0.StreamHeaderLength + 1)] + [TestCase(V1_0.StreamHeaderLength - 1)] + public void EncodeStreamHeaderRejectBadBufferSize(int bufferSize) + { + Random r = new(); + byte[] encoding = new byte[bufferSize]; + + void Action() => V1_0.WriteStreamHeader(encoding, r.Next(2, int.MaxValue), Flags.StorageCrc64, r.Next(2, int.MaxValue)); + if (bufferSize < V1_0.StreamHeaderLength) + { + Assert.That(Action, Throws.ArgumentException); + } + else + { + Assert.That(Action, Throws.Nothing); + } + } + + [TestCase(1, 1024)] + [TestCase(5, 39578)] + public void EncodeSegmentHeader(int segmentNum, int contentLength) + { + Span encoding = new(new byte[V1_0.SegmentHeaderLength]); + V1_0.WriteSegmentHeader(encoding, segmentNum, contentLength); + + Assert.That(BinaryPrimitives.ReadUInt16LittleEndian(encoding.Slice(0, 2)), Is.EqualTo(segmentNum)); + Assert.That(BinaryPrimitives.ReadUInt64LittleEndian(encoding.Slice(2, 8)), Is.EqualTo(contentLength)); + } + + [TestCase(V1_0.SegmentHeaderLength)] + [TestCase(V1_0.SegmentHeaderLength + 1)] + [TestCase(V1_0.SegmentHeaderLength - 1)] + public void EncodeSegmentHeaderRejectBadBufferSize(int bufferSize) + { + Random r = new(); + byte[] encoding = new byte[bufferSize]; + + void Action() => V1_0.WriteSegmentHeader(encoding, r.Next(1, int.MaxValue), r.Next(2, int.MaxValue)); + if (bufferSize < V1_0.SegmentHeaderLength) + { + Assert.That(Action, Throws.ArgumentException); + } + else + { + Assert.That(Action, Throws.Nothing); + } + } + + [TestCase(true)] + [TestCase(false)] + public void EncodeSegmentFooter(bool useCrc) + { + Span encoding = new(new byte[Crc64Length]); + Span crc = useCrc ? new Random().NextBytesInline(Crc64Length) : default; + V1_0.WriteSegmentFooter(encoding, crc); + + if (useCrc) + { + Assert.That(encoding.SequenceEqual(crc), Is.True); + } + else + { + Assert.That(encoding.SequenceEqual(new Span(new byte[Crc64Length])), Is.True); + } + } + + [TestCase(Crc64Length)] + [TestCase(Crc64Length + 1)] + [TestCase(Crc64Length - 1)] + public void EncodeSegmentFooterRejectBadBufferSize(int bufferSize) + { + byte[] encoding = new byte[bufferSize]; + byte[] crc = new byte[Crc64Length]; + new Random().NextBytes(crc); + + void Action() => V1_0.WriteSegmentFooter(encoding, crc); + if (bufferSize < Crc64Length) + { + Assert.That(Action, Throws.ArgumentException); + } + else + { + Assert.That(Action, Throws.Nothing); + } + } + } +} diff --git a/sdk/storage/Azure.Storage.DataMovement.Blobs/samples/Azure.Storage.DataMovement.Blobs.Samples.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement.Blobs/samples/Azure.Storage.DataMovement.Blobs.Samples.Tests.csproj index 7ab901e963e03..30d4b1f79daaf 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Blobs/samples/Azure.Storage.DataMovement.Blobs.Samples.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Blobs/samples/Azure.Storage.DataMovement.Blobs.Samples.Tests.csproj @@ -11,6 +11,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Blobs/src/Azure.Storage.DataMovement.Blobs.csproj b/sdk/storage/Azure.Storage.DataMovement.Blobs/src/Azure.Storage.DataMovement.Blobs.csproj index 6098dcd8ba33d..93e7432f186e3 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Blobs/src/Azure.Storage.DataMovement.Blobs.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Blobs/src/Azure.Storage.DataMovement.Blobs.csproj @@ -37,6 +37,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Blobs/src/DataMovementBlobsExtensions.cs b/sdk/storage/Azure.Storage.DataMovement.Blobs/src/DataMovementBlobsExtensions.cs index 84d60b3bc37c4..2c6864f511571 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Blobs/src/DataMovementBlobsExtensions.cs +++ b/sdk/storage/Azure.Storage.DataMovement.Blobs/src/DataMovementBlobsExtensions.cs @@ -99,7 +99,7 @@ internal static StorageResourceItemProperties ToStorageResourceItemProperties(th ContentRange contentRange = !string.IsNullOrWhiteSpace(result?.Details?.ContentRange) ? ContentRange.Parse(result.Details.ContentRange) : default; if (contentRange != default) { - size = contentRange.Size; + size = contentRange.TotalResourceLength; } return new StorageResourceItemProperties( @@ -151,7 +151,7 @@ internal static StorageResourceReadStreamResult ToReadStreamStorageResourceInfo( if (contentRange != default) { range = ContentRange.ToHttpRange(contentRange); - size = contentRange.Size; + size = contentRange.TotalResourceLength; } else if (result.Details.ContentLength > 0) { diff --git a/sdk/storage/Azure.Storage.DataMovement.Blobs/tests/Azure.Storage.DataMovement.Blobs.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement.Blobs/tests/Azure.Storage.DataMovement.Blobs.Tests.csproj index f8b62d0b947e2..214903eb5f9c4 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Blobs/tests/Azure.Storage.DataMovement.Blobs.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Blobs/tests/Azure.Storage.DataMovement.Blobs.Tests.csproj @@ -22,11 +22,15 @@ + + + + @@ -40,6 +44,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/BlobToFileSharesTests/Azure.Storage.DataMovement.Blobs.Files.Shares.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/BlobToFileSharesTests/Azure.Storage.DataMovement.Blobs.Files.Shares.Tests.csproj index a6abde432473f..66a9fea0861a2 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/BlobToFileSharesTests/Azure.Storage.DataMovement.Blobs.Files.Shares.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/BlobToFileSharesTests/Azure.Storage.DataMovement.Blobs.Files.Shares.Tests.csproj @@ -35,6 +35,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/samples/Azure.Storage.DataMovement.Files.Shares.Samples.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/samples/Azure.Storage.DataMovement.Files.Shares.Samples.Tests.csproj index 9cde066f64eb7..6a472b9f74158 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/samples/Azure.Storage.DataMovement.Files.Shares.Samples.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/samples/Azure.Storage.DataMovement.Files.Shares.Samples.Tests.csproj @@ -1,4 +1,4 @@ - + $(RequiredTargetFrameworks) Microsoft Azure.Storage.DataMovement.Files.Shares client library samples @@ -11,6 +11,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/src/DataMovementSharesExtensions.cs b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/src/DataMovementSharesExtensions.cs index 9cb7d338fcb60..16a164f61b060 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/src/DataMovementSharesExtensions.cs +++ b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/src/DataMovementSharesExtensions.cs @@ -335,14 +335,14 @@ internal static StorageResourceReadStreamResult ToStorageResourceReadStreamResul ContentRange contentRange = !string.IsNullOrWhiteSpace(info?.Details?.ContentRange) ? ContentRange.Parse(info.Details.ContentRange) : default; if (contentRange != default) { - size = contentRange.Size; + size = contentRange.TotalResourceLength; } return new StorageResourceReadStreamResult( content: info?.Content, range: ContentRange.ToHttpRange(contentRange), properties: new StorageResourceItemProperties( - resourceLength: contentRange.Size, + resourceLength: contentRange.TotalResourceLength, eTag: info.Details.ETag, lastModifiedTime: info.Details.LastModified, properties: properties)); diff --git a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Azure.Storage.DataMovement.Files.Shares.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Azure.Storage.DataMovement.Files.Shares.Tests.csproj index 8e574bca36a48..d75775beceafd 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Azure.Storage.DataMovement.Files.Shares.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Azure.Storage.DataMovement.Files.Shares.Tests.csproj @@ -27,6 +27,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Shared/DisposingShare.cs b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Shared/DisposingShare.cs index ae3bc879f717e..577ee7bb9a480 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Shared/DisposingShare.cs +++ b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Shared/DisposingShare.cs @@ -17,7 +17,7 @@ public class DisposingShare : IDisposingContainer public static async Task CreateAsync(ShareClient share, IDictionary metadata) { - await share.CreateIfNotExistsAsync(metadata: metadata); + await share.CreateIfNotExistsAsync(new() { Metadata = metadata }); return new DisposingShare(share); } diff --git a/sdk/storage/Azure.Storage.DataMovement/src/Azure.Storage.DataMovement.csproj b/sdk/storage/Azure.Storage.DataMovement/src/Azure.Storage.DataMovement.csproj index 5aaf548493b15..dd30659cf0a5d 100644 --- a/sdk/storage/Azure.Storage.DataMovement/src/Azure.Storage.DataMovement.csproj +++ b/sdk/storage/Azure.Storage.DataMovement/src/Azure.Storage.DataMovement.csproj @@ -1,4 +1,4 @@ - + $(RequiredTargetFrameworks);net6.0 diff --git a/sdk/storage/Azure.Storage.DataMovement/tests/Azure.Storage.DataMovement.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement/tests/Azure.Storage.DataMovement.Tests.csproj index b5e3c42359976..7a40eb8026443 100644 --- a/sdk/storage/Azure.Storage.DataMovement/tests/Azure.Storage.DataMovement.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement/tests/Azure.Storage.DataMovement.Tests.csproj @@ -34,6 +34,7 @@ + diff --git a/sdk/storage/Azure.Storage.Files.DataLake/api/Azure.Storage.Files.DataLake.net6.0.cs b/sdk/storage/Azure.Storage.Files.DataLake/api/Azure.Storage.Files.DataLake.net6.0.cs index a202d6300f50e..7f856db5829ac 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/api/Azure.Storage.Files.DataLake.net6.0.cs +++ b/sdk/storage/Azure.Storage.Files.DataLake/api/Azure.Storage.Files.DataLake.net6.0.cs @@ -2,7 +2,7 @@ namespace Azure.Storage.Files.DataLake { public partial class DataLakeClientOptions : Azure.Core.ClientOptions { - public DataLakeClientOptions(Azure.Storage.Files.DataLake.DataLakeClientOptions.ServiceVersion version = Azure.Storage.Files.DataLake.DataLakeClientOptions.ServiceVersion.V2024_11_04) { } + public DataLakeClientOptions(Azure.Storage.Files.DataLake.DataLakeClientOptions.ServiceVersion version = Azure.Storage.Files.DataLake.DataLakeClientOptions.ServiceVersion.V2025_01_05) { } public Azure.Storage.Files.DataLake.Models.DataLakeAudience? Audience { get { throw null; } set { } } public Azure.Storage.Files.DataLake.Models.DataLakeCustomerProvidedKey? CustomerProvidedKey { get { throw null; } set { } } public bool EnableTenantDiscovery { get { throw null; } set { } } diff --git a/sdk/storage/Azure.Storage.Files.DataLake/api/Azure.Storage.Files.DataLake.netstandard2.0.cs b/sdk/storage/Azure.Storage.Files.DataLake/api/Azure.Storage.Files.DataLake.netstandard2.0.cs index a202d6300f50e..7f856db5829ac 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/api/Azure.Storage.Files.DataLake.netstandard2.0.cs +++ b/sdk/storage/Azure.Storage.Files.DataLake/api/Azure.Storage.Files.DataLake.netstandard2.0.cs @@ -2,7 +2,7 @@ namespace Azure.Storage.Files.DataLake { public partial class DataLakeClientOptions : Azure.Core.ClientOptions { - public DataLakeClientOptions(Azure.Storage.Files.DataLake.DataLakeClientOptions.ServiceVersion version = Azure.Storage.Files.DataLake.DataLakeClientOptions.ServiceVersion.V2024_11_04) { } + public DataLakeClientOptions(Azure.Storage.Files.DataLake.DataLakeClientOptions.ServiceVersion version = Azure.Storage.Files.DataLake.DataLakeClientOptions.ServiceVersion.V2025_01_05) { } public Azure.Storage.Files.DataLake.Models.DataLakeAudience? Audience { get { throw null; } set { } } public Azure.Storage.Files.DataLake.Models.DataLakeCustomerProvidedKey? CustomerProvidedKey { get { throw null; } set { } } public bool EnableTenantDiscovery { get { throw null; } set { } } diff --git a/sdk/storage/Azure.Storage.Files.DataLake/assets.json b/sdk/storage/Azure.Storage.Files.DataLake/assets.json index 4a64b8398f656..5127ea7e0c4db 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/assets.json +++ b/sdk/storage/Azure.Storage.Files.DataLake/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/storage/Azure.Storage.Files.DataLake", - "Tag": "net/storage/Azure.Storage.Files.DataLake_d74597f1e3" + "Tag": "net/storage/Azure.Storage.Files.DataLake_48a38da58a" } diff --git a/sdk/storage/Azure.Storage.Files.DataLake/samples/Azure.Storage.Files.DataLake.Samples.Tests.csproj b/sdk/storage/Azure.Storage.Files.DataLake/samples/Azure.Storage.Files.DataLake.Samples.Tests.csproj index c230f2ed8fa20..eecbe0543fe87 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/samples/Azure.Storage.Files.DataLake.Samples.Tests.csproj +++ b/sdk/storage/Azure.Storage.Files.DataLake/samples/Azure.Storage.Files.DataLake.Samples.Tests.csproj @@ -15,6 +15,7 @@ + diff --git a/sdk/storage/Azure.Storage.Files.DataLake/src/Azure.Storage.Files.DataLake.csproj b/sdk/storage/Azure.Storage.Files.DataLake/src/Azure.Storage.Files.DataLake.csproj index 3c551e05c24c2..f8652fd283e36 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/src/Azure.Storage.Files.DataLake.csproj +++ b/sdk/storage/Azure.Storage.Files.DataLake/src/Azure.Storage.Files.DataLake.csproj @@ -42,6 +42,7 @@ + @@ -81,6 +82,10 @@ + + + + diff --git a/sdk/storage/Azure.Storage.Files.DataLake/src/DataLakeFileClient.cs b/sdk/storage/Azure.Storage.Files.DataLake/src/DataLakeFileClient.cs index 2da5eb76349eb..aaa8f514c6e44 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/src/DataLakeFileClient.cs +++ b/sdk/storage/Azure.Storage.Files.DataLake/src/DataLakeFileClient.cs @@ -16,6 +16,7 @@ using Azure.Storage.Common; using Azure.Storage.Files.DataLake.Models; using Azure.Storage.Sas; +using Azure.Storage.Shared; using Metadata = System.Collections.Generic.IDictionary; namespace Azure.Storage.Files.DataLake @@ -2332,13 +2333,39 @@ internal virtual async Task AppendInternal( using (ClientConfiguration.Pipeline.BeginLoggingScope(nameof(DataLakeFileClient))) { // compute hash BEFORE attaching progress handler - ContentHasher.GetHashResult hashResult = await ContentHasher.GetHashOrDefaultInternal( - content, - validationOptions, - async, - cancellationToken).ConfigureAwait(false); + ContentHasher.GetHashResult hashResult = null; + long contentLength = (content?.Length - content?.Position) ?? 0; + long? structuredContentLength = default; + string structuredBodyType = null; + if (content != null && + validationOptions != null && + validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64) + { + // report progress in terms of caller bytes, not encoded bytes + structuredContentLength = contentLength; + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; + content = content.WithNoDispose().WithProgress(progressHandler); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); + contentLength = content.Length - content.Position; + } + else + { + // compute hash BEFORE attaching progress handler + hashResult = await ContentHasher.GetHashOrDefaultInternal( + content, + validationOptions, + async, + cancellationToken).ConfigureAwait(false); + content = content?.WithNoDispose().WithProgress(progressHandler); + } - content = content?.WithNoDispose().WithProgress(progressHandler); ClientConfiguration.Pipeline.LogMethodEnter( nameof(DataLakeFileClient), message: @@ -2373,6 +2400,8 @@ internal virtual async Task AppendInternal( encryptionKey: ClientConfiguration.CustomerProvidedKey?.EncryptionKey, encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, leaseId: leaseId, leaseAction: leaseAction, leaseDuration: leaseDurationLong, @@ -2392,6 +2421,8 @@ internal virtual async Task AppendInternal( encryptionKey: ClientConfiguration.CustomerProvidedKey?.EncryptionKey, encryptionKeySha256: ClientConfiguration.CustomerProvidedKey?.EncryptionKeyHash, encryptionAlgorithm: ClientConfiguration.CustomerProvidedKey?.EncryptionAlgorithm == null ? null : EncryptionAlgorithmTypeInternal.AES256, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, leaseId: leaseId, leaseAction: leaseAction, leaseDuration: leaseDurationLong, diff --git a/sdk/storage/Azure.Storage.Files.DataLake/src/autorest.md b/sdk/storage/Azure.Storage.Files.DataLake/src/autorest.md index ec9675a014f70..a8340f1092bcb 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/src/autorest.md +++ b/sdk/storage/Azure.Storage.Files.DataLake/src/autorest.md @@ -23,7 +23,7 @@ directive: if (property.includes('/{filesystem}/{path}')) { $[property]["parameters"] = $[property]["parameters"].filter(function(param) { return (typeof param['$ref'] === "undefined") || (false == param['$ref'].endsWith("#/parameters/FileSystem") && false == param['$ref'].endsWith("#/parameters/Path"))}); - } + } else if (property.includes('/{filesystem}')) { $[property]["parameters"] = $[property]["parameters"].filter(function(param) { return (typeof param['$ref'] === "undefined") || (false == param['$ref'].endsWith("#/parameters/FileSystem"))}); @@ -127,7 +127,7 @@ directive: } $[newName] = $[oldName]; delete $[oldName]; - } + } else if (property.includes('/{filesystem}')) { var oldName = property; diff --git a/sdk/storage/Azure.Storage.Files.DataLake/tests/Azure.Storage.Files.DataLake.Tests.csproj b/sdk/storage/Azure.Storage.Files.DataLake/tests/Azure.Storage.Files.DataLake.Tests.csproj index bef13bb21a1c6..1fa78690077be 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/tests/Azure.Storage.Files.DataLake.Tests.csproj +++ b/sdk/storage/Azure.Storage.Files.DataLake/tests/Azure.Storage.Files.DataLake.Tests.csproj @@ -6,6 +6,9 @@ Microsoft Azure.Storage.Files.DataLake client library tests false + + DataLakeSDK + diff --git a/sdk/storage/Azure.Storage.Files.DataLake/tests/DataLakeFileClientTransferValidationTests.cs b/sdk/storage/Azure.Storage.Files.DataLake/tests/DataLakeFileClientTransferValidationTests.cs index 4bdefdbf756cd..5067f98517bd2 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/tests/DataLakeFileClientTransferValidationTests.cs +++ b/sdk/storage/Azure.Storage.Files.DataLake/tests/DataLakeFileClientTransferValidationTests.cs @@ -34,7 +34,10 @@ protected override async Task> Get StorageChecksumAlgorithm uploadAlgorithm = StorageChecksumAlgorithm.None, StorageChecksumAlgorithm downloadAlgorithm = StorageChecksumAlgorithm.None) { - var disposingFileSystem = await ClientBuilder.GetNewFileSystem(service: service, fileSystemName: containerName); + var disposingFileSystem = await ClientBuilder.GetNewFileSystem( + service: service, + fileSystemName: containerName, + publicAccessType: PublicAccessType.None); disposingFileSystem.FileSystem.ClientConfiguration.TransferValidation.Upload.ChecksumAlgorithm = uploadAlgorithm; disposingFileSystem.FileSystem.ClientConfiguration.TransferValidation.Download.ChecksumAlgorithm = downloadAlgorithm; diff --git a/sdk/storage/Azure.Storage.Files.Shares/api/Azure.Storage.Files.Shares.net6.0.cs b/sdk/storage/Azure.Storage.Files.Shares/api/Azure.Storage.Files.Shares.net6.0.cs index cf8ce32808d81..473ffb67af41f 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/api/Azure.Storage.Files.Shares.net6.0.cs +++ b/sdk/storage/Azure.Storage.Files.Shares/api/Azure.Storage.Files.Shares.net6.0.cs @@ -115,7 +115,7 @@ public ShareClient(System.Uri shareUri, Azure.Storage.StorageSharedKeyCredential } public partial class ShareClientOptions : Azure.Core.ClientOptions { - public ShareClientOptions(Azure.Storage.Files.Shares.ShareClientOptions.ServiceVersion version = Azure.Storage.Files.Shares.ShareClientOptions.ServiceVersion.V2024_11_04) { } + public ShareClientOptions(Azure.Storage.Files.Shares.ShareClientOptions.ServiceVersion version = Azure.Storage.Files.Shares.ShareClientOptions.ServiceVersion.V2025_01_05) { } public bool? AllowSourceTrailingDot { get { throw null; } set { } } public bool? AllowTrailingDot { get { throw null; } set { } } public Azure.Storage.Files.Shares.Models.ShareAudience? Audience { get { throw null; } set { } } @@ -808,6 +808,7 @@ public partial class ShareFileDownloadInfo : System.IDisposable { internal ShareFileDownloadInfo() { } public System.IO.Stream Content { get { throw null; } } + public byte[] ContentCrc { get { throw null; } } public byte[] ContentHash { get { throw null; } } public long ContentLength { get { throw null; } } public string ContentType { get { throw null; } } diff --git a/sdk/storage/Azure.Storage.Files.Shares/api/Azure.Storage.Files.Shares.netstandard2.0.cs b/sdk/storage/Azure.Storage.Files.Shares/api/Azure.Storage.Files.Shares.netstandard2.0.cs index cf8ce32808d81..473ffb67af41f 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/api/Azure.Storage.Files.Shares.netstandard2.0.cs +++ b/sdk/storage/Azure.Storage.Files.Shares/api/Azure.Storage.Files.Shares.netstandard2.0.cs @@ -115,7 +115,7 @@ public ShareClient(System.Uri shareUri, Azure.Storage.StorageSharedKeyCredential } public partial class ShareClientOptions : Azure.Core.ClientOptions { - public ShareClientOptions(Azure.Storage.Files.Shares.ShareClientOptions.ServiceVersion version = Azure.Storage.Files.Shares.ShareClientOptions.ServiceVersion.V2024_11_04) { } + public ShareClientOptions(Azure.Storage.Files.Shares.ShareClientOptions.ServiceVersion version = Azure.Storage.Files.Shares.ShareClientOptions.ServiceVersion.V2025_01_05) { } public bool? AllowSourceTrailingDot { get { throw null; } set { } } public bool? AllowTrailingDot { get { throw null; } set { } } public Azure.Storage.Files.Shares.Models.ShareAudience? Audience { get { throw null; } set { } } @@ -808,6 +808,7 @@ public partial class ShareFileDownloadInfo : System.IDisposable { internal ShareFileDownloadInfo() { } public System.IO.Stream Content { get { throw null; } } + public byte[] ContentCrc { get { throw null; } } public byte[] ContentHash { get { throw null; } } public long ContentLength { get { throw null; } } public string ContentType { get { throw null; } } diff --git a/sdk/storage/Azure.Storage.Files.Shares/assets.json b/sdk/storage/Azure.Storage.Files.Shares/assets.json index c2b5c3d31e6a2..c33c8bb335398 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/assets.json +++ b/sdk/storage/Azure.Storage.Files.Shares/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/storage/Azure.Storage.Files.Shares", - "Tag": "net/storage/Azure.Storage.Files.Shares_df67d82d59" + "Tag": "net/storage/Azure.Storage.Files.Shares_4b545ae555" } diff --git a/sdk/storage/Azure.Storage.Files.Shares/samples/Azure.Storage.Files.Shares.Samples.Tests.csproj b/sdk/storage/Azure.Storage.Files.Shares/samples/Azure.Storage.Files.Shares.Samples.Tests.csproj index 0bcec423c144d..d1efeca0c2da2 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/samples/Azure.Storage.Files.Shares.Samples.Tests.csproj +++ b/sdk/storage/Azure.Storage.Files.Shares/samples/Azure.Storage.Files.Shares.Samples.Tests.csproj @@ -16,6 +16,7 @@ + PreserveNewest diff --git a/sdk/storage/Azure.Storage.Files.Shares/src/Azure.Storage.Files.Shares.csproj b/sdk/storage/Azure.Storage.Files.Shares/src/Azure.Storage.Files.Shares.csproj index 740160b155650..d136154f5d3d4 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/src/Azure.Storage.Files.Shares.csproj +++ b/sdk/storage/Azure.Storage.Files.Shares/src/Azure.Storage.Files.Shares.csproj @@ -1,4 +1,4 @@ - + $(RequiredTargetFrameworks);net6.0 @@ -42,6 +42,7 @@ + @@ -85,6 +86,11 @@ + + + + + diff --git a/sdk/storage/Azure.Storage.Files.Shares/src/Models/ShareFileDownloadInfo.cs b/sdk/storage/Azure.Storage.Files.Shares/src/Models/ShareFileDownloadInfo.cs index 0165af94435a0..4037cbdfd875e 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/src/Models/ShareFileDownloadInfo.cs +++ b/sdk/storage/Azure.Storage.Files.Shares/src/Models/ShareFileDownloadInfo.cs @@ -38,6 +38,12 @@ public partial class ShareFileDownloadInfo : IDisposable, IDownloadedContent public byte[] ContentHash { get; internal set; } #pragma warning restore CA1819 // Properties should not return arrays + /// + /// When requested using , this value contains the CRC for the download blob range. + /// This value may only become populated once the network stream is fully consumed. + /// + public byte[] ContentCrc { get; internal set; } + /// /// Details returned when downloading a file /// diff --git a/sdk/storage/Azure.Storage.Files.Shares/src/ShareErrors.cs b/sdk/storage/Azure.Storage.Files.Shares/src/ShareErrors.cs index f776384d06add..0b27510aaa6c4 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/src/ShareErrors.cs +++ b/sdk/storage/Azure.Storage.Files.Shares/src/ShareErrors.cs @@ -17,20 +17,5 @@ public static InvalidOperationException FileOrShareMissing( string fileClient, string shareClient) => new InvalidOperationException($"{leaseClient} requires either a {fileClient} or {shareClient}"); - - public static void AssertAlgorithmSupport(StorageChecksumAlgorithm? algorithm) - { - StorageChecksumAlgorithm resolved = (algorithm ?? StorageChecksumAlgorithm.None).ResolveAuto(); - switch (resolved) - { - case StorageChecksumAlgorithm.None: - case StorageChecksumAlgorithm.MD5: - return; - case StorageChecksumAlgorithm.StorageCrc64: - throw new ArgumentException("Azure File Shares do not support CRC-64."); - default: - throw new ArgumentException($"{nameof(StorageChecksumAlgorithm)} does not support value {Enum.GetName(typeof(StorageChecksumAlgorithm), resolved) ?? ((int)resolved).ToString(CultureInfo.InvariantCulture)}."); - } - } } } diff --git a/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs b/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs index f713200a524de..ea3f8554b944d 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs +++ b/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs @@ -2397,51 +2397,70 @@ private async Task> DownloadInternal( // Wrap the response Content in a RetriableStream so we // can return it before it's finished downloading, but still // allow retrying if it fails. - initialResponse.Value.Content = RetriableStream.Create( - stream, - startOffset => - { - (Response Response, Stream ContentStream) = StartDownloadAsync( - range, - validationOptions, - conditions, - startOffset, - async, - cancellationToken) - .EnsureCompleted(); - if (etag != Response.GetRawResponse().Headers.ETag) - { - throw new ShareFileModifiedException( - "File has been modified concurrently", - Uri, etag, Response.GetRawResponse().Headers.ETag.GetValueOrDefault(), range); - } - return ContentStream; - }, - async startOffset => + async ValueTask> Factory(long offset, bool async, CancellationToken cancellationToken) + { + (Response response, Stream contentStream) = await StartDownloadAsync( + range, + validationOptions, + conditions, + offset, + async, + cancellationToken).ConfigureAwait(false); + if (etag != response.GetRawResponse().Headers.ETag) { - (Response Response, Stream ContentStream) = await StartDownloadAsync( - range, - validationOptions, - conditions, - startOffset, - async, - cancellationToken) - .ConfigureAwait(false); - if (etag != Response.GetRawResponse().Headers.ETag) + throw new ShareFileModifiedException( + "File has been modified concurrently", + Uri, etag, response.GetRawResponse().Headers.ETag.GetValueOrDefault(), range); + } + return response; + } + async ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.RawDecodedData DecodedData)> StructuredMessageFactory( + long offset, bool async, CancellationToken cancellationToken) + { + Response result = await Factory(offset, async, cancellationToken).ConfigureAwait(false); + return StructuredMessageDecodingStream.WrapStream(result.Value.Content, result.Value.ContentLength); + } + + if (initialResponse.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) + { + (Stream decodingStream, StructuredMessageDecodingStream.RawDecodedData decodedData) = StructuredMessageDecodingStream.WrapStream( + initialResponse.Value.Content, initialResponse.Value.ContentLength); + initialResponse.Value.Content = new StructuredMessageDecodingRetriableStream( + decodingStream, + decodedData, + StructuredMessage.Flags.StorageCrc64, + startOffset => StructuredMessageFactory(startOffset, async: false, cancellationToken) + .EnsureCompleted(), + async startOffset => await StructuredMessageFactory(startOffset, async: true, cancellationToken) + .ConfigureAwait(false), + decodedData => { - throw new ShareFileModifiedException( - "File has been modified concurrently", - Uri, etag, Response.GetRawResponse().Headers.ETag.GetValueOrDefault(), range); - } - return ContentStream; - }, - ClientConfiguration.Pipeline.ResponseClassifier, - Constants.MaxReliabilityRetries); + initialResponse.Value.ContentCrc = new byte[StructuredMessage.Crc64Length]; + decodedData.Crc.WriteCrc64(initialResponse.Value.ContentCrc); + }, + ClientConfiguration.Pipeline.ResponseClassifier, + Constants.MaxReliabilityRetries); + } + else + { + initialResponse.Value.Content = RetriableStream.Create( + initialResponse.Value.Content, + startOffset => Factory(startOffset, async: false, cancellationToken) + .EnsureCompleted().Value.Content, + async startOffset => (await Factory(startOffset, async: true, cancellationToken) + .ConfigureAwait(false)).Value.Content, + ClientConfiguration.Pipeline.ResponseClassifier, + Constants.MaxReliabilityRetries); + } // buffer response stream and ensure it matches the transactional hash if any // Storage will not return a hash for payload >4MB, so this buffer is capped similarly // hashing is opt-in, so this buffer is part of that opt-in - if (validationOptions != default && validationOptions.ChecksumAlgorithm != StorageChecksumAlgorithm.None && validationOptions.AutoValidateChecksum) + if (validationOptions != default && + validationOptions.ChecksumAlgorithm != StorageChecksumAlgorithm.None && + validationOptions.AutoValidateChecksum && + // structured message decoding does the validation for us + !initialResponse.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) { // safe-buffer; transactional hash download limit well below maxInt var readDestStream = new MemoryStream((int)initialResponse.Value.ContentLength); @@ -2524,8 +2543,6 @@ await ContentHasher.AssertResponseHashMatchInternal( bool async = true, CancellationToken cancellationToken = default) { - ShareErrors.AssertAlgorithmSupport(transferValidationOverride?.ChecksumAlgorithm); - // calculation gets illegible with null coalesce; just pre-initialize var pageRange = range; pageRange = new HttpRange( @@ -2535,13 +2552,27 @@ await ContentHasher.AssertResponseHashMatchInternal( (long?)null); ClientConfiguration.Pipeline.LogTrace($"Download {Uri} with range: {pageRange}"); - ResponseWithHeaders response; + bool? rangeGetContentMD5 = null; + string structuredBodyType = null; + switch (transferValidationOverride?.ChecksumAlgorithm.ResolveAuto()) + { + case StorageChecksumAlgorithm.MD5: + rangeGetContentMD5 = true; + break; + case StorageChecksumAlgorithm.StorageCrc64: + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; + break; + default: + break; + } + ResponseWithHeaders response; if (async) { response = await FileRestClient.DownloadAsync( range: pageRange == default ? null : pageRange.ToString(), - rangeGetContentMD5: transferValidationOverride?.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.MD5 ? true : null, + rangeGetContentMD5: rangeGetContentMD5, + structuredBodyType: structuredBodyType, shareFileRequestConditions: conditions, cancellationToken: cancellationToken) .ConfigureAwait(false); @@ -2550,7 +2581,8 @@ await ContentHasher.AssertResponseHashMatchInternal( { response = FileRestClient.Download( range: pageRange == default ? null : pageRange.ToString(), - rangeGetContentMD5: transferValidationOverride?.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.MD5 ? true : null, + rangeGetContentMD5: rangeGetContentMD5, + structuredBodyType: structuredBodyType, shareFileRequestConditions: conditions, cancellationToken: cancellationToken); } @@ -4630,7 +4662,6 @@ internal async Task> UploadRangeInternal( CancellationToken cancellationToken) { UploadTransferValidationOptions validationOptions = transferValidationOverride ?? ClientConfiguration.TransferValidation.Upload; - ShareErrors.AssertAlgorithmSupport(validationOptions?.ChecksumAlgorithm); using (ClientConfiguration.Pipeline.BeginLoggingScope(nameof(ShareFileClient))) { @@ -4646,14 +4677,38 @@ internal async Task> UploadRangeInternal( scope.Start(); Errors.VerifyStreamPosition(content, nameof(content)); - // compute hash BEFORE attaching progress handler - ContentHasher.GetHashResult hashResult = await ContentHasher.GetHashOrDefaultInternal( - content, - validationOptions, - async, - cancellationToken).ConfigureAwait(false); - - content = content.WithNoDispose().WithProgress(progressHandler); + ContentHasher.GetHashResult hashResult = null; + long contentLength = (content?.Length - content?.Position) ?? 0; + long? structuredContentLength = default; + string structuredBodyType = null; + if (validationOptions != null && + validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64) + { + // report progress in terms of caller bytes, not encoded bytes + structuredContentLength = contentLength; + contentLength = (content?.Length - content?.Position) ?? 0; + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; + content = content.WithNoDispose().WithProgress(progressHandler); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); + contentLength = (content?.Length - content?.Position) ?? 0; + } + else + { + // compute hash BEFORE attaching progress handler + hashResult = await ContentHasher.GetHashOrDefaultInternal( + content, + validationOptions, + async, + cancellationToken).ConfigureAwait(false); + content = content.WithNoDispose().WithProgress(progressHandler); + } ResponseWithHeaders response; @@ -4666,6 +4721,8 @@ internal async Task> UploadRangeInternal( fileLastWrittenMode: fileLastWrittenMode, optionalbody: content, contentMD5: hashResult?.MD5AsArray, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, shareFileRequestConditions: conditions, cancellationToken: cancellationToken) .ConfigureAwait(false); @@ -4679,6 +4736,8 @@ internal async Task> UploadRangeInternal( fileLastWrittenMode: fileLastWrittenMode, optionalbody: content, contentMD5: hashResult?.MD5AsArray, + structuredBodyType: structuredBodyType, + structuredContentLength: structuredContentLength, shareFileRequestConditions: conditions, cancellationToken: cancellationToken); } diff --git a/sdk/storage/Azure.Storage.Files.Shares/src/autorest.md b/sdk/storage/Azure.Storage.Files.Shares/src/autorest.md index 2bcc0e37ee65a..ca0e5ae4c9160 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/src/autorest.md +++ b/sdk/storage/Azure.Storage.Files.Shares/src/autorest.md @@ -25,7 +25,7 @@ directive: if (property.includes('/{shareName}/{directory}/{fileName}')) { $[property]["parameters"] = $[property]["parameters"].filter(function(param) { return (typeof param['$ref'] === "undefined") || (false == param['$ref'].endsWith("#/parameters/ShareName") && false == param['$ref'].endsWith("#/parameters/DirectoryPath") && false == param['$ref'].endsWith("#/parameters/FilePath"))}); - } + } else if (property.includes('/{shareName}/{directory}')) { $[property]["parameters"] = $[property]["parameters"].filter(function(param) { return (typeof param['$ref'] === "undefined") || (false == param['$ref'].endsWith("#/parameters/ShareName") && false == param['$ref'].endsWith("#/parameters/DirectoryPath"))}); @@ -46,7 +46,7 @@ directive: $.Metrics.type = "object"; ``` -### Times aren't required +### Times aren't required ``` yaml directive: - from: swagger-document diff --git a/sdk/storage/Azure.Storage.Files.Shares/tests/Azure.Storage.Files.Shares.Tests.csproj b/sdk/storage/Azure.Storage.Files.Shares/tests/Azure.Storage.Files.Shares.Tests.csproj index 398a4b6367489..d09dd8fe8949f 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/tests/Azure.Storage.Files.Shares.Tests.csproj +++ b/sdk/storage/Azure.Storage.Files.Shares/tests/Azure.Storage.Files.Shares.Tests.csproj @@ -17,6 +17,7 @@ + PreserveNewest diff --git a/sdk/storage/Azure.Storage.Files.Shares/tests/ShareFileClientTransferValidationTests.cs b/sdk/storage/Azure.Storage.Files.Shares/tests/ShareFileClientTransferValidationTests.cs index 3dcdb21f27b36..9fd8905e388b1 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/tests/ShareFileClientTransferValidationTests.cs +++ b/sdk/storage/Azure.Storage.Files.Shares/tests/ShareFileClientTransferValidationTests.cs @@ -64,10 +64,6 @@ protected override async Task GetResourceClientAsync( private void AssertSupportsHashAlgorithm(StorageChecksumAlgorithm algorithm) { - if (algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64) - { - TestHelper.AssertInconclusiveRecordingFriendly(Recording.Mode, "Azure File Share does not support CRC64."); - } } protected override async Task UploadPartitionAsync(ShareFileClient client, Stream source, UploadTransferValidationOptions transferValidation) @@ -147,8 +143,44 @@ protected override async Task SetupDataAsync(ShareFileClient client, Stream data public override void TestAutoResolve() { Assert.AreEqual( - StorageChecksumAlgorithm.MD5, + StorageChecksumAlgorithm.StorageCrc64, TransferValidationOptionsExtensions.ResolveAuto(StorageChecksumAlgorithm.Auto)); } + + [Test] + public async Task StructuredMessagePopulatesCrcDownloadStreaming() + { + await using DisposingShare disposingContainer = await ClientBuilder.GetTestShareAsync(); + + const int dataLength = Constants.KB; + byte[] data = GetRandomBuffer(dataLength); + byte[] dataCrc = new byte[8]; + StorageCrc64Calculator.ComputeSlicedSafe(data, 0L).WriteCrc64(dataCrc); + + ShareFileClient file = disposingContainer.Container.GetRootDirectoryClient().GetFileClient(GetNewResourceName()); + await file.CreateAsync(data.Length); + await file.UploadAsync(new MemoryStream(data)); + + Response response = await file.DownloadAsync(new ShareFileDownloadOptions() + { + TransferValidation = new DownloadTransferValidationOptions + { + ChecksumAlgorithm = StorageChecksumAlgorithm.StorageCrc64 + } + }); + + // crc is not present until response stream is consumed + Assert.That(response.Value.ContentCrc, Is.Null); + + byte[] downloadedData; + using (MemoryStream ms = new()) + { + await response.Value.Content.CopyToAsync(ms); + downloadedData = ms.ToArray(); + } + + Assert.That(response.Value.ContentCrc, Is.EqualTo(dataCrc)); + Assert.That(downloadedData, Is.EqualTo(data)); + } } } diff --git a/sdk/storage/Azure.Storage.Queues/api/Azure.Storage.Queues.net6.0.cs b/sdk/storage/Azure.Storage.Queues/api/Azure.Storage.Queues.net6.0.cs index 96bc919c7a719..9f440eb3639d7 100644 --- a/sdk/storage/Azure.Storage.Queues/api/Azure.Storage.Queues.net6.0.cs +++ b/sdk/storage/Azure.Storage.Queues/api/Azure.Storage.Queues.net6.0.cs @@ -74,7 +74,7 @@ public QueueClient(System.Uri queueUri, Azure.Storage.StorageSharedKeyCredential } public partial class QueueClientOptions : Azure.Core.ClientOptions { - public QueueClientOptions(Azure.Storage.Queues.QueueClientOptions.ServiceVersion version = Azure.Storage.Queues.QueueClientOptions.ServiceVersion.V2024_11_04) { } + public QueueClientOptions(Azure.Storage.Queues.QueueClientOptions.ServiceVersion version = Azure.Storage.Queues.QueueClientOptions.ServiceVersion.V2025_01_05) { } public Azure.Storage.Queues.Models.QueueAudience? Audience { get { throw null; } set { } } public bool EnableTenantDiscovery { get { throw null; } set { } } public System.Uri GeoRedundantSecondaryUri { get { throw null; } set { } } @@ -426,7 +426,7 @@ public event System.EventHandler + PreserveNewest diff --git a/sdk/storage/Azure.Storage.Queues/tests/Azure.Storage.Queues.Tests.csproj b/sdk/storage/Azure.Storage.Queues/tests/Azure.Storage.Queues.Tests.csproj index e0a6fab3c753b..4d0334255f041 100644 --- a/sdk/storage/Azure.Storage.Queues/tests/Azure.Storage.Queues.Tests.csproj +++ b/sdk/storage/Azure.Storage.Queues/tests/Azure.Storage.Queues.Tests.csproj @@ -21,6 +21,7 @@ +