From 5ede910f9150f05efcab62c3c0a212ffb3e08c5a Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Tue, 13 Feb 2024 10:13:40 -0800 Subject: [PATCH] Core 2.0 Prototype: Merge "move buffering" and "response.Content" to a single PR (#41907) * initial rework * continued refactor * WIP * nits * fix and add some tests * update * update * fix * fix * fix * fix * Add messages to Debug.Asserts; update test to opt-out of buffering * plumb through network timeout from clientoptions * more test fixes * move template call into Azure.Core base type * fix and address tests that rely on transport not buffering * set NetworkTimeout in Pipeline.Send * add functional tests for buffering * update * refactor * reorg * refactor * nits * nit * rewwork a bit * rewwork a bit * updates * fix * initial move * clientmodel tests * Azure.Core implementation * export API * bug fix * updates * fix internal timeout property issue * bug fix * fix where dispose is called before buffer * bug fix * fix * updates from clientmodel PR * instrumentation for debugging async issue * remove instrumentation * bug fix and remove more instrumentation * updates * more cleanup * update * updates * keep in sync with PR to main * updates * add test for async exception factory --- sdk/core/Azure.Core/api/Azure.Core.net461.cs | 5 +- sdk/core/Azure.Core/api/Azure.Core.net472.cs | 5 +- sdk/core/Azure.Core/api/Azure.Core.net6.0.cs | 5 +- .../api/Azure.Core.netstandard2.0.cs | 5 +- sdk/core/Azure.Core/src/ClientOptions.cs | 2 + sdk/core/Azure.Core/src/HttpMessage.cs | 1 + .../Internal/AzureBaseBuffersExtensions.cs | 80 ++++ .../src/Pipeline/DisposableHttpPipeline.cs | 5 +- .../Pipeline/HttpClientTransport.Response.cs | 34 ++ .../src/Pipeline/HttpClientTransport.cs | 4 +- .../Azure.Core/src/Pipeline/HttpPipeline.cs | 30 +- .../src/Pipeline/HttpPipelineBuilder.cs | 16 +- .../Pipeline/HttpPipelineTransport.Adapter.cs | 49 +++ .../src/Pipeline/HttpPipelineTransport.cs | 24 +- .../src/Pipeline/HttpWebRequestTransport.cs | 2 + .../Internal/HttpPipelineTransportPolicy.cs | 6 +- .../src/Pipeline/Internal/LoggingPolicy.cs | 2 +- .../Pipeline/Internal/ResponseBodyPolicy.cs | 94 ---- sdk/core/Azure.Core/src/Request.cs | 2 + .../Azure.Core/src/RequestFailedException.cs | 55 ++- sdk/core/Azure.Core/src/Response.cs | 120 ++++- sdk/core/Azure.Core/src/RetryOptions.cs | 2 +- .../src/Shared/CancellationHelper.cs | 22 +- .../tests/DisposableHttpPipelineTests.cs | 4 +- sdk/core/Azure.Core/tests/EventSourceTests.cs | 7 +- .../tests/HttpPipelineFunctionalTests.cs | 8 +- .../tests/ResponseBodyPolicyTests.cs | 402 ----------------- .../tests/ResponseBufferingTests.cs | 415 ++++++++++++++++++ .../Azure.Core/tests/RetriableStreamTests.cs | 31 +- .../tests/TransportFunctionalTests.cs | 9 +- .../api/System.ClientModel.net6.0.cs | 11 +- .../api/System.ClientModel.netstandard2.0.cs | 11 +- .../src/Convenience/ClientResultException.cs | 22 +- .../src/Internal/CancellationHelper.cs | 27 +- .../src/Internal/ReadTimeoutStream.cs | 17 +- .../src/Internal/StreamExtensions.cs | 44 ++ .../src/Message/PipelineResponse.cs | 128 +----- .../src/Pipeline/ClientPipeline.cs | 4 - .../HttpClientPipelineTransport.Response.cs | 146 ++++-- .../Pipeline/HttpClientPipelineTransport.cs | 2 +- .../src/Pipeline/PipelineTransport.cs | 155 +++++-- .../src/Pipeline/ResponseBufferingPolicy.cs | 139 ------ .../ClientRequestExceptionTests.cs | 15 + .../Mocks/MockPipelineResponse.cs | 85 +++- .../Mocks/MockPipelineTransport.cs | 13 + .../Mocks/ObservableTransport.cs | 13 + .../tests/client/MapsClient/MapsClient.cs | 35 ++ .../tests/client/MapsClientTests.cs | 12 + .../nullableenabledclient/MapsClientTests.cs | 12 + .../src/TextAnalyticsFailedDetailsParser.cs | 115 +++-- 50 files changed, 1454 insertions(+), 998 deletions(-) create mode 100644 sdk/core/Azure.Core/src/Pipeline/HttpPipelineTransport.Adapter.cs delete mode 100644 sdk/core/Azure.Core/src/Pipeline/Internal/ResponseBodyPolicy.cs delete mode 100644 sdk/core/Azure.Core/tests/ResponseBodyPolicyTests.cs create mode 100644 sdk/core/Azure.Core/tests/ResponseBufferingTests.cs delete mode 100644 sdk/core/System.ClientModel/src/Pipeline/ResponseBufferingPolicy.cs diff --git a/sdk/core/Azure.Core/api/Azure.Core.net461.cs b/sdk/core/Azure.Core/api/Azure.Core.net461.cs index f13ca2da9aee5..0d825ec9eed92 100644 --- a/sdk/core/Azure.Core/api/Azure.Core.net461.cs +++ b/sdk/core/Azure.Core/api/Azure.Core.net461.cs @@ -232,6 +232,7 @@ public partial class RequestFailedException : System.ClientModel.ClientResultExc public RequestFailedException(string message) : base (default(System.ClientModel.Primitives.PipelineResponse), default(System.Exception)) { } public RequestFailedException(string message, System.Exception? innerException) : base (default(System.ClientModel.Primitives.PipelineResponse), default(System.Exception)) { } public string? ErrorCode { get { throw null; } } + public static System.Threading.Tasks.ValueTask CreateAsync(Azure.Response response, Azure.Core.RequestFailedDetailsParser? parser = null, System.Exception? innerException = null) { throw null; } public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public new Azure.Response? GetRawResponse() { throw null; } } @@ -239,13 +240,15 @@ public abstract partial class Response : System.ClientModel.Primitives.PipelineR { protected Response() { } public abstract string ClientRequestId { get; set; } - public virtual new System.BinaryData Content { get { throw null; } } + public override System.BinaryData Content { get { throw null; } } public virtual new Azure.Core.ResponseHeaders Headers { get { throw null; } } protected internal abstract bool ContainsHeader(string name); protected internal abstract System.Collections.Generic.IEnumerable EnumerateHeaders(); public static Azure.Response FromValue(T value, Azure.Response response) { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] protected override System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore() { throw null; } + public override System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public override System.Threading.Tasks.ValueTask ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] protected sealed override void SetIsErrorCore(bool isError) { } public override string ToString() { throw null; } diff --git a/sdk/core/Azure.Core/api/Azure.Core.net472.cs b/sdk/core/Azure.Core/api/Azure.Core.net472.cs index f13ca2da9aee5..0d825ec9eed92 100644 --- a/sdk/core/Azure.Core/api/Azure.Core.net472.cs +++ b/sdk/core/Azure.Core/api/Azure.Core.net472.cs @@ -232,6 +232,7 @@ public partial class RequestFailedException : System.ClientModel.ClientResultExc public RequestFailedException(string message) : base (default(System.ClientModel.Primitives.PipelineResponse), default(System.Exception)) { } public RequestFailedException(string message, System.Exception? innerException) : base (default(System.ClientModel.Primitives.PipelineResponse), default(System.Exception)) { } public string? ErrorCode { get { throw null; } } + public static System.Threading.Tasks.ValueTask CreateAsync(Azure.Response response, Azure.Core.RequestFailedDetailsParser? parser = null, System.Exception? innerException = null) { throw null; } public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public new Azure.Response? GetRawResponse() { throw null; } } @@ -239,13 +240,15 @@ public abstract partial class Response : System.ClientModel.Primitives.PipelineR { protected Response() { } public abstract string ClientRequestId { get; set; } - public virtual new System.BinaryData Content { get { throw null; } } + public override System.BinaryData Content { get { throw null; } } public virtual new Azure.Core.ResponseHeaders Headers { get { throw null; } } protected internal abstract bool ContainsHeader(string name); protected internal abstract System.Collections.Generic.IEnumerable EnumerateHeaders(); public static Azure.Response FromValue(T value, Azure.Response response) { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] protected override System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore() { throw null; } + public override System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public override System.Threading.Tasks.ValueTask ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] protected sealed override void SetIsErrorCore(bool isError) { } public override string ToString() { throw null; } diff --git a/sdk/core/Azure.Core/api/Azure.Core.net6.0.cs b/sdk/core/Azure.Core/api/Azure.Core.net6.0.cs index cde0e1a8e97ee..36b25e3e367bc 100644 --- a/sdk/core/Azure.Core/api/Azure.Core.net6.0.cs +++ b/sdk/core/Azure.Core/api/Azure.Core.net6.0.cs @@ -232,6 +232,7 @@ public partial class RequestFailedException : System.ClientModel.ClientResultExc public RequestFailedException(string message) : base (default(System.ClientModel.Primitives.PipelineResponse), default(System.Exception)) { } public RequestFailedException(string message, System.Exception? innerException) : base (default(System.ClientModel.Primitives.PipelineResponse), default(System.Exception)) { } public string? ErrorCode { get { throw null; } } + public static System.Threading.Tasks.ValueTask CreateAsync(Azure.Response response, Azure.Core.RequestFailedDetailsParser? parser = null, System.Exception? innerException = null) { throw null; } public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public new Azure.Response? GetRawResponse() { throw null; } } @@ -239,13 +240,15 @@ public abstract partial class Response : System.ClientModel.Primitives.PipelineR { protected Response() { } public abstract string ClientRequestId { get; set; } - public virtual new System.BinaryData Content { get { throw null; } } + public override System.BinaryData Content { get { throw null; } } public virtual new Azure.Core.ResponseHeaders Headers { get { throw null; } } protected internal abstract bool ContainsHeader(string name); protected internal abstract System.Collections.Generic.IEnumerable EnumerateHeaders(); public static Azure.Response FromValue(T value, Azure.Response response) { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] protected override System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore() { throw null; } + public override System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public override System.Threading.Tasks.ValueTask ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] protected sealed override void SetIsErrorCore(bool isError) { } public override string ToString() { throw null; } diff --git a/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs b/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs index f13ca2da9aee5..0d825ec9eed92 100644 --- a/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs +++ b/sdk/core/Azure.Core/api/Azure.Core.netstandard2.0.cs @@ -232,6 +232,7 @@ public partial class RequestFailedException : System.ClientModel.ClientResultExc public RequestFailedException(string message) : base (default(System.ClientModel.Primitives.PipelineResponse), default(System.Exception)) { } public RequestFailedException(string message, System.Exception? innerException) : base (default(System.ClientModel.Primitives.PipelineResponse), default(System.Exception)) { } public string? ErrorCode { get { throw null; } } + public static System.Threading.Tasks.ValueTask CreateAsync(Azure.Response response, Azure.Core.RequestFailedDetailsParser? parser = null, System.Exception? innerException = null) { throw null; } public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public new Azure.Response? GetRawResponse() { throw null; } } @@ -239,13 +240,15 @@ public abstract partial class Response : System.ClientModel.Primitives.PipelineR { protected Response() { } public abstract string ClientRequestId { get; set; } - public virtual new System.BinaryData Content { get { throw null; } } + public override System.BinaryData Content { get { throw null; } } public virtual new Azure.Core.ResponseHeaders Headers { get { throw null; } } protected internal abstract bool ContainsHeader(string name); protected internal abstract System.Collections.Generic.IEnumerable EnumerateHeaders(); public static Azure.Response FromValue(T value, Azure.Response response) { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] protected override System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore() { throw null; } + public override System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public override System.Threading.Tasks.ValueTask ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] protected sealed override void SetIsErrorCore(bool isError) { } public override string ToString() { throw null; } diff --git a/sdk/core/Azure.Core/src/ClientOptions.cs b/sdk/core/Azure.Core/src/ClientOptions.cs index 5dc1067de00fe..03cf891b98a48 100644 --- a/sdk/core/Azure.Core/src/ClientOptions.cs +++ b/sdk/core/Azure.Core/src/ClientOptions.cs @@ -14,6 +14,8 @@ namespace Azure.Core /// public abstract class ClientOptions : ClientPipelineOptions { + internal static readonly TimeSpan DefaultNetworkTimeout = TimeSpan.FromSeconds(100); + private HttpPipelineTransport _transport; internal bool IsCustomTransportSet { get; private set; } diff --git a/sdk/core/Azure.Core/src/HttpMessage.cs b/sdk/core/Azure.Core/src/HttpMessage.cs index 3fdf81972837c..f5f7428dc87ae 100644 --- a/sdk/core/Azure.Core/src/HttpMessage.cs +++ b/sdk/core/Azure.Core/src/HttpMessage.cs @@ -26,6 +26,7 @@ public HttpMessage(Request request, ResponseClassifier responseClassifier) Argument.AssertNotNull(request, nameof(request)); ResponseClassifier = responseClassifier; + NetworkTimeout = request.NetworkTimeout ?? ClientOptions.DefaultNetworkTimeout; } /// diff --git a/sdk/core/Azure.Core/src/Internal/AzureBaseBuffersExtensions.cs b/sdk/core/Azure.Core/src/Internal/AzureBaseBuffersExtensions.cs index 50087757232ca..9d1410e65b11a 100644 --- a/sdk/core/Azure.Core/src/Internal/AzureBaseBuffersExtensions.cs +++ b/sdk/core/Azure.Core/src/Internal/AzureBaseBuffersExtensions.cs @@ -13,6 +13,9 @@ namespace Azure.Core.Buffers { internal static class AzureBaseBuffersExtensions { + // Same value as Stream.CopyTo uses by default + private const int DefaultCopyBufferSize = 81920; + public static async Task WriteAsync(this Stream stream, ReadOnlyMemory buffer, CancellationToken cancellation = default) { Argument.AssertNotNull(stream, nameof(stream)); @@ -87,5 +90,82 @@ public static async Task WriteAsync(this Stream stream, ReadOnlySequence b ArrayPool.Shared.Return(array); } } + + public static async Task CopyToAsync(this Stream source, Stream destination, CancellationToken cancellationToken) + { + //using CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + //cts.CancelAfter(timeout); + + //// If cancellation is possible (whether due to network timeout or a user cancellation token being passed), then + //// register callback to dispose the stream on cancellation. + //if (timeout != Timeout.InfiniteTimeSpan || cancellationToken.CanBeCanceled) + //{ + // cts.Token.Register(state => ((Stream?)state)?.Dispose(), source); + //} + + byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); + + try + { + while (true) + { +#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets + int bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); +#pragma warning restore // ReadAsync(Memory<>) overload is not available in all targets + if (bytesRead == 0) + break; + await destination.WriteAsync(new ReadOnlyMemory(buffer, 0, bytesRead), cancellationToken).ConfigureAwait(false); + } + } + //catch (Exception ex) when (ex is ObjectDisposedException + // or IOException + // or OperationCanceledException + // or NotSupportedException) + //{ + // CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, cts.Token, ex, timeout); + // throw; + //} + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + public static void CopyTo(this Stream source, Stream destination, CancellationToken cancellationToken) + { + //using CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + //cts.CancelAfter(timeout); + + //// If cancellation is possible (whether due to network timeout or a user cancellation token being passed), then + //// register callback to dispose the stream on cancellation. + //if (timeout != Timeout.InfiniteTimeSpan || cancellationToken.CanBeCanceled) + //{ + // cts.Token.Register(state => ((Stream?)state)?.Dispose(), source); + //} + + byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); + + try + { + int read; + while ((read = source.Read(buffer, 0, buffer.Length)) != 0) + { + cancellationToken.ThrowIfCancellationRequested(); + destination.Write(buffer, 0, read); + } + } + //catch (Exception ex) when (ex is ObjectDisposedException + // or IOException + // or OperationCanceledException + // or NotSupportedException) + //{ + // CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, cts.Token, ex, timeout); + // throw; + //} + finally + { + ArrayPool.Shared.Return(buffer); + } + } } } diff --git a/sdk/core/Azure.Core/src/Pipeline/DisposableHttpPipeline.cs b/sdk/core/Azure.Core/src/Pipeline/DisposableHttpPipeline.cs index e21aebdc4abc1..d18ffa921ff1f 100644 --- a/sdk/core/Azure.Core/src/Pipeline/DisposableHttpPipeline.cs +++ b/sdk/core/Azure.Core/src/Pipeline/DisposableHttpPipeline.cs @@ -24,8 +24,9 @@ public sealed class DisposableHttpPipeline : HttpPipeline, IDisposable /// Policies to be invoked as part of the pipeline in order. /// The response classifier to be used in invocations. /// - internal DisposableHttpPipeline(HttpPipelineTransport transport, int perCallIndex, int perRetryIndex, HttpPipelinePolicy[] policies, ResponseClassifier responseClassifier, bool isTransportOwnedInternally) - : base(transport, perCallIndex, perRetryIndex, policies, responseClassifier) + /// + internal DisposableHttpPipeline(HttpPipelineTransport transport, int perCallIndex, int perRetryIndex, HttpPipelinePolicy[] policies, ResponseClassifier responseClassifier, bool isTransportOwnedInternally, TimeSpan networkTimeout) + : base(transport, perCallIndex, perRetryIndex, policies, responseClassifier, networkTimeout) { this.isTransportOwnedInternally = isTransportOwnedInternally; } diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.Response.cs b/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.Response.cs index 604df526b3e39..337a87892436a 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.Response.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.Response.cs @@ -1,11 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.ClientModel.Primitives; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; namespace Azure.Core.Pipeline { @@ -41,6 +44,21 @@ public override Stream? ContentStream set => _pipelineResponse.ContentStream = value; } + public override BinaryData Content + { + get + { + ResetContentStreamPosition(_pipelineResponse); + return _pipelineResponse.Content; + } + } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + => _pipelineResponse.ReadContent(cancellationToken); + + public override async ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + => await base.ReadContentAsync(cancellationToken).ConfigureAwait(false); + protected internal override bool ContainsHeader(string name) => _pipelineResponse.Headers.TryGetValue(name, out _); @@ -61,8 +79,24 @@ protected internal override bool TryGetHeaderValues(string name, [NotNullWhen(tr public override void Dispose() { PipelineResponse response = _pipelineResponse; + ResetContentStreamPosition(response); response?.Dispose(); } + + private void ResetContentStreamPosition(PipelineResponse response) + { + if (response.ContentStream is MemoryStream stream && stream.Position != 0) + { + // Azure.Core Response has a contract that ContentStream can be read + // without setting position back to 0. This means if ReadContent is + // called after such a read, the buffer will contain empty BinaryData. + + // So that the ClientModel response implementations don't throw, + // set the position back to 0 if Azure.Core Response default + // ReadContent was called. + stream.Position = 0; + } + } } } } diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.cs b/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.cs index 4bb6797727257..18c1bc678b329 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpClientTransport.cs @@ -99,11 +99,11 @@ public override async ValueTask ProcessAsync(HttpMessage message) { if (message.HasResponse) { - throw new RequestFailedException(message.Response, e.InnerException); + throw await RequestFailedException.CreateAsync(message.Response, innerException: e.InnerException).ConfigureAwait(false); } else { - throw new RequestFailedException(e.Message, e.InnerException); + throw new RequestFailedException(e.Message, innerException: e.InnerException); } } } diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpPipeline.cs b/sdk/core/Azure.Core/src/Pipeline/HttpPipeline.cs index 221270684e8a3..61bb9850a61f5 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpPipeline.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpPipeline.cs @@ -40,6 +40,8 @@ public class HttpPipeline /// private readonly int _perRetryIndex; + private readonly TimeSpan _networkTimeout; + /// /// Creates a new instance of with the provided transport, policies and response classifier. /// @@ -59,6 +61,7 @@ public HttpPipeline(HttpPipelineTransport transport, HttpPipelinePolicy[]? polic policies.CopyTo(all, 0); _pipeline = all; + _networkTimeout = ClientOptions.DefaultNetworkTimeout; } internal HttpPipeline( @@ -66,7 +69,8 @@ internal HttpPipeline( int perCallIndex, int perRetryIndex, HttpPipelinePolicy[] pipeline, - ResponseClassifier responseClassifier) + ResponseClassifier responseClassifier, + TimeSpan? networkTimeout) { ResponseClassifier = responseClassifier ?? throw new ArgumentNullException(nameof(responseClassifier)); @@ -77,6 +81,9 @@ internal HttpPipeline( _perCallIndex = perCallIndex; _perRetryIndex = perRetryIndex; + + _networkTimeout = networkTimeout ?? ClientOptions.DefaultNetworkTimeout; + _internallyConstructed = true; } @@ -85,14 +92,22 @@ internal HttpPipeline( /// /// The request. public Request CreateRequest() - => _transport.CreateRequest(); + { + Request request = _transport.CreateRequest(); + request.NetworkTimeout = _networkTimeout; + return request; + } /// /// Creates a new instance. /// /// The message. public HttpMessage CreateMessage() - => new(CreateRequest(), ResponseClassifier); + { + Request request = CreateRequest(); + HttpMessage message = new(request, ResponseClassifier); + return message; + } /// /// @@ -109,7 +124,8 @@ public HttpMessage CreateMessage(RequestContext? context) /// The message. public HttpMessage CreateMessage(RequestContext? context, ResponseClassifier? classifier = default) { - HttpMessage message = new(CreateRequest(), classifier ?? ResponseClassifier); + Request request = CreateRequest(); + HttpMessage message = new(request, classifier ?? ResponseClassifier); if (context != null) { @@ -134,6 +150,11 @@ public ValueTask SendAsync(HttpMessage message, CancellationToken cancellationTo { message.SetCancellationToken(cancellationToken); message.ProcessingStartTime = DateTimeOffset.UtcNow; + + // We must set NetworkTimeout here because the documentation for + // HttpMessage states that if a user sets this value to null, the + // pipeline will use the value set on ClientOptions. + message.NetworkTimeout ??= _networkTimeout; AddHttpMessageProperties(message); if (message.Policies == null || message.Policies.Count == 0) @@ -169,6 +190,7 @@ public void Send(HttpMessage message, CancellationToken cancellationToken) { message.SetCancellationToken(cancellationToken); message.ProcessingStartTime = DateTimeOffset.UtcNow; + message.NetworkTimeout ??= _networkTimeout; AddHttpMessageProperties(message); if (message.Policies == null || message.Policies.Count == 0) diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs b/sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs index e84ee002ae01c..d218be0c7b82e 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpPipelineBuilder.cs @@ -43,7 +43,7 @@ public static HttpPipeline Build( ((List)pipelineOptions.PerRetryPolicies).AddRange(perRetryPolicies); var result = BuildInternal(pipelineOptions, null); - return new HttpPipeline(result.Transport, result.PerCallIndex, result.PerRetryIndex, result.Policies, result.Classifier); + return new HttpPipeline(result.Transport, result.PerCallIndex, result.PerRetryIndex, result.Policies, result.Classifier, result.NetworkTimeout); } /// @@ -63,7 +63,7 @@ public static DisposableHttpPipeline Build(ClientOptions options, HttpPipelinePo ((List)pipelineOptions.PerCallPolicies).AddRange(perCallPolicies); ((List)pipelineOptions.PerRetryPolicies).AddRange(perRetryPolicies); var result = BuildInternal(pipelineOptions, transportOptions); - return new DisposableHttpPipeline(result.Transport, result.PerCallIndex, result.PerRetryIndex, result.Policies, result.Classifier, result.IsTransportOwned); + return new DisposableHttpPipeline(result.Transport, result.PerCallIndex, result.PerRetryIndex, result.Policies, result.Classifier, result.IsTransportOwned, result.NetworkTimeout); } /// @@ -74,7 +74,7 @@ public static DisposableHttpPipeline Build(ClientOptions options, HttpPipelinePo public static HttpPipeline Build(HttpPipelineOptions options) { var result = BuildInternal(options, null); - return new HttpPipeline(result.Transport, result.PerCallIndex, result.PerRetryIndex, result.Policies, result.Classifier); + return new HttpPipeline(result.Transport, result.PerCallIndex, result.PerRetryIndex, result.Policies, result.Classifier, result.NetworkTimeout); } /// @@ -87,17 +87,17 @@ public static DisposableHttpPipeline Build(HttpPipelineOptions options, HttpPipe { Argument.AssertNotNull(transportOptions, nameof(transportOptions)); var result = BuildInternal(options, transportOptions); - return new DisposableHttpPipeline(result.Transport, result.PerCallIndex, result.PerRetryIndex, result.Policies, result.Classifier, result.IsTransportOwned); + return new DisposableHttpPipeline(result.Transport, result.PerCallIndex, result.PerRetryIndex, result.Policies, result.Classifier, result.IsTransportOwned, result.NetworkTimeout); } - internal static (ResponseClassifier Classifier, HttpPipelineTransport Transport, int PerCallIndex, int PerRetryIndex, HttpPipelinePolicy[] Policies, bool IsTransportOwned) BuildInternal( + internal static (ResponseClassifier Classifier, HttpPipelineTransport Transport, int PerCallIndex, int PerRetryIndex, HttpPipelinePolicy[] Policies, bool IsTransportOwned, TimeSpan NetworkTimeout) BuildInternal( HttpPipelineOptions buildOptions, HttpPipelineTransportOptions? defaultTransportOptions) { Argument.AssertNotNull(buildOptions.PerCallPolicies, nameof(buildOptions.PerCallPolicies)); Argument.AssertNotNull(buildOptions.PerRetryPolicies, nameof(buildOptions.PerRetryPolicies)); - var policies = new List(8 + + var policies = new List(7 + (buildOptions.ClientOptions.Policies?.Count ?? 0) + buildOptions.PerCallPolicies.Count + buildOptions.PerRetryPolicies.Count); @@ -181,8 +181,6 @@ void AddNonNullPolicies(HttpPipelinePolicy[] policiesToAdd) policies.Add(new LoggingPolicy(diagnostics.IsLoggingContentEnabled, diagnostics.LoggedContentSizeLimit, sanitizer, assemblyName)); } - policies.Add(new ResponseBodyPolicy(buildOptions.ClientOptions.Retry.NetworkTimeout)); - policies.Add(new RequestActivityPolicy(isDistributedTracingEnabled, ClientDiagnostics.GetResourceProviderNamespace(buildOptions.ClientOptions.GetType().Assembly), sanitizer)); AddUserPolicies(HttpPipelinePosition.BeforeTransport); @@ -209,7 +207,7 @@ void AddNonNullPolicies(HttpPipelinePolicy[] policiesToAdd) buildOptions.ResponseClassifier ??= ResponseClassifier.Shared; - return (buildOptions.ResponseClassifier, transport, perCallIndex, perRetryIndex, policies.ToArray(), isTransportInternallyCreated); + return (buildOptions.ResponseClassifier, transport, perCallIndex, perRetryIndex, policies.ToArray(), isTransportInternallyCreated, buildOptions.ClientOptions.Retry.NetworkTimeout); } // internal for testing diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpPipelineTransport.Adapter.cs b/sdk/core/Azure.Core/src/Pipeline/HttpPipelineTransport.Adapter.cs new file mode 100644 index 0000000000000..3a2a1ff5a4d18 --- /dev/null +++ b/sdk/core/Azure.Core/src/Pipeline/HttpPipelineTransport.Adapter.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel.Primitives; +using System.Threading.Tasks; + +namespace Azure.Core.Pipeline; + +public partial class HttpPipelineTransport +{ + private class AzureCorePipelineTransport : PipelineTransport + { + private readonly HttpPipelineTransport _transport; + + public AzureCorePipelineTransport(HttpPipelineTransport transport) + { + _transport = transport; + } + + protected override PipelineMessage CreateMessageCore() + { + Request request = _transport.CreateRequest(); + return new HttpMessage(request, ResponseClassifier.Shared); + } + + protected override void ProcessCore(PipelineMessage message) + { + HttpMessage httpMessage = AssertHttpMessage(message); + _transport.Process(httpMessage); + } + + protected override async ValueTask ProcessCoreAsync(PipelineMessage message) + { + HttpMessage httpMessage = AssertHttpMessage(message); + await _transport.ProcessAsync(httpMessage).ConfigureAwait(false); + } + + private static HttpMessage AssertHttpMessage(PipelineMessage message) + { + if (message is not HttpMessage httpMessage) + { + throw new InvalidOperationException($"Invalid type for PipelineMessage: '{message?.GetType()}'."); + } + + return httpMessage; + } + } +} diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpPipelineTransport.cs b/sdk/core/Azure.Core/src/Pipeline/HttpPipelineTransport.cs index a9daa7f450077..c372a13d072ac 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpPipelineTransport.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpPipelineTransport.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System.Threading; +using System.ClientModel.Primitives; using System.Threading.Tasks; namespace Azure.Core.Pipeline @@ -9,8 +9,28 @@ namespace Azure.Core.Pipeline /// /// Represents an HTTP pipeline transport used to send HTTP requests and receive responses. /// - public abstract class HttpPipelineTransport + public abstract partial class HttpPipelineTransport { + private readonly PipelineTransport _transport; + + /// + /// TBD. + /// + protected HttpPipelineTransport() + { + _transport = new AzureCorePipelineTransport(this); + } + + internal void ProcessInternal(HttpMessage message) + { + _transport.Process(message); + } + + internal async ValueTask ProcessInternalAsync(HttpMessage message) + { + await _transport.ProcessAsync(message).ConfigureAwait(false); + } + /// /// Sends the request contained by the and sets the property to received response synchronously. /// diff --git a/sdk/core/Azure.Core/src/Pipeline/HttpWebRequestTransport.cs b/sdk/core/Azure.Core/src/Pipeline/HttpWebRequestTransport.cs index f3153bbe9011d..77b0501fc8fbb 100644 --- a/sdk/core/Azure.Core/src/Pipeline/HttpWebRequestTransport.cs +++ b/sdk/core/Azure.Core/src/Pipeline/HttpWebRequestTransport.cs @@ -328,6 +328,8 @@ public override Stream? ContentStream } } + // TODO: Implement Content and ReadContent + public override string ClientRequestId { get; set; } public override void Dispose() diff --git a/sdk/core/Azure.Core/src/Pipeline/Internal/HttpPipelineTransportPolicy.cs b/sdk/core/Azure.Core/src/Pipeline/Internal/HttpPipelineTransportPolicy.cs index 49f5dfce21f94..e1d3ecad8bf65 100644 --- a/sdk/core/Azure.Core/src/Pipeline/Internal/HttpPipelineTransportPolicy.cs +++ b/sdk/core/Azure.Core/src/Pipeline/Internal/HttpPipelineTransportPolicy.cs @@ -24,10 +24,12 @@ public override async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory { Debug.Assert(pipeline.IsEmpty); - await _transport.ProcessAsync(message).ConfigureAwait(false); + await _transport.ProcessInternalAsync(message).ConfigureAwait(false); message.Response.RequestFailedDetailsParser = _errorParser; message.Response.Sanitizer = _sanitizer; + + // TODO: I think we can remove the call to below? message.Response.SetIsError(message.ResponseClassifier.IsErrorResponse(message)); } @@ -35,7 +37,7 @@ public override void Process(HttpMessage message, ReadOnlyMemory - /// Pipeline policy to buffer response content or add a timeout to response content managed by the client - /// - internal class ResponseBodyPolicy : HttpPipelinePolicy - { - private readonly ResponseBufferingPolicy _policy; - private readonly TimeSpan _networkTimeout; - - public ResponseBodyPolicy(TimeSpan networkTimeout) - { - _policy = new ResponseBufferingPolicy(); - _networkTimeout = networkTimeout; - } - - public override async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) - => await ProcessSyncOrAsync(message, pipeline, async: true).ConfigureAwait(false); - - public override void Process(HttpMessage message, ReadOnlyMemory pipeline) - => ProcessSyncOrAsync(message, pipeline, async: false).EnsureCompleted(); - - private async ValueTask ProcessSyncOrAsync(HttpMessage message, ReadOnlyMemory pipeline, bool async) - { - AzureCorePipelineProcessor processor = new(pipeline); - - // Get the network timeout for this particular invocation of the pipeline. - // We either use the default that the policy was constructed with at - // pipeline-creation time, or we get an override value from the message that - // we use for the duration of this invocation only. - message.NetworkTimeout ??= _networkTimeout; - TimeSpan invocationNetworkTimeout = (TimeSpan)message.NetworkTimeout!; - - try - { - if (async) - { - await _policy.ProcessAsync(message, processor, -1).ConfigureAwait(false); - } - else - { - _policy.Process(message, processor, -1); - } - } - catch (TaskCanceledException e) - { - if (e.Message.Contains("The operation was cancelled because it exceeded the configured timeout")) - { - string exceptionMessage = e.Message + - $"Network timeout can be adjusted in {nameof(ClientOptions)}.{nameof(ClientOptions.Retry)}.{nameof(RetryOptions.NetworkTimeout)}."; -#if NETCOREAPP2_1_OR_GREATER - throw new TaskCanceledException(exceptionMessage, e.InnerException, e.CancellationToken); -#else - throw new TaskCanceledException(exceptionMessage, e.InnerException); -#endif - } - else - { - throw e; - } - } - } - - /// Throws a cancellation exception if cancellation has been requested via or . - /// The customer provided token. - /// The linked token that is cancelled on timeout provided token. - /// The inner exception to use. - /// The timeout used for the operation. -#pragma warning disable CA1068 // Cancellation token has to be the last parameter - internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken originalToken, CancellationToken timeoutToken, Exception? inner, TimeSpan timeout) -#pragma warning restore CA1068 - { - CancellationHelper.ThrowIfCancellationRequested(originalToken); - - if (timeoutToken.IsCancellationRequested) - { - throw CancellationHelper.CreateOperationCanceledException( - inner, - timeoutToken, - $"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. " + - $"Network timeout can be adjusted in {nameof(ClientOptions)}.{nameof(ClientOptions.Retry)}.{nameof(RetryOptions.NetworkTimeout)}."); - } - } - } -} diff --git a/sdk/core/Azure.Core/src/Request.cs b/sdk/core/Azure.Core/src/Request.cs index 1da30c23ee7fd..d71dde70c4eec 100644 --- a/sdk/core/Azure.Core/src/Request.cs +++ b/sdk/core/Azure.Core/src/Request.cs @@ -73,6 +73,8 @@ public virtual string ClientRequestId /// public new RequestHeaders Headers => new(this); + internal TimeSpan? NetworkTimeout { get; set; } + #region Overrides for "Core" methods from the PipelineRequest Template pattern /// diff --git a/sdk/core/Azure.Core/src/RequestFailedException.cs b/sdk/core/Azure.Core/src/RequestFailedException.cs index 2b93d44047f66..0b179002e162b 100644 --- a/sdk/core/Azure.Core/src/RequestFailedException.cs +++ b/sdk/core/Azure.Core/src/RequestFailedException.cs @@ -6,9 +6,9 @@ using System.Collections.Generic; using System.ComponentModel; using System.Globalization; -using System.IO; using System.Runtime.Serialization; using System.Text; +using System.Threading.Tasks; using Azure.Core; using Azure.Core.Pipeline; @@ -22,6 +22,19 @@ public class RequestFailedException : ClientResultException, ISerializable { private const string DefaultMessage = "Service request failed."; + /// + /// TBD. + /// + /// + /// + /// + /// + public static async ValueTask CreateAsync(Response response, RequestFailedDetailsParser? parser = default, Exception? innerException = default) + { + ErrorDetails details = await CreateExceptionDetailsAsync(response, parser).ConfigureAwait(false); + return new RequestFailedException(response, details, innerException); + } + /// /// Gets the service specific error code if available. Please refer to the client documentation for the list of supported error codes. /// @@ -52,7 +65,7 @@ public RequestFailedException(Response response, Exception? innerException) /// An inner exception to associate with the new . /// The parser to use to parse the response content. public RequestFailedException(Response response, Exception? innerException, RequestFailedDetailsParser? detailsParser) - : this(response, CreateRequestFailedExceptionContent(response, detailsParser), innerException) + : this(response, CreateExceptionDetails(response, detailsParser), innerException) { } @@ -148,9 +161,22 @@ public override void GetObjectData(SerializationInfo info, StreamingContext cont /// public new Response? GetRawResponse() => (Response?)base.GetRawResponse(); - private static ErrorDetails CreateRequestFailedExceptionContent(Response response, RequestFailedDetailsParser? parser) + private static ErrorDetails CreateExceptionDetails(Response response, RequestFailedDetailsParser? parser) + => CreateExceptionDetailsSyncOrAsync(response, parser, async: false).EnsureCompleted(); + + private static async ValueTask CreateExceptionDetailsAsync(Response response, RequestFailedDetailsParser? parser) + => await CreateExceptionDetailsSyncOrAsync(response, parser, async: true).ConfigureAwait(false); + + private static async ValueTask CreateExceptionDetailsSyncOrAsync(Response response, RequestFailedDetailsParser? parser, bool async) { - BufferResponseIfNeeded(response); + if (async) + { + await response.ReadContentAsync().ConfigureAwait(false); + } + else + { + response.ReadContent(); + } parser ??= response.RequestFailedDetailsParser; @@ -202,7 +228,7 @@ private static ErrorDetails CreateRequestFailedExceptionContent(Response respons } } - if (response.ContentStream is MemoryStream && ContentTypeUtilities.TryGetTextEncoding(response.Headers.ContentType, out Encoding _)) + if (ContentTypeUtilities.TryGetTextEncoding(response.Headers.ContentType, out Encoding _)) { messageBuilder .AppendLine() @@ -224,25 +250,6 @@ private static ErrorDetails CreateRequestFailedExceptionContent(Response respons return new ErrorDetails(messageBuilder.ToString(), error?.Code, additionalInfo); } - private static void BufferResponseIfNeeded(Response response) - { - // Buffer into a memory stream if not already buffered - if (response.ContentStream is null or MemoryStream) - { - return; - } - - var bufferedStream = new MemoryStream(); - response.ContentStream.CopyTo(bufferedStream); - - // Dispose the unbuffered stream - response.ContentStream.Dispose(); - - // Reset the position of the buffered stream and set it on the response - bufferedStream.Position = 0; - response.ContentStream = bufferedStream; - } - // This class needs to be internal rather than private so that it can be used // by the System.Text.Json source generator. internal class ErrorResponse diff --git a/sdk/core/Azure.Core/src/Response.cs b/sdk/core/Azure.Core/src/Response.cs index 915394de89393..93500999122d2 100644 --- a/sdk/core/Azure.Core/src/Response.cs +++ b/sdk/core/Azure.Core/src/Response.cs @@ -5,9 +5,13 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.ComponentModel; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Threading; +using System.Threading.Tasks; using Azure.Core; +using Azure.Core.Buffers; namespace Azure { @@ -18,6 +22,9 @@ namespace Azure public abstract class Response : PipelineResponse #pragma warning restore AZC0012 // Avoid single word type names { + // TODO(matell): The .NET Framework team plans to add BinaryData.Empty in dotnet/runtime#49670, and we can use it then. + private static readonly BinaryData s_EmptyBinaryData = new(Array.Empty()); + /// /// Gets the client request id that was sent to the server as x-ms-client-request-id headers. /// @@ -28,14 +35,6 @@ public abstract class Response : PipelineResponse /// public new virtual ResponseHeaders Headers => new ResponseHeaders(this); - /// - /// Gets the contents of HTTP response, if it is available. - /// - /// - /// Throws when is not a . - /// - public new virtual BinaryData Content => base.Content; - /// /// TBD. /// @@ -48,6 +47,25 @@ protected override PipelineResponseHeaders GetHeadersCore() throw new NotImplementedException(); } + /// + /// Gets the contents of HTTP response, if it is available. + /// + /// + /// Throws when content is not buffered. + /// + public override BinaryData Content + { + get + { + if (ContentStream is null || ContentStream is MemoryStream) + { + return ReadContent(); + } + + throw new InvalidOperationException($"The response is not buffered."); + } + } + internal HttpMessageSanitizer Sanitizer { get; set; } = HttpMessageSanitizer.Default; internal RequestFailedDetailsParser? RequestFailedDetailsParser { get; set; } @@ -124,6 +142,82 @@ internal static void DisposeStreamIfNotBuffered(ref Stream? stream) } } + /// + /// TBD. + /// + /// + /// + /// + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + // Derived types should provide an implementation that allows caching + // to improve performance. + if (ContentStream is null) + { + return s_EmptyBinaryData; + } + + if (ContentStream is MemoryStream memoryStream) + { + return memoryStream.TryGetBuffer(out ArraySegment segment) ? + new BinaryData(segment.AsMemory()) : + new BinaryData(memoryStream.ToArray()); + } + + BufferedContentStream bufferStream = new(); + + Stream? contentStream = ContentStream; + contentStream.CopyTo(bufferStream, cancellationToken); + contentStream.Dispose(); + + bufferStream.Position = 0; + ContentStream = bufferStream; + + BinaryData content = BinaryData.FromStream(bufferStream); + bufferStream.Position = 0; + + return content; + } + + /// + /// TBD. + /// + /// + /// + /// + public override async ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + // Derived types should provide an implementation that allows caching + // to improve performance. + if (ContentStream is null) + { + return s_EmptyBinaryData; + } + + if (ContentStream is MemoryStream memoryStream) + { + return memoryStream.TryGetBuffer(out ArraySegment segment) ? + new BinaryData(segment.AsMemory()) : + new BinaryData(memoryStream.ToArray()); + } + + BufferedContentStream bufferStream = new(); + + Stream? contentStream = ContentStream; + await contentStream.CopyToAsync(bufferStream, cancellationToken).ConfigureAwait(false); + contentStream.Dispose(); + + bufferStream.Position = 0; + ContentStream = bufferStream; + + BinaryData content = BinaryData.FromStream(bufferStream); + bufferStream.Position = 0; + + return content; + } + + private class BufferedContentStream : MemoryStream { } + #region Private implementation subtypes of abstract Response types private class AzureCoreResponse : Response { @@ -180,6 +274,16 @@ protected internal override bool TryGetHeaderValues(string name, [NotNullWhen(tr { throw new NotSupportedException(DefaultMessage); } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + throw new NotSupportedException(DefaultMessage); + } + + public override ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + throw new NotSupportedException(DefaultMessage); + } } #endregion } diff --git a/sdk/core/Azure.Core/src/RetryOptions.cs b/sdk/core/Azure.Core/src/RetryOptions.cs index f3f84dcac5887..d056f884ef068 100644 --- a/sdk/core/Azure.Core/src/RetryOptions.cs +++ b/sdk/core/Azure.Core/src/RetryOptions.cs @@ -20,7 +20,7 @@ public class RetryOptions private TimeSpan _delay = DefaultInitialDelay; private TimeSpan _maxDelay = DefaultMaxDelay; private RetryMode _retryMode = RetryMode.Exponential; - private TimeSpan _networkTimeout = TimeSpan.FromSeconds(100); + private TimeSpan _networkTimeout = ClientOptions.DefaultNetworkTimeout; private bool _frozen; diff --git a/sdk/core/Azure.Core/src/Shared/CancellationHelper.cs b/sdk/core/Azure.Core/src/Shared/CancellationHelper.cs index 73773b9a3fcda..43019e0c5274a 100644 --- a/sdk/core/Azure.Core/src/Shared/CancellationHelper.cs +++ b/sdk/core/Azure.Core/src/Shared/CancellationHelper.cs @@ -53,5 +53,25 @@ internal static void ThrowIfCancellationRequested(CancellationToken cancellation ThrowOperationCanceledException(innerException: null, cancellationToken); } } + + /// Throws a cancellation exception if cancellation has been requested via or . + /// The customer provided token. + /// The linked token that is cancelled on timeout provided token. + /// The inner exception to use. + /// The timeout used for the operation. +#pragma warning disable CA1068 // Cancellation token has to be the last parameter + internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken cancellationToken, CancellationToken timeoutToken, Exception? innerException, TimeSpan timeout) +#pragma warning restore CA1068 + { + ThrowIfCancellationRequested(cancellationToken); + + if (timeoutToken.IsCancellationRequested) + { + throw CreateOperationCanceledException( + innerException, + timeoutToken, + $"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. "); + } + } } -} \ No newline at end of file +} diff --git a/sdk/core/Azure.Core/tests/DisposableHttpPipelineTests.cs b/sdk/core/Azure.Core/tests/DisposableHttpPipelineTests.cs index b81913b71edd0..41ce8fd4239d6 100644 --- a/sdk/core/Azure.Core/tests/DisposableHttpPipelineTests.cs +++ b/sdk/core/Azure.Core/tests/DisposableHttpPipelineTests.cs @@ -15,7 +15,7 @@ public class DisposableHttpPipelineTests public void DisposeWithDisposableTransport([Values(true, false)] bool isOwned) { var transport = new MockDisposableHttpPipelineTransport(); - var target = new DisposableHttpPipeline(transport, 0, 0, new[] { new MockPolicy(transport, HttpMessageSanitizer.Default) }, ResponseClassifier.Shared, isOwned); + var target = new DisposableHttpPipeline(transport, 0, 0, new[] { new MockPolicy(transport, HttpMessageSanitizer.Default) }, ResponseClassifier.Shared, isOwned, ClientOptions.DefaultNetworkTimeout); target.Dispose(); Assert.AreEqual(isOwned, transport.DisposeCalled); @@ -25,7 +25,7 @@ public void DisposeWithDisposableTransport([Values(true, false)] bool isOwned) public void DisposeWithoutDisposableTransport([Values(true, false)] bool isOwned) { var transport = new MockHttpPipelineTransport(); - var target = new DisposableHttpPipeline(transport, 0, 0, new[] { new MockPolicy(transport, HttpMessageSanitizer.Default) }, ResponseClassifier.Shared, isOwned); + var target = new DisposableHttpPipeline(transport, 0, 0, new[] { new MockPolicy(transport, HttpMessageSanitizer.Default) }, ResponseClassifier.Shared, isOwned, ClientOptions.DefaultNetworkTimeout); target.Dispose(); } diff --git a/sdk/core/Azure.Core/tests/EventSourceTests.cs b/sdk/core/Azure.Core/tests/EventSourceTests.cs index 6f7df6e3a2c95..7e8911b35eb1b 100644 --- a/sdk/core/Azure.Core/tests/EventSourceTests.cs +++ b/sdk/core/Azure.Core/tests/EventSourceTests.cs @@ -654,7 +654,12 @@ private async Task SendRequest(bool isSeekable, bool isError, Action(async () => await ExecuteRequest(message, httpPipeline)); - Assert.AreEqual("The operation was cancelled because it exceeded the configured timeout of 0:00:00.5. " + - "Network timeout can be adjusted in ClientOptions.Retry.NetworkTimeout.", exception.Message); + Assert.AreEqual("The operation was cancelled because it exceeded the configured timeout of 0:00:00.5. ", + exception.Message); testDoneTcs.Cancel(); } @@ -530,8 +530,8 @@ public void TimeoutsBodyBuffering() message.BufferResponse = true; var exception = Assert.ThrowsAsync(async () => await ExecuteRequest(message, httpPipeline)); - Assert.AreEqual("The operation was cancelled because it exceeded the configured timeout of 0:00:00.5. " + - "Network timeout can be adjusted in ClientOptions.Retry.NetworkTimeout.", exception.Message); + Assert.AreEqual("The operation was cancelled because it exceeded the configured timeout of 0:00:00.5. ", + exception.Message); testDoneTcs.Cancel(); } diff --git a/sdk/core/Azure.Core/tests/ResponseBodyPolicyTests.cs b/sdk/core/Azure.Core/tests/ResponseBodyPolicyTests.cs deleted file mode 100644 index 26fbc80452100..0000000000000 --- a/sdk/core/Azure.Core/tests/ResponseBodyPolicyTests.cs +++ /dev/null @@ -1,402 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.IO; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Azure.Core.Pipeline; -using Azure.Core.TestFramework; -using NUnit.Framework; -using NUnit.Framework.Internal; - -namespace Azure.Core.Tests -{ - public class ResponseBodyPolicyTests : SyncAsyncPolicyTestBase - { - private static HttpPipelinePolicy NoTimeoutPolicy = new ResponseBodyPolicy(Timeout.InfiniteTimeSpan); - - private static HttpPipelinePolicy TimeoutPolicy = new ResponseBodyPolicy(TimeSpan.FromMilliseconds(50)); - - public ResponseBodyPolicyTests(bool isAsync) : base(isAsync) { } - - [Test] - public async Task ReadsEntireBodyIntoMemoryStream() - { - MockResponse mockResponse = new MockResponse(200); - var readTrackingStream = new ReadTrackingStream(128, int.MaxValue); - mockResponse.ContentStream = readTrackingStream; - - MockTransport mockTransport = CreateMockTransport(mockResponse); - Response response = await SendGetRequest(mockTransport, NoTimeoutPolicy); - - Assert.IsInstanceOf(response.ContentStream); - var ms = (MemoryStream)response.ContentStream; - - Assert.AreEqual(128, ms.Length); - foreach (var b in ms.ToArray()) - { - Assert.AreEqual(ReadTrackingStream.ContentByteValue, b); - } - Assert.AreEqual(128, readTrackingStream.BytesRead); - Assert.AreEqual(0, ms.Position); - } - - [Test] - public void SurfacesStreamReadingExceptions() - { - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = new ReadTrackingStream(128, 64) - }; - - MockTransport mockTransport = CreateMockTransport(mockResponse); - Assert.ThrowsAsync(async () => await SendGetRequest(mockTransport, NoTimeoutPolicy)); - } - - [Test] - public async Task SkipsResponsesWithoutContent() - { - MockResponse mockResponse = new MockResponse(200); - - MockTransport mockTransport = CreateMockTransport(mockResponse); - Response response = await SendGetRequest(mockTransport, NoTimeoutPolicy); - Assert.Null(response.ContentStream); - } - - [Test] - public async Task ClosesStreamAfterCopying() - { - ReadTrackingStream readTrackingStream = new ReadTrackingStream(128, int.MaxValue); - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = readTrackingStream - }; - - MockTransport mockTransport = CreateMockTransport(mockResponse); - await SendGetRequest(mockTransport, NoTimeoutPolicy); - - Assert.True(readTrackingStream.IsClosed); - } - - [Test] - public async Task DoesntBufferWhenDisabled() - { - ReadTrackingStream readTrackingStream = new ReadTrackingStream(128, int.MaxValue); - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = readTrackingStream - }; - - MockTransport mockTransport = CreateMockTransport(mockResponse); - Response response = await SendGetRequest(mockTransport, NoTimeoutPolicy, bufferResponse: false); - - Assert.IsNotInstanceOf(response.ContentStream); - } - - [Test] - public async Task WrapsNonBufferedStreamsWithTimeoutStream() - { - var hangingStream = new HangingReadStream(); - - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = hangingStream - }; - - MockTransport mockTransport = new MockTransport(mockResponse); - Response response = await SendGetRequest(mockTransport, TimeoutPolicy, bufferResponse: false); - - var buffer = new byte[100]; - Assert.ThrowsAsync(async () => await response.ContentStream.ReadAsync(buffer, 0, 100)); - Assert.AreEqual(50, hangingStream.ReadTimeout); - } - - [Test] - public async Task WrapsNonBufferedStreamsWithTimeoutStreamCopyToAsync() - { - var hangingStream = new HangingReadStream(); - - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = hangingStream - }; - - MockTransport mockTransport = new MockTransport(mockResponse); - Response response = await SendGetRequest(mockTransport, TimeoutPolicy, bufferResponse: false); - - var memoryStream = new MemoryStream(); - Assert.ThrowsAsync(async () => await response.ContentStream.CopyToAsync(memoryStream)); - Assert.AreEqual(50, hangingStream.ReadTimeout); - } - - [Test] - public async Task SetsReadTimeoutToProvidedValue() - { - var hangingStream = new HangingReadStream(); - - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = hangingStream - }; - - MockTransport mockTransport = new MockTransport(mockResponse); - Response response = await SendGetRequest(mockTransport, new ResponseBodyPolicy(TimeSpan.FromMilliseconds(1234567)), bufferResponse: false); - - //Assert.IsInstanceOf(response.ContentStream); - Assert.IsFalse(response.ContentStream.CanWrite); - Assert.AreEqual(1234567, hangingStream.ReadTimeout); - } - - [Test] - public async Task BufferingRespectsCancellationToken() - { - var slowReadStream = new SlowReadStream(); - - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = slowReadStream - }; - - MockTransport mockTransport = CreateMockTransport(mockResponse); - CancellationTokenSource cts = new CancellationTokenSource(100); - - Task getRequestTask = Task.Run(async () => await SendGetRequest(mockTransport, NoTimeoutPolicy, bufferResponse: true, cancellationToken: cts.Token)); - - await slowReadStream.StartedReader.Task; - - cts.Cancel(); - - Assert.That(async () => await getRequestTask, Throws.InstanceOf()); - } - - [Test] - public void CanOverrideDefaultNetworkTimeout() - { - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - MockTransport mockTransport = MockTransport.FromMessageCallback(message => - { - tcs.Task.Wait(message.CancellationToken); - return null; - }); - - var exception = Assert.ThrowsAsync(async () => await SendRequestAsync(mockTransport, message => - { - message.NetworkTimeout = TimeSpan.FromMilliseconds(30); - }, new ResponseBodyPolicy(TimeSpan.MaxValue), bufferResponse: false)); - Assert.AreEqual("The operation was cancelled because it exceeded the configured timeout of 0:00:00.03. " + - "Network timeout can be adjusted in ClientOptions.Retry.NetworkTimeout.", exception.Message); - } - - [Test] - public async Task CanOverrideDefaultNetworkTimeout_Stream() - { - var hangingStream = new HangingReadStream(); - - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = hangingStream - }; - - MockTransport mockTransport = new MockTransport(mockResponse); - Response response = await SendRequestAsync(mockTransport, message => - { - message.NetworkTimeout = TimeSpan.FromMilliseconds(30); - }, new ResponseBodyPolicy(TimeSpan.MaxValue), bufferResponse: false); - - //Assert.IsInstanceOf(response.ContentStream); - Assert.IsFalse(response.ContentStream.CanWrite); - Assert.AreEqual(30, hangingStream.ReadTimeout); - } - - private static IEnumerable GetExceptionCases() - { - yield return new object[] { new IOException(), TimeoutPolicy}; - yield return new object[] { new IOException(), NoTimeoutPolicy}; - yield return new object[] { new ObjectDisposedException("test"), TimeoutPolicy}; - yield return new object[] { new ObjectDisposedException("test"), NoTimeoutPolicy}; - yield return new object[] { new OperationCanceledException(), TimeoutPolicy}; - yield return new object[] { new OperationCanceledException(), NoTimeoutPolicy}; - yield return new object[] { new NotSupportedException(), TimeoutPolicy}; - yield return new object[] { new NotSupportedException(), NoTimeoutPolicy}; - } - - [TestCaseSource(nameof(GetExceptionCases))] - public void ExceptionsTranslatedCorrectlyWhenCanceled(Exception exception, HttpPipelinePolicy policy) - { - var cts = new CancellationTokenSource(); - var stream = new CancelingStream(cts, exception); - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = stream - }; - - MockTransport mockTransport = CreateMockTransport(mockResponse); - Assert.ThrowsAsync(async () => await SendGetRequest(mockTransport, policy, cancellationToken: cts.Token)); - Assert.IsTrue(stream.IsClosed); - } - - [TestCaseSource(nameof(GetExceptionCases))] - public void ExceptionsNotTranslatedWhenNotCanceled(Exception exception, HttpPipelinePolicy policy) - { - var cts = new CancellationTokenSource(); - var stream = new CancelingStream(cts, exception, false); - MockResponse mockResponse = new MockResponse(200) - { - ContentStream = stream - }; - - MockTransport mockTransport = CreateMockTransport(mockResponse); - var thrown = Assert.CatchAsync(async () => await SendGetRequest(mockTransport, policy, cancellationToken: cts.Token)); - Assert.AreSame(exception, thrown); - Assert.IsFalse(stream.IsClosed); - } - - private class SlowReadStream : TestReadStream - { - public readonly TaskCompletionSource StartedReader = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - StartedReader.TrySetResult(null); - await Task.Delay(20, cancellationToken); - return 10; - } - - public override int Read(byte[] buffer, int offset, int count) - { - StartedReader.TrySetResult(null); - Thread.Sleep(20); - return 10; - } - } - - private class HangingReadStream : TestReadStream - { - public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - await Task.Delay(Timeout.Infinite, cancellationToken); - return 0; - } - - public override int Read(byte[] buffer, int offset, int count) - { - throw new NotImplementedException(); - } - - public override int ReadTimeout { get; set; } - - public override bool CanTimeout { get; } = true; - } - - private class ReadTrackingStream : TestReadStream - { - public const int ContentByteValue = 233; - - private readonly int _size; - - private readonly int _throwAfter; - - public ReadTrackingStream(int size, int throwAfter) - { - _size = size; - _throwAfter = throwAfter; - } - - public int BytesRead { get; set; } - - public override int Read(byte[] buffer, int offset, int count) - { - if (BytesRead == _size) - { - return 0; - } - - int left = Math.Min(count, _size); - Span span = buffer.AsSpan(offset, left); - - for (int i = 0; i < span.Length; i++) - { - span[i] = ContentByteValue; - } - - BytesRead += left; - - if (BytesRead > _throwAfter) - { - throw new IOException(); - } - - return left; - } - - public override void Close() - { - IsClosed = true; - base.Close(); - } - - public bool IsClosed { get; set; } - } - - private class CancelingStream : TestReadStream - { - private readonly Exception _exceptionToThrow; - private readonly CancellationTokenSource _cancellationTokenSource; - private readonly bool _cancel; - - public CancelingStream(CancellationTokenSource cancellationTokenSource, Exception exceptionToThrow, bool cancel = true) - { - _exceptionToThrow = exceptionToThrow; - _cancellationTokenSource = cancellationTokenSource; - _cancel = cancel; - } - - public override int Read(byte[] buffer, int offset, int count) - { - if (_cancel) - _cancellationTokenSource.Cancel(); - throw _exceptionToThrow; - } - - public override void Close() - { - IsClosed = true; - base.Close(); - } - - public bool IsClosed { get; set; } - } - - private abstract class TestReadStream: Stream - { - public override bool CanRead { get; } = true; - public override bool CanSeek { get; } - public override bool CanWrite { get; } - public override long Length { get; } - public override long Position { get; set; } - - public override void Flush() - { - throw new System.NotImplementedException(); - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new System.NotImplementedException(); - } - - public override void SetLength(long value) - { - throw new System.NotImplementedException(); - } - - public override void Write(byte[] buffer, int offset, int count) - { - throw new System.NotImplementedException(); - } - } - } -} diff --git a/sdk/core/Azure.Core/tests/ResponseBufferingTests.cs b/sdk/core/Azure.Core/tests/ResponseBufferingTests.cs new file mode 100644 index 0000000000000..b0954f01cb836 --- /dev/null +++ b/sdk/core/Azure.Core/tests/ResponseBufferingTests.cs @@ -0,0 +1,415 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; +using Azure.Core.TestFramework; +using NUnit.Framework; +using NUnit.Framework.Internal; + +namespace Azure.Core.Tests; + +public class ResponseBufferingTests : SyncAsyncPolicyTestBase +{ + public ResponseBufferingTests(bool isAsync) : base(isAsync) + { + } + + [Test] + public async Task ReadsEntireBodyIntoMemoryStream() + { + MockResponse mockResponse = new MockResponse(200); + var readTrackingStream = new ReadTrackingStream(128, int.MaxValue); + mockResponse.ContentStream = readTrackingStream; + + MockTransport mockTransport = CreateMockTransport(mockResponse); + Response response = await SendGetRequestAsync(mockTransport, Timeout.InfiniteTimeSpan); + + Assert.IsInstanceOf(response.ContentStream); + var ms = (MemoryStream)response.ContentStream; + + Assert.AreEqual(128, ms.Length); + foreach (var b in ms.ToArray()) + { + Assert.AreEqual(ReadTrackingStream.ContentByteValue, b); + } + Assert.AreEqual(128, readTrackingStream.BytesRead); + Assert.AreEqual(0, ms.Position); + } + + [Test] + public void SurfacesStreamReadingExceptions() + { + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = new ReadTrackingStream(128, 64) + }; + + MockTransport mockTransport = CreateMockTransport(mockResponse); + Assert.ThrowsAsync(async () => await SendGetRequestAsync(mockTransport, Timeout.InfiniteTimeSpan)); + } + + [Test] + public async Task SkipsResponsesWithoutContent() + { + MockResponse mockResponse = new MockResponse(200); + + MockTransport mockTransport = CreateMockTransport(mockResponse); + Response response = await SendGetRequestAsync(mockTransport, Timeout.InfiniteTimeSpan); + Assert.Null(response.ContentStream); + } + + [Test] + public async Task ClosesStreamAfterCopying() + { + ReadTrackingStream readTrackingStream = new ReadTrackingStream(128, int.MaxValue); + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = readTrackingStream + }; + + MockTransport mockTransport = CreateMockTransport(mockResponse); + await SendGetRequestAsync(mockTransport, Timeout.InfiniteTimeSpan); + + Assert.True(readTrackingStream.IsClosed); + } + + [Test] + public async Task DoesntBufferWhenDisabled() + { + ReadTrackingStream readTrackingStream = new ReadTrackingStream(128, int.MaxValue); + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = readTrackingStream + }; + + MockTransport mockTransport = CreateMockTransport(mockResponse); + Response response = await SendGetRequestAsync(mockTransport, Timeout.InfiniteTimeSpan, bufferResponse: false); + + Assert.IsNotInstanceOf(response.ContentStream); + } + + [Test] + public async Task WrapsNonBufferedStreamsWithTimeoutStream() + { + var hangingStream = new HangingReadStream(); + + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = hangingStream + }; + + MockTransport mockTransport = new MockTransport(mockResponse); + Response response = await SendGetRequestAsync(mockTransport, TimeSpan.FromMilliseconds(50), bufferResponse: false); + + var buffer = new byte[100]; + Assert.ThrowsAsync(async () => await response.ContentStream.ReadAsync(buffer, 0, 100)); + Assert.AreEqual(50, hangingStream.ReadTimeout); + } + + [Test] + public async Task WrapsNonBufferedStreamsWithTimeoutStreamCopyToAsync() + { + var hangingStream = new HangingReadStream(); + + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = hangingStream + }; + + MockTransport mockTransport = new MockTransport(mockResponse); + Response response = await SendGetRequestAsync(mockTransport, TimeSpan.FromMilliseconds(50), bufferResponse: false); + + var memoryStream = new MemoryStream(); + Assert.ThrowsAsync(async () => await response.ContentStream.CopyToAsync(memoryStream)); + Assert.AreEqual(50, hangingStream.ReadTimeout); + } + + [Test] + public async Task SetsReadTimeoutToProvidedValue() + { + var hangingStream = new HangingReadStream(); + + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = hangingStream + }; + + MockTransport mockTransport = new MockTransport(mockResponse); + Response response = await SendGetRequestAsync(mockTransport, TimeSpan.FromMilliseconds(1234567), bufferResponse: false); + + //Assert.IsInstanceOf(response.ContentStream); + Assert.IsFalse(response.ContentStream.CanWrite); + Assert.AreEqual(1234567, hangingStream.ReadTimeout); + } + + [Test] + public async Task BufferingRespectsCancellationToken() + { + var slowReadStream = new SlowReadStream(); + + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = slowReadStream + }; + + MockTransport mockTransport = CreateMockTransport(mockResponse); + CancellationTokenSource cts = new CancellationTokenSource(100); + + Task getRequestTask = Task.Run(async () => await SendGetRequestAsync(mockTransport, Timeout.InfiniteTimeSpan, bufferResponse: true, cancellationToken: cts.Token)); + + await slowReadStream.StartedReader.Task; + + cts.Cancel(); + + Assert.That(async () => await getRequestTask, Throws.InstanceOf()); + } + + [Test] + public void CanOverrideDefaultNetworkTimeout() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + MockTransport mockTransport = MockTransport.FromMessageCallback(message => + { + tcs.Task.Wait(message.CancellationToken); + return null; + }); + + var exception = Assert.ThrowsAsync(async () => await SendGetRequestAsync(mockTransport, TimeSpan.FromMilliseconds(30), bufferResponse: false)); + Assert.AreEqual("The operation was cancelled because it exceeded the configured timeout of 0:00:00.03. ", + exception.Message); + } + + private static IEnumerable GetExceptionCases() + { + yield return new object[] { new IOException(), TimeSpan.FromMilliseconds(50) }; + yield return new object[] { new IOException(), Timeout.InfiniteTimeSpan }; + yield return new object[] { new ObjectDisposedException("test"), TimeSpan.FromMilliseconds(50) }; + yield return new object[] { new ObjectDisposedException("test"), Timeout.InfiniteTimeSpan }; + yield return new object[] { new OperationCanceledException(), TimeSpan.FromMilliseconds(50) }; + yield return new object[] { new OperationCanceledException(), Timeout.InfiniteTimeSpan }; + yield return new object[] { new NotSupportedException(), TimeSpan.FromMilliseconds(50) }; + yield return new object[] { new NotSupportedException(), Timeout.InfiniteTimeSpan }; + } + + [TestCaseSource(nameof(GetExceptionCases))] + public void ExceptionsTranslatedCorrectlyWhenCanceled(Exception exception, TimeSpan timeout) + { + var cts = new CancellationTokenSource(); + var stream = new CancelingStream(cts, exception); + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = stream + }; + + MockTransport mockTransport = CreateMockTransport(mockResponse); + Assert.ThrowsAsync(async () => await SendGetRequestAsync(mockTransport, timeout, cancellationToken: cts.Token)); + Assert.IsTrue(stream.IsClosed); + } + + [TestCaseSource(nameof(GetExceptionCases))] + public void ExceptionsNotTranslatedWhenNotCanceled(Exception exception, TimeSpan timeout) + { + var cts = new CancellationTokenSource(); + var stream = new CancelingStream(cts, exception, false); + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = stream + }; + + MockTransport mockTransport = CreateMockTransport(mockResponse); + var thrown = Assert.CatchAsync(async () => await SendGetRequestAsync(mockTransport, timeout, cancellationToken: cts.Token)); + Assert.AreSame(exception, thrown); + Assert.IsFalse(stream.IsClosed); + } + + [Test] + public async Task CanOverrideDefaultNetworkTimeout_Stream() + { + var hangingStream = new HangingReadStream(); + + MockResponse mockResponse = new MockResponse(200) + { + ContentStream = hangingStream + }; + + MockTransport mockTransport = new MockTransport(mockResponse); + Response response = await SendGetRequestAsync(mockTransport, TimeSpan.FromMilliseconds(30), bufferResponse: false); + + //Assert.IsInstanceOf(response.ContentStream); + Assert.IsFalse(response.ContentStream.CanWrite); + Assert.AreEqual(30, hangingStream.ReadTimeout); + } + + #region Helpers + + protected async Task SendGetRequestAsync(HttpPipelineTransport transport, TimeSpan networkTimeout, bool bufferResponse = true, CancellationToken cancellationToken = default) + { + HttpPipeline pipeline = new(transport); + HttpMessage message = pipeline.CreateMessage(); + message.NetworkTimeout = networkTimeout; + message.BufferResponse = bufferResponse; + + if (IsAsync) + { + await pipeline.SendAsync(message, cancellationToken).ConfigureAwait(false); + } + else + { + pipeline.Send(message, cancellationToken); + } + + return message.Response; + } + + private class SlowReadStream : TestReadStream + { + public readonly TaskCompletionSource StartedReader = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + StartedReader.TrySetResult(null); + await Task.Delay(20, cancellationToken); + return 10; + } + + public override int Read(byte[] buffer, int offset, int count) + { + StartedReader.TrySetResult(null); + Thread.Sleep(20); + return 10; + } + } + + private class HangingReadStream : TestReadStream + { + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await Task.Delay(Timeout.Infinite, cancellationToken); + return 0; + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override int ReadTimeout { get; set; } + + public override bool CanTimeout { get; } = true; + } + + private class ReadTrackingStream : TestReadStream + { + public const int ContentByteValue = 233; + + private readonly int _size; + + private readonly int _throwAfter; + + public ReadTrackingStream(int size, int throwAfter) + { + _size = size; + _throwAfter = throwAfter; + } + + public int BytesRead { get; set; } + + public override int Read(byte[] buffer, int offset, int count) + { + if (BytesRead == _size) + { + return 0; + } + + int left = Math.Min(count, _size); + Span span = buffer.AsSpan(offset, left); + + for (int i = 0; i < span.Length; i++) + { + span[i] = ContentByteValue; + } + + BytesRead += left; + + if (BytesRead > _throwAfter) + { + throw new IOException(); + } + + return left; + } + + public override void Close() + { + IsClosed = true; + base.Close(); + } + + public bool IsClosed { get; set; } + } + + private class CancelingStream : TestReadStream + { + private readonly Exception _exceptionToThrow; + private readonly CancellationTokenSource _cancellationTokenSource; + private readonly bool _cancel; + + public CancelingStream(CancellationTokenSource cancellationTokenSource, Exception exceptionToThrow, bool cancel = true) + { + _exceptionToThrow = exceptionToThrow; + _cancellationTokenSource = cancellationTokenSource; + _cancel = cancel; + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (_cancel) + _cancellationTokenSource.Cancel(); + throw _exceptionToThrow; + } + + public override void Close() + { + IsClosed = true; + base.Close(); + } + + public bool IsClosed { get; set; } + } + + private abstract class TestReadStream : Stream + { + public override bool CanRead { get; } = true; + public override bool CanSeek { get; } + public override bool CanWrite { get; } + public override long Length { get; } + public override long Position { get; set; } + + public override void Flush() + { + throw new System.NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new System.NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new System.NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new System.NotImplementedException(); + } + } + + #endregion +} diff --git a/sdk/core/Azure.Core/tests/RetriableStreamTests.cs b/sdk/core/Azure.Core/tests/RetriableStreamTests.cs index 4e8e1721641f9..386e5a956bb80 100644 --- a/sdk/core/Azure.Core/tests/RetriableStreamTests.cs +++ b/sdk/core/Azure.Core/tests/RetriableStreamTests.cs @@ -370,27 +370,42 @@ private Task ReadAsync(Stream stream, byte[] buffer, int offset, int length private static Stream SendTestRequest(HttpPipeline pipeline, long offset) { - using Request request = CreateRequest(pipeline, offset); + using HttpMessage message = CreateMessage(pipeline, offset); - Response response = pipeline.SendRequest(request, CancellationToken.None); - return response.ContentStream; + pipeline.Send(message, CancellationToken.None); + Response response = message.Response; + Stream stream = message.ExtractResponseContent(); + + return stream; } private static async ValueTask SendTestRequestAsync(HttpPipeline pipeline, long offset) { - using Request request = CreateRequest(pipeline, offset); + using HttpMessage message = CreateMessage(pipeline, offset); + + await pipeline.SendAsync(message, CancellationToken.None); + Response response = message.Response; + Stream stream = message.ExtractResponseContent(); - Response response = await pipeline.SendRequestAsync(request, CancellationToken.None); - return response.ContentStream; + return stream; } - private static Request CreateRequest(HttpPipeline pipeline, long offset) + private static HttpMessage CreateMessage(HttpPipeline pipeline, long offset) { Request request = pipeline.CreateRequest(); request.Method = RequestMethod.Get; request.Uri.Reset(new Uri("https://example.com")); request.Headers.Add("Range", "bytes=" + offset); - return request; + HttpMessage message = new(request, ResponseClassifier.Shared); + + // RetriableStream is only used in clients where streaming APIs + // return the network stream to the end-user. RetriableStream lets + // us do this in a way that if a request fails, it can be retried + // according to the retry logic configured for the client's pipeline. + // As such, when it is used clients must set message.BufferResponse + // to false, so we do this in the validation tests as well. + message.BufferResponse = false; + return message; } private class NoLengthStream : ReadOnlyStream diff --git a/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs b/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs index 6411ae3ba22e3..2eda351fd96e0 100644 --- a/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs +++ b/sdk/core/Azure.Core/tests/TransportFunctionalTests.cs @@ -987,7 +987,14 @@ public async Task StreamReadingExceptionsAreIOExceptions() var transport = GetTransport(); Request request = transport.CreateRequest(); request.Uri.Reset(testServer.Address); - Response response = await ExecuteRequest(request, transport); + HttpMessage messsage = new(request, ResponseClassifier.Shared); + + // This test is explicitly testing the behavior of a response that + // holds a live network stream, so we set BufferResponse to false. + messsage.BufferResponse = false; + + await ProcessAsync(messsage, transport); + Response response = messsage.Response; tcs.SetResult(null); diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index 3af3e84a6abd2..a5442ec9e4e88 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -30,6 +30,7 @@ public ClientResultException(System.ClientModel.Primitives.PipelineResponse resp protected ClientResultException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public ClientResultException(string message, System.ClientModel.Primitives.PipelineResponse? response = null, System.Exception? innerException = null) { } public int Status { get { throw null; } protected set { } } + public static System.Threading.Tasks.Task CreateAsync(System.ClientModel.Primitives.PipelineResponse response, System.Exception? innerException = null) { throw null; } public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public System.ClientModel.Primitives.PipelineResponse? GetRawResponse() { throw null; } } @@ -205,7 +206,7 @@ protected PipelineRequestHeaders() { } public abstract partial class PipelineResponse : System.IDisposable { protected PipelineResponse() { } - public virtual System.BinaryData Content { get { throw null; } } + public abstract System.BinaryData Content { get; } public abstract System.IO.Stream? ContentStream { get; set; } public System.ClientModel.Primitives.PipelineResponseHeaders Headers { get { throw null; } } public virtual bool IsError { get { throw null; } } @@ -213,6 +214,8 @@ protected PipelineResponse() { } public abstract int Status { get; } public abstract void Dispose(); protected abstract System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore(); + public abstract System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); + public abstract System.Threading.Tasks.ValueTask ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); protected virtual void SetIsErrorCore(bool isError) { } } public abstract partial class PipelineResponseHeaders : System.Collections.Generic.IEnumerable>, System.Collections.IEnumerable @@ -246,10 +249,4 @@ protected void AssertNotFrozen() { } public virtual void Freeze() { } public void SetHeader(string name, string value) { } } - public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy - { - public ResponseBufferingPolicy() { } - public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } - public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } - } } diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index 0e44a8d8acece..928285a6edded 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -30,6 +30,7 @@ public ClientResultException(System.ClientModel.Primitives.PipelineResponse resp protected ClientResultException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public ClientResultException(string message, System.ClientModel.Primitives.PipelineResponse? response = null, System.Exception? innerException = null) { } public int Status { get { throw null; } protected set { } } + public static System.Threading.Tasks.Task CreateAsync(System.ClientModel.Primitives.PipelineResponse response, System.Exception? innerException = null) { throw null; } public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public System.ClientModel.Primitives.PipelineResponse? GetRawResponse() { throw null; } } @@ -204,7 +205,7 @@ protected PipelineRequestHeaders() { } public abstract partial class PipelineResponse : System.IDisposable { protected PipelineResponse() { } - public virtual System.BinaryData Content { get { throw null; } } + public abstract System.BinaryData Content { get; } public abstract System.IO.Stream? ContentStream { get; set; } public System.ClientModel.Primitives.PipelineResponseHeaders Headers { get { throw null; } } public virtual bool IsError { get { throw null; } } @@ -212,6 +213,8 @@ protected PipelineResponse() { } public abstract int Status { get; } public abstract void Dispose(); protected abstract System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore(); + public abstract System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); + public abstract System.Threading.Tasks.ValueTask ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); protected virtual void SetIsErrorCore(bool isError) { } } public abstract partial class PipelineResponseHeaders : System.Collections.Generic.IEnumerable>, System.Collections.IEnumerable @@ -245,10 +248,4 @@ protected void AssertNotFrozen() { } public virtual void Freeze() { } public void SetHeader(string name, string value) { } } - public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy - { - public ResponseBufferingPolicy() { } - public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } - public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } - } } diff --git a/sdk/core/System.ClientModel/src/Convenience/ClientResultException.cs b/sdk/core/System.ClientModel/src/Convenience/ClientResultException.cs index bf86832f1dbab..b27ba2167b378 100644 --- a/sdk/core/System.ClientModel/src/Convenience/ClientResultException.cs +++ b/sdk/core/System.ClientModel/src/Convenience/ClientResultException.cs @@ -6,6 +6,7 @@ using System.Globalization; using System.Runtime.Serialization; using System.Text; +using System.Threading.Tasks; namespace System.ClientModel; @@ -17,6 +18,12 @@ public class ClientResultException : Exception, ISerializable private readonly PipelineResponse? _response; private int _status; + public static async Task CreateAsync(PipelineResponse response, Exception? innerException = default) + { + string message = await CreateMessageAsync(response).ConfigureAwait(false); + return new ClientResultException(message, response, innerException); + } + /// /// Gets the HTTP status code of the response. Returns. 0 if response was not received. /// @@ -66,8 +73,21 @@ public override void GetObjectData(SerializationInfo info, StreamingContext cont public PipelineResponse? GetRawResponse() => _response; private static string CreateMessage(PipelineResponse response) + => CreateMessageSyncOrAsync(response, async: false).EnsureCompleted(); + + private static async ValueTask CreateMessageAsync(PipelineResponse response) + => await CreateMessageSyncOrAsync(response, async: true).ConfigureAwait(false); + + private static async ValueTask CreateMessageSyncOrAsync(PipelineResponse response, bool async) { - response.BufferContent(); + if (async) + { + await response.ReadContentAsync().ConfigureAwait(false); + } + else + { + response.ReadContent(); + } StringBuilder messageBuilder = new(); diff --git a/sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs b/sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs index 869516e461bf9..41680950be29e 100644 --- a/sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs +++ b/sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs @@ -41,11 +41,32 @@ private static void ThrowOperationCanceledException(Exception? innerException, C /// Throws a cancellation exception if cancellation has been requested via . /// The token to check for a cancellation request. - internal static void ThrowIfCancellationRequested(CancellationToken cancellationToken) + /// The inner exception to wrap. May be null. + internal static void ThrowIfCancellationRequested(CancellationToken cancellationToken, Exception? innerException = default) { if (cancellationToken.IsCancellationRequested) { - ThrowOperationCanceledException(innerException: null, cancellationToken); + ThrowOperationCanceledException(innerException, cancellationToken); + } + } + + /// Throws a cancellation exception if cancellation has been requested via or . + /// The user-provided token. + /// The linked token that is cancelled on timeout provided token. + /// The inner exception to use. + /// The timeout used for the operation. +#pragma warning disable CA1068 // Cancellation token has to be the last parameter + internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken messageToken, CancellationToken timeoutToken, Exception? innerException, TimeSpan timeout) +#pragma warning restore CA1068 + { + ThrowIfCancellationRequested(messageToken, innerException); + + if (timeoutToken.IsCancellationRequested) + { + throw CreateOperationCanceledException( + innerException, + timeoutToken, + $"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. "); } } -} \ No newline at end of file +} diff --git a/sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs b/sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs index fd8c3bdd96d76..95d9b6c1cd2df 100644 --- a/sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs +++ b/sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System.ClientModel.Primitives; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -60,7 +59,7 @@ public override void Flush() public override int Read(byte[] buffer, int offset, int count) { - var source = StartTimeout(default, out bool dispose); + CancellationTokenSource source = StartTimeout(default, out bool dispose); try { return _stream.Read(buffer, offset, count); @@ -68,18 +67,18 @@ public override int Read(byte[] buffer, int offset, int count) // We dispose stream on timeout so catch and check if cancellation token was cancelled catch (IOException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); throw; } // We dispose stream on timeout so catch and check if cancellation token was cancelled catch (ObjectDisposedException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); throw; } catch (OperationCanceledException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); throw; } finally @@ -90,7 +89,7 @@ public override int Read(byte[] buffer, int offset, int count) public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - var source = StartTimeout(cancellationToken, out bool dispose); + CancellationTokenSource source = StartTimeout(cancellationToken, out bool dispose); try { #pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets @@ -100,18 +99,18 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, // We dispose stream on timeout so catch and check if cancellation token was cancelled catch (IOException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); throw; } // We dispose stream on timeout so catch and check if cancellation token was cancelled catch (ObjectDisposedException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); throw; } catch (OperationCanceledException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); throw; } finally diff --git a/sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs b/sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs index 21fdef014f56d..a7c206fa7debd 100644 --- a/sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs +++ b/sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs @@ -12,6 +12,9 @@ namespace System.ClientModel.Internal; internal static class StreamExtensions { + // Same value as Stream.CopyTo uses by default + private const int DefaultCopyBufferSize = 81920; + public static async Task WriteAsync(this Stream stream, ReadOnlyMemory buffer, CancellationToken cancellation = default) { Argument.AssertNotNull(stream, nameof(stream)); @@ -86,4 +89,45 @@ public static async Task WriteAsync(this Stream stream, ReadOnlySequence b ArrayPool.Shared.Return(array); } } + + public static async Task CopyToAsync(this Stream source, Stream destination, CancellationToken cancellationToken) + { + byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); + + try + { + while (true) + { +#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets + int bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); +#pragma warning restore // ReadAsync(Memory<>) overload is not available in all targets + if (bytesRead == 0) + break; + await destination.WriteAsync(new ReadOnlyMemory(buffer, 0, bytesRead), cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + public static void CopyTo(this Stream source, Stream destination, CancellationToken cancellationToken) + { + byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); + + try + { + int read; + while ((read = source.Read(buffer, 0, buffer.Length)) != 0) + { + cancellationToken.ThrowIfCancellationRequested(); + destination.Write(buffer, 0, read); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } } diff --git a/sdk/core/System.ClientModel/src/Message/PipelineResponse.cs b/sdk/core/System.ClientModel/src/Message/PipelineResponse.cs index 8e13b24ff2d03..5e280ea5d5938 100644 --- a/sdk/core/System.ClientModel/src/Message/PipelineResponse.cs +++ b/sdk/core/System.ClientModel/src/Message/PipelineResponse.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System.Buffers; -using System.ClientModel.Internal; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -12,7 +10,7 @@ namespace System.ClientModel.Primitives; public abstract class PipelineResponse : IDisposable { // TODO(matell): The .NET Framework team plans to add BinaryData.Empty in dotnet/runtime#49670, and we can use it then. - private static readonly BinaryData s_emptyBinaryData = new(Array.Empty()); + internal static readonly BinaryData s_EmptyBinaryData = new(Array.Empty()); private bool _isError = false; @@ -35,30 +33,11 @@ public abstract class PipelineResponse : IDisposable /// public abstract Stream? ContentStream { get; set; } - public virtual BinaryData Content - { - get - { - if (ContentStream == null) - { - return s_emptyBinaryData; - } - - if (!TryGetBufferedContent(out MemoryStream bufferedContent)) - { - throw new InvalidOperationException($"The response is not buffered."); - } - - if (bufferedContent.TryGetBuffer(out ArraySegment segment)) - { - return new BinaryData(segment.AsMemory()); - } - else - { - return new BinaryData(bufferedContent.ToArray()); - } - } - } + public abstract BinaryData Content { get; } + + public abstract BinaryData ReadContent(CancellationToken cancellationToken = default); + + public abstract ValueTask ReadContentAsync(CancellationToken cancellationToken = default); /// /// Indicates whether the status code of the returned response is considered @@ -73,100 +52,5 @@ public virtual BinaryData Content protected virtual void SetIsErrorCore(bool isError) => _isError = isError; - internal TimeSpan NetworkTimeout { get; set; } = ClientPipeline.DefaultNetworkTimeout; - public abstract void Dispose(); - - #region Response Buffering - - // Same value as Stream.CopyTo uses by default - private const int DefaultCopyBufferSize = 81920; - - internal bool TryGetBufferedContent(out MemoryStream bufferedContent) - { - if (ContentStream is MemoryStream content) - { - bufferedContent = content; - return true; - } - - bufferedContent = default!; - return false; - } - - internal void BufferContent(TimeSpan? timeout = default, CancellationTokenSource? cts = default) - { - Stream? responseContentStream = ContentStream; - if (responseContentStream == null || TryGetBufferedContent(out _)) - { - // No need to buffer content. - return; - } - - MemoryStream bufferStream = new(); - CopyTo(responseContentStream, bufferStream, timeout ?? NetworkTimeout, cts ?? new CancellationTokenSource()); - responseContentStream.Dispose(); - bufferStream.Position = 0; - ContentStream = bufferStream; - } - - internal async Task BufferContentAsync(TimeSpan? timeout = default, CancellationTokenSource? cts = default) - { - Stream? responseContentStream = ContentStream; - if (responseContentStream == null || TryGetBufferedContent(out _)) - { - // No need to buffer content. - return; - } - - MemoryStream bufferStream = new(); - await CopyToAsync(responseContentStream, bufferStream, timeout ?? NetworkTimeout, cts ?? new CancellationTokenSource()).ConfigureAwait(false); - responseContentStream.Dispose(); - bufferStream.Position = 0; - ContentStream = bufferStream; - } - - private static async Task CopyToAsync(Stream source, Stream destination, TimeSpan timeout, CancellationTokenSource cancellationTokenSource) - { - byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); - try - { - while (true) - { - cancellationTokenSource.CancelAfter(timeout); -#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets - int bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationTokenSource.Token).ConfigureAwait(false); -#pragma warning restore // ReadAsync(Memory<>) overload is not available in all targets - if (bytesRead == 0) break; - await destination.WriteAsync(new ReadOnlyMemory(buffer, 0, bytesRead), cancellationTokenSource.Token).ConfigureAwait(false); - } - } - finally - { - cancellationTokenSource.CancelAfter(Timeout.InfiniteTimeSpan); - ArrayPool.Shared.Return(buffer); - } - } - - private static void CopyTo(Stream source, Stream destination, TimeSpan timeout, CancellationTokenSource cancellationTokenSource) - { - byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); - try - { - int read; - while ((read = source.Read(buffer, 0, buffer.Length)) != 0) - { - cancellationTokenSource.Token.ThrowIfCancellationRequested(); - cancellationTokenSource.CancelAfter(timeout); - destination.Write(buffer, 0, read); - } - } - finally - { - cancellationTokenSource.CancelAfter(Timeout.InfiniteTimeSpan); - ArrayPool.Shared.Return(buffer); - } - } - - #endregion } diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs index 7d49e441cdfdf..4f518de1ccc84 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs @@ -69,7 +69,6 @@ public static ClientPipeline Create( pipelineLength += options.BeforeTransportPolicies?.Length ?? 0; pipelineLength++; // for retry policy - pipelineLength++; // for response buffering policy pipelineLength++; // for transport PipelinePolicy[] policies = new PipelinePolicy[pipelineLength]; @@ -103,9 +102,6 @@ public static ClientPipeline Create( int perTryIndex = index; - // Response buffering comes before the transport. - policies[index++] = ResponseBufferingPolicy.Default; - // Before transport policies come before the transport. beforeTransportPolicies.CopyTo(policies.AsSpan(index)); index += beforeTransportPolicies.Length; diff --git a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.Response.cs b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.Response.cs index dc7c8e2f5228e..cfddb4a4c6fed 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.Response.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.Response.cs @@ -1,14 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.ClientModel.Internal; using System.IO; using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; namespace System.ClientModel.Primitives; public partial class HttpClientPipelineTransport { - private class HttpPipelineResponse : PipelineResponse + private class HttpClientTransportResponse : PipelineResponse { private readonly HttpResponseMessage _httpResponse; @@ -20,13 +23,17 @@ private class HttpPipelineResponse : PipelineResponse private readonly HttpContent _httpResponseContent; private Stream? _contentStream; + private BinaryData? _bufferedContent; private bool _disposed; - public HttpPipelineResponse(HttpResponseMessage httpResponse) + public HttpClientTransportResponse(HttpResponseMessage httpResponse) { _httpResponse = httpResponse ?? throw new ArgumentNullException(nameof(httpResponse)); _httpResponseContent = _httpResponse.Content; + + // Don't dispose the content so it remains available for reading headers. + _httpResponse.Content = null; } public override int Status => (int)_httpResponse.StatusCode; @@ -39,17 +46,103 @@ protected override PipelineResponseHeaders GetHeadersCore() public override Stream? ContentStream { - get => _contentStream; - set + get { - // Make sure we don't dispose the content if the stream was replaced - _httpResponse.Content = null; + if (_contentStream is not null) + { + return _contentStream; + } + + if (_bufferedContent is not null) + { + return _bufferedContent.ToStream(); + } + return null; + } + set + { _contentStream = value; + + // Invalidate the cache since the source-stream has been replaced. + _bufferedContent = null; } } - #region IDisposable + public override BinaryData Content + { + get + { + if (_bufferedContent is not null) + { + return _bufferedContent; + } + + if (_contentStream is null || _contentStream is MemoryStream) + { + return ReadContent(); + } + + throw new InvalidOperationException($"The response is not buffered."); + } + } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + => ReadContentSyncOrAsync(cancellationToken, async: false).EnsureCompleted(); + + public override async ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + => await ReadContentSyncOrAsync(cancellationToken, async: true).ConfigureAwait(false); + + private async ValueTask ReadContentSyncOrAsync(CancellationToken cancellationToken, bool async) + { + if (_bufferedContent is not null) + { + // Content has already been buffered. + return _bufferedContent; + } + + if (_contentStream == null) + { + // Content is not buffered but there is no source stream. + // Our contract from Azure.Core is to return BinaryData.Empty in this case. + _bufferedContent = s_EmptyBinaryData; + return _bufferedContent; + } + + if (_contentStream.CanSeek && _contentStream.Position != 0) + { + throw new InvalidOperationException("Content stream position is not at beginning of stream."); + } + + // ContentStream still holds the source stream. Buffer the content + // and dispose the source stream. + BufferedContentStream bufferStream = new(); + + if (async) + { + await _contentStream.CopyToAsync(bufferStream, cancellationToken).ConfigureAwait(false); +#if NETSTANDARD2_0 + _contentStream.Dispose(); +#else + await _contentStream.DisposeAsync().ConfigureAwait(false); +#endif + } + else + { + _contentStream.CopyTo(bufferStream, cancellationToken); + _contentStream.Dispose(); + } + + _contentStream = null; + + bufferStream.Position = 0; + + _bufferedContent = bufferStream.TryGetBuffer(out ArraySegment segment) ? + new BinaryData(segment.AsMemory()) : + new BinaryData(bufferStream.ToArray()); + + return _bufferedContent; + } public override void Dispose() { @@ -62,41 +155,22 @@ protected virtual void Dispose(bool disposing) { if (disposing && !_disposed) { - var httpResponse = _httpResponse; + HttpResponseMessage httpResponse = _httpResponse; httpResponse?.Dispose(); - // Some notes on this: - // - // 1. If the content is buffered, we want it to remain available to the - // client for model deserialization and in case the end user of the - // client calls OutputMessage.GetRawResponse. So, we don't dispose it. - // - // If the content is buffered, we assume that the entity that did the - // buffering took responsibility for disposing the network stream. - // - // 2. If the content is not buffered, we dispose it so that we don't leave - // a network connection open. - // - // One tricky piece here is that in some cases, we may not have buffered - // the content because we wanted to pass the live network stream out of - // the client method and back to the end-user caller of the client e.g. - // for a streaming API. If the latter is the case, the client should have - // called the HttpMessage.ExtractResponseContent method to obtain a reference - // to the network stream, and the response content was replaced by a stream - // that we are ok to dispose here. In this case, the network stream is - // not disposed, because the entity that replaced the response content - // intentionally left the network stream undisposed. - - var contentStream = _contentStream; - if (contentStream is not null && !TryGetBufferedContent(out _)) + if (ContentStream is MemoryStream) { - contentStream?.Dispose(); - _contentStream = null; + ReadContent(); } + Stream? contentStream = _contentStream; + contentStream?.Dispose(); + _contentStream = null; + _disposed = true; } } - #endregion + + private class BufferedContentStream : MemoryStream { } } -} \ No newline at end of file +} diff --git a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs index fb800ca3e7ca9..d25a89823dc9e 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs @@ -147,7 +147,7 @@ private async ValueTask ProcessSyncOrAsync(PipelineMessage message, bool async) throw new ClientResultException(e.Message, response: default, e); } - message.Response = new HttpPipelineResponse(responseMessage); + message.Response = new HttpClientTransportResponse(responseMessage); // This extensibility point lets derived types do the following: // 1. Set message.Response to an implementation-specific type, e.g. Azure.Core.Response. diff --git a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs index a516ed3a4cef6..c34fb164ee36b 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs @@ -1,83 +1,170 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.ClientModel.Internal; using System.Collections.Generic; using System.Diagnostics; +using System.IO; +using System.Threading; using System.Threading.Tasks; namespace System.ClientModel.Primitives; public abstract class PipelineTransport : PipelinePolicy { + #region CreateMessage + /// /// TBD: needed for inheritdoc. /// - /// - public void Process(PipelineMessage message) + public PipelineMessage CreateMessage() { - ProcessCore(message); + PipelineMessage message = CreateMessageCore(); + message.NetworkTimeout ??= ClientPipeline.DefaultNetworkTimeout; + + if (message.Request is null) + { + throw new InvalidOperationException("Request was not set on message."); + } - if (message.Response is null) + if (message.Response is not null) { - throw new InvalidOperationException("Response was not set by transport."); + throw new InvalidOperationException("Response should not be set before transport is invoked."); } - message.Response.SetIsError(ClassifyResponse(message)); + return message; } + protected abstract PipelineMessage CreateMessageCore(); + + #endregion + + #region Process message + + /// + /// TBD: needed for inheritdoc. + /// + /// + public void Process(PipelineMessage message) + => ProcessSyncOrAsync(message, async: false).EnsureCompleted(); + /// /// TBD: needed for inheritdoc. /// /// public async ValueTask ProcessAsync(PipelineMessage message) + => await ProcessSyncOrAsync(message, async: true).ConfigureAwait(false); + + private async ValueTask ProcessSyncOrAsync(PipelineMessage message, bool async) { - await ProcessCoreAsync(message).ConfigureAwait(false); + Debug.Assert(message.NetworkTimeout is not null, "NetworkTimeout is not set on PipelineMessage."); - if (message.Response is null) + // Implement network timeout behavior around call to ProcessCore. + TimeSpan networkTimeout = (TimeSpan)message.NetworkTimeout!; + CancellationToken messageToken = message.CancellationToken; + using CancellationTokenSource timeoutTokenSource = CancellationTokenSource.CreateLinkedTokenSource(messageToken); + timeoutTokenSource.CancelAfter(networkTimeout); + + try + { + message.CancellationToken = timeoutTokenSource.Token; + + if (async) + { + await ProcessCoreAsync(message).ConfigureAwait(false); + } + else + { + ProcessCore(message); + } + } + catch (OperationCanceledException ex) { - throw new InvalidOperationException("Response was not set by transport."); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(messageToken, timeoutTokenSource.Token, ex, networkTimeout); + throw; + } + finally + { + message.CancellationToken = messageToken; + timeoutTokenSource.CancelAfter(Timeout.Infinite); } - message.Response.SetIsError(ClassifyResponse(message)); - } + message.AssertResponse(); + message.Response!.SetIsError(ClassifyResponse(message)); - private static bool ClassifyResponse(PipelineMessage message) - { - if (!message.ResponseClassifier.TryClassify(message, out bool isError)) + // The remainder of the method handles response content according to + // buffering logic specified by value of message.BufferResponse. + + Stream? contentStream = message.Response!.ContentStream; + if (contentStream is null) { - bool classified = PipelineMessageClassifier.Default.TryClassify(message, out isError); + // There is no response content. + return; + } - Debug.Assert(classified); + if (!message.BufferResponse) + { + // Client has requested not to buffer the message response content. + // If applicable, wrap it in a read-timeout stream. + if (networkTimeout != Timeout.InfiniteTimeSpan) + { + message.Response.ContentStream = new ReadTimeoutStream(contentStream, networkTimeout); + } + + return; } - return isError; + try + { + // If cancellation is possible (whether due to network timeout or a user + // cancellation token being passed), then register callback to dispose the + // stream on cancellation. + if (networkTimeout != Timeout.InfiniteTimeSpan || messageToken.CanBeCanceled) + { + timeoutTokenSource.Token.Register(state => ((Stream?)state)?.Dispose(), contentStream); + timeoutTokenSource.CancelAfter(networkTimeout); + } + + if (async) + { + await message.Response.ReadContentAsync(timeoutTokenSource.Token).ConfigureAwait(false); + } + else + { + message.Response.ReadContent(timeoutTokenSource.Token); + } + } + // We dispose stream on timeout or user cancellation so catch and check if + // cancellation token was cancelled + catch (Exception ex) when (ex is ObjectDisposedException + or IOException + or OperationCanceledException + or NotSupportedException) + { + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(messageToken, timeoutTokenSource.Token, ex, networkTimeout); + throw; + } } protected abstract void ProcessCore(PipelineMessage message); protected abstract ValueTask ProcessCoreAsync(PipelineMessage message); - /// - /// TBD: needed for inheritdoc. - /// - public PipelineMessage CreateMessage() + private static bool ClassifyResponse(PipelineMessage message) { - PipelineMessage message = CreateMessageCore(); - - if (message.Request is null) + if (!message.ResponseClassifier.TryClassify(message, out bool isError)) { - throw new InvalidOperationException("Request was not set on message."); - } + bool classified = PipelineMessageClassifier.Default.TryClassify(message, out isError); - if (message.Response is not null) - { - throw new InvalidOperationException("Response should not be set before transport is invoked."); + Debug.Assert(classified, "Error classifier did not classify message."); } - return message; + return isError; } - protected abstract PipelineMessage CreateMessageCore(); + #endregion + + #region PipelinePolicy.Process overrides // These methods from PipelinePolicy just say "you've reached the end // of the line", i.e. they stop the invocation of the policy chain. @@ -85,13 +172,15 @@ public sealed override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) { await ProcessAsync(message).ConfigureAwait(false); - Debug.Assert(++currentIndex == pipeline.Count); + Debug.Assert(++currentIndex == pipeline.Count, "Transport is not at last position in pipeline."); } + + #endregion } diff --git a/sdk/core/System.ClientModel/src/Pipeline/ResponseBufferingPolicy.cs b/sdk/core/System.ClientModel/src/Pipeline/ResponseBufferingPolicy.cs deleted file mode 100644 index 9bec762e3cd22..0000000000000 --- a/sdk/core/System.ClientModel/src/Pipeline/ResponseBufferingPolicy.cs +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace System.ClientModel.Primitives; - -/// -/// Pipeline policy to buffer response content or add a timeout to response content -/// managed by the client. -/// -public class ResponseBufferingPolicy : PipelinePolicy -{ - internal static readonly ResponseBufferingPolicy Default = new(); - - public ResponseBufferingPolicy() - { - } - - public sealed override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) - => ProcessSyncOrAsync(message, pipeline, currentIndex, async: false).EnsureCompleted(); - - public sealed override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) - => await ProcessSyncOrAsync(message, pipeline, currentIndex, async: true).ConfigureAwait(false); - - private async ValueTask ProcessSyncOrAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex, bool async) - { - Debug.Assert(message.NetworkTimeout is not null); - - TimeSpan invocationNetworkTimeout = (TimeSpan)message.NetworkTimeout!; - - CancellationToken oldToken = message.CancellationToken; - using CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(oldToken); - cts.CancelAfter(invocationNetworkTimeout); - - try - { - message.CancellationToken = cts.Token; - if (async) - { - await ProcessNextAsync(message, pipeline, currentIndex).ConfigureAwait(false); - } - else - { - ProcessNext(message, pipeline, currentIndex); - } - } - catch (OperationCanceledException ex) - { - ThrowIfCancellationRequestedOrTimeout(oldToken, cts.Token, ex, invocationNetworkTimeout); - throw; - } - finally - { - message.CancellationToken = oldToken; - cts.CancelAfter(Timeout.Infinite); - } - - message.AssertResponse(); - message.Response!.NetworkTimeout = invocationNetworkTimeout; - - Stream? responseContentStream = message.Response!.ContentStream; - if (responseContentStream is null || - message.Response.TryGetBufferedContent(out var _)) - { - // There is either no content on the response, or the content has already - // been buffered. - return; - } - - if (!message.BufferResponse) - { - // Client has requested not to buffer the message response content. - // If applicable, wrap it in a read-timeout stream. - if (invocationNetworkTimeout != Timeout.InfiniteTimeSpan) - { - message.Response.ContentStream = new ReadTimeoutStream(responseContentStream, invocationNetworkTimeout); - } - - return; - } - - // If we got this far, buffer the response. - - // If cancellation is possible (whether due to network timeout or a user cancellation token being passed), then - // register callback to dispose the stream on cancellation. - if (invocationNetworkTimeout != Timeout.InfiniteTimeSpan || oldToken.CanBeCanceled) - { - cts.Token.Register(state => ((Stream?)state)?.Dispose(), responseContentStream); - } - - try - { - if (async) - { - await message.Response.BufferContentAsync(invocationNetworkTimeout, cts).ConfigureAwait(false); - } - else - { - message.Response.BufferContent(invocationNetworkTimeout, cts); - } - } - // We dispose stream on timeout or user cancellation so catch and check if cancellation token was cancelled - catch (Exception ex) - when (ex is ObjectDisposedException - or IOException - or OperationCanceledException - or NotSupportedException) - { - ThrowIfCancellationRequestedOrTimeout(oldToken, cts.Token, ex, invocationNetworkTimeout); - throw; - } - } - - /// Throws a cancellation exception if cancellation has been requested via or . - /// The customer provided token. - /// The linked token that is cancelled on timeout provided token. - /// The inner exception to use. - /// The timeout used for the operation. -#pragma warning disable CA1068 // Cancellation token has to be the last parameter - internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken originalToken, CancellationToken timeoutToken, Exception? inner, TimeSpan timeout) -#pragma warning restore CA1068 - { - CancellationHelper.ThrowIfCancellationRequested(originalToken); - - if (timeoutToken.IsCancellationRequested) - { - throw CancellationHelper.CreateOperationCanceledException( - inner, - timeoutToken, - $"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. "); - } - } -} \ No newline at end of file diff --git a/sdk/core/System.ClientModel/tests/Convenience/ClientRequestExceptionTests.cs b/sdk/core/System.ClientModel/tests/Convenience/ClientRequestExceptionTests.cs index 173090c729d9b..3f538573846ee 100644 --- a/sdk/core/System.ClientModel/tests/Convenience/ClientRequestExceptionTests.cs +++ b/sdk/core/System.ClientModel/tests/Convenience/ClientRequestExceptionTests.cs @@ -5,6 +5,7 @@ using NUnit.Framework; using System.ClientModel.Primitives; using System.IO; +using System.Threading.Tasks; namespace System.ClientModel.Tests.Exceptions; @@ -24,6 +25,20 @@ public void CanCreateFromResponse() exception.Message); } + [Test] + public async Task CanCreateFromAsyncFactory() + { + PipelineResponse response = new MockPipelineResponse(200, "MockReason"); + + ClientResultException exception = await ClientResultException.CreateAsync(response); + + Assert.AreEqual(response.Status, exception.Status); + Assert.AreEqual(response, exception.GetRawResponse()); + Assert.AreEqual( + $"Service request failed.{Environment.NewLine}Status: 200 (MockReason){Environment.NewLine}", + exception.Message); + } + [Test] public void PassingMessageOverridesResponseMessage() { diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs index 1bc236899fcdd..e4163ffc85f59 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs @@ -5,6 +5,8 @@ using System.ClientModel.Primitives; using System.IO; using System.Text; +using System.Threading; +using System.Threading.Tasks; namespace ClientModel.Tests.Mocks; @@ -13,6 +15,7 @@ public class MockPipelineResponse : PipelineResponse private int _status; private string _reasonPhrase; private Stream? _contentStream; + private BinaryData? _bufferedContent; private readonly PipelineResponseHeaders _headers; @@ -50,6 +53,31 @@ public override Stream? ContentStream set => _contentStream = value; } + public override BinaryData Content + { + get + { + if (_contentStream is null) + { + return new BinaryData(Array.Empty()); + } + + if (ContentStream is not MemoryStream memoryContent) + { + throw new InvalidOperationException($"The response is not buffered."); + } + + if (memoryContent.TryGetBuffer(out ArraySegment segment)) + { + return new BinaryData(segment.AsMemory()); + } + else + { + return new BinaryData(memoryContent.ToArray()); + } + } + } + protected override PipelineResponseHeaders GetHeadersCore() => _headers; public sealed override void Dispose() @@ -63,7 +91,7 @@ protected void Dispose(bool disposing) { if (disposing && !_disposed) { - var content = _contentStream; + Stream? content = _contentStream; if (content != null) { _contentStream = null; @@ -73,4 +101,59 @@ protected void Dispose(bool disposing) _disposed = true; } } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + if (_bufferedContent is not null) + { + return _bufferedContent; + } + + if (_contentStream is null) + { + _bufferedContent = new BinaryData(Array.Empty()); + return _bufferedContent; + } + + MemoryStream bufferStream = new(); + _contentStream.CopyTo(bufferStream); + _contentStream.Dispose(); + _contentStream = bufferStream; + + // Less efficient FromStream method called here because it is a mock. + // For intended production implementation, see HttpClientTransportResponse. + _bufferedContent = BinaryData.FromStream(bufferStream); + return _bufferedContent; + } + + public override async ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + if (_bufferedContent is not null) + { + return _bufferedContent; + } + + if (_contentStream is null) + { + _bufferedContent = new BinaryData(Array.Empty()); + return _bufferedContent; + } + + MemoryStream bufferStream = new(); + +#if NETSTANDARD2_0 || NETFRAMEWORK + await _contentStream.CopyToAsync(bufferStream).ConfigureAwait(false); + _contentStream.Dispose(); +#else + await _contentStream.CopyToAsync(bufferStream, cancellationToken).ConfigureAwait(false); + await _contentStream.DisposeAsync().ConfigureAwait(false); +#endif + + _contentStream = bufferStream; + + // Less efficient FromStream method called here because it is a mock. + // For intended production implementation, see HttpClientTransportResponse. + _bufferedContent = BinaryData.FromStream(bufferStream); + return _bufferedContent; + } } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs index 17cb7968dd7d5..c811ae1304462 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs @@ -6,6 +6,7 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace ClientModel.Tests.Mocks; @@ -176,11 +177,23 @@ public override Stream? ContentStream set => throw new NotImplementedException(); } + public override BinaryData Content => throw new NotImplementedException(); + protected override PipelineResponseHeaders GetHeadersCore() { throw new NotImplementedException(); } public override void Dispose() { } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } } } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs index 7bb84c2bf79f8..b9cc98db95d90 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs @@ -6,6 +6,7 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace ClientModel.Tests.Mocks; @@ -137,6 +138,8 @@ public override Stream? ContentStream set => throw new NotImplementedException(); } + public override BinaryData Content => throw new NotImplementedException(); + protected override PipelineResponseHeaders GetHeadersCore() { throw new NotImplementedException(); @@ -146,5 +149,15 @@ public override void Dispose() { throw new NotImplementedException(); } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } } } diff --git a/sdk/core/System.ClientModel/tests/client/MapsClient/MapsClient.cs b/sdk/core/System.ClientModel/tests/client/MapsClient/MapsClient.cs index a4c18222a771a..95179679b9c04 100644 --- a/sdk/core/System.ClientModel/tests/client/MapsClient/MapsClient.cs +++ b/sdk/core/System.ClientModel/tests/client/MapsClient/MapsClient.cs @@ -6,6 +6,7 @@ using System.ClientModel.Primitives; using System.Net; using System.Text; +using System.Threading.Tasks; namespace Maps; @@ -34,6 +35,40 @@ public MapsClient(Uri endpoint, ApiKeyCredential credential, MapsClientOptions o beforeTransportPolicies: ReadOnlySpan.Empty); } + public virtual async Task> GetCountryCodeAsync(IPAddress ipAddress) + { + if (ipAddress is null) + throw new ArgumentNullException(nameof(ipAddress)); + + ClientResult output = await GetCountryCodeAsync(ipAddress.ToString()).ConfigureAwait(false); + + PipelineResponse response = output.GetRawResponse(); + IPAddressCountryPair value = IPAddressCountryPair.FromResponse(response); + + return ClientResult.FromValue(value, response); + } + + public virtual async Task GetCountryCodeAsync(string ipAddress, RequestOptions options = null) + { + if (ipAddress is null) + throw new ArgumentNullException(nameof(ipAddress)); + + options ??= new RequestOptions(); + + using PipelineMessage message = CreateGetLocationRequest(ipAddress, options); + + _pipeline.Send(message); + + PipelineResponse response = message.Response!; + + if (response.IsError && options.ErrorOptions == ClientErrorBehaviors.Default) + { + throw await ClientResultException.CreateAsync(response).ConfigureAwait(false); + } + + return ClientResult.FromResponse(response); + } + public virtual ClientResult GetCountryCode(IPAddress ipAddress) { if (ipAddress is null) throw new ArgumentNullException(nameof(ipAddress)); diff --git a/sdk/core/System.ClientModel/tests/client/MapsClientTests.cs b/sdk/core/System.ClientModel/tests/client/MapsClientTests.cs index 20aea50d50446..3cc23bd2e4433 100644 --- a/sdk/core/System.ClientModel/tests/client/MapsClientTests.cs +++ b/sdk/core/System.ClientModel/tests/client/MapsClientTests.cs @@ -375,6 +375,8 @@ public override Stream ContentStream set => _stream = value; } + public override BinaryData Content => BinaryData.FromStream(_stream); + public override void Dispose() { _stream?.Dispose(); @@ -384,6 +386,16 @@ protected override PipelineResponseHeaders GetHeadersCore() { throw new NotImplementedException(); } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } } private class CustomHeaders : PipelineRequestHeaders diff --git a/sdk/core/System.ClientModel/tests/nullableenabledclient/MapsClientTests.cs b/sdk/core/System.ClientModel/tests/nullableenabledclient/MapsClientTests.cs index e400e916fafed..4ab39c69547c5 100644 --- a/sdk/core/System.ClientModel/tests/nullableenabledclient/MapsClientTests.cs +++ b/sdk/core/System.ClientModel/tests/nullableenabledclient/MapsClientTests.cs @@ -407,6 +407,8 @@ public override Stream? ContentStream set => _stream = value; } + public override BinaryData Content => BinaryData.FromStream(_stream!); + public override void Dispose() { _stream?.Dispose(); @@ -416,6 +418,16 @@ protected override PipelineResponseHeaders GetHeadersCore() { throw new NotImplementedException(); } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } } private class CustomHeaders : PipelineRequestHeaders diff --git a/sdk/textanalytics/Azure.AI.TextAnalytics/src/TextAnalyticsFailedDetailsParser.cs b/sdk/textanalytics/Azure.AI.TextAnalytics/src/TextAnalyticsFailedDetailsParser.cs index bfef5d0416677..88397cf7d03a5 100644 --- a/sdk/textanalytics/Azure.AI.TextAnalytics/src/TextAnalyticsFailedDetailsParser.cs +++ b/sdk/textanalytics/Azure.AI.TextAnalytics/src/TextAnalyticsFailedDetailsParser.cs @@ -20,84 +20,73 @@ public override bool TryParse(Response response, out ResponseError? error, out I data = default; error = default; - if (response.ContentStream is { CanSeek: true }) + try { - long position = response.ContentStream.Position; + // Try to extract the standard Azure Error object from the response so that we can use it as the + // default value for the message, error code, etc. + using JsonDocument doc = JsonDocument.Parse(response.Content); + if (doc.RootElement.TryGetProperty("error", out JsonElement errorElement)) + { + TextAnalyticsError textAnalyticsError = Transforms.ConvertToError(Error.DeserializeError(errorElement)); + error = new ResponseError(textAnalyticsError.ErrorCode.ToString(), textAnalyticsError.Message); + return true; + } - try + // If the response does not straight up correspond to the standard Azure Error object that we are + // looking for, the Error object must actually be nested somewhere in there instead. For example, + // this can happen in the case of the convenience methods that receive a single input document as a + // parameter instead of a list of input documents. Here, rather than returning the typical + // successful response that includes a list of errors that the user needs to look through, we + // want to grab the first error in that list (which inevitably corresponds to a problem with the + // single input document), and use that error to throw a useful RequestFailedException. Now, + // depending on the circumstances, that standard Azure Error could be inside an InputError + // object, a DocumentError object, etc., so we need to look for it among a handful of well-known + // cases like those. + + if (doc.RootElement.TryGetProperty("errors", out JsonElement errorsElement)) { - // Try to extract the standard Azure Error object from the response so that we can use it as the - // default value for the message, error code, etc. + List errors = new(); - response.ContentStream.Position = 0; - using JsonDocument doc = JsonDocument.Parse(response.ContentStream); - if (doc.RootElement.TryGetProperty("error", out JsonElement errorElement)) + foreach (JsonElement item in errorsElement.EnumerateArray()) { - TextAnalyticsError textAnalyticsError = Transforms.ConvertToError(Error.DeserializeError(errorElement)); - error = new ResponseError(textAnalyticsError.ErrorCode.ToString(), textAnalyticsError.Message); - return true; + if (item.TryGetProperty("error", out errorElement)) + { + errors.Add(Error.DeserializeError(errorElement)); + } + else + { + errors.Add(Error.DeserializeError(item)); + } } - // If the response does not straight up correspond to the standard Azure Error object that we are - // looking for, the Error object must actually be nested somewhere in there instead. For example, - // this can happen in the case of the convenience methods that receive a single input document as a - // parameter instead of a list of input documents. Here, rather than returning the typical - // successful response that includes a list of errors that the user needs to look through, we - // want to grab the first error in that list (which inevitably corresponds to a problem with the - // single input document), and use that error to throw a useful RequestFailedException. Now, - // depending on the circumstances, that standard Azure Error could be inside an InputError - // object, a DocumentError object, etc., so we need to look for it among a handful of well-known - // cases like those. + GetResponseError(errors, out error, out data); + return true; + } - if (doc.RootElement.TryGetProperty("errors", out JsonElement errorsElement)) - { - List errors = new(); + if (doc.RootElement.TryGetProperty("results", out JsonElement results) && results.TryGetProperty("errors", out errorsElement)) + { + List errors = new(); - foreach (JsonElement item in errorsElement.EnumerateArray()) + foreach (JsonElement item in errorsElement.EnumerateArray()) + { + if (item.TryGetProperty("error", out errorElement)) { - if (item.TryGetProperty("error", out errorElement)) - { - errors.Add(Error.DeserializeError(errorElement)); - } - else - { - errors.Add(Error.DeserializeError(item)); - } + errors.Add(Error.DeserializeError(errorElement)); } - - GetResponseError(errors, out error, out data); - return true; - } - - if (doc.RootElement.TryGetProperty("results", out JsonElement results) && results.TryGetProperty("errors", out errorsElement)) - { - List errors = new(); - - foreach (JsonElement item in errorsElement.EnumerateArray()) + else { - if (item.TryGetProperty("error", out errorElement)) - { - errors.Add(Error.DeserializeError(errorElement)); - } - else - { - errors.Add(Error.DeserializeError(item)); - } + errors.Add(Error.DeserializeError(item)); } - - GetResponseError(errors, out error, out data); - return true; } + + GetResponseError(errors, out error, out data); + return true; } - catch (JsonException) - { - // Ignore any failures - unexpected content will be - // included verbatim in the detailed error message - } - finally - { - response.ContentStream.Position = position; - } + } + catch (JsonException) + { + // Ignore any failures - unexpected content will be + // included verbatim in the detailed error message } return false;