diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index 3af3e84a6abd2..a5442ec9e4e88 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -30,6 +30,7 @@ public ClientResultException(System.ClientModel.Primitives.PipelineResponse resp protected ClientResultException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public ClientResultException(string message, System.ClientModel.Primitives.PipelineResponse? response = null, System.Exception? innerException = null) { } public int Status { get { throw null; } protected set { } } + public static System.Threading.Tasks.Task CreateAsync(System.ClientModel.Primitives.PipelineResponse response, System.Exception? innerException = null) { throw null; } public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public System.ClientModel.Primitives.PipelineResponse? GetRawResponse() { throw null; } } @@ -205,7 +206,7 @@ protected PipelineRequestHeaders() { } public abstract partial class PipelineResponse : System.IDisposable { protected PipelineResponse() { } - public virtual System.BinaryData Content { get { throw null; } } + public abstract System.BinaryData Content { get; } public abstract System.IO.Stream? ContentStream { get; set; } public System.ClientModel.Primitives.PipelineResponseHeaders Headers { get { throw null; } } public virtual bool IsError { get { throw null; } } @@ -213,6 +214,8 @@ protected PipelineResponse() { } public abstract int Status { get; } public abstract void Dispose(); protected abstract System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore(); + public abstract System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); + public abstract System.Threading.Tasks.ValueTask ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); protected virtual void SetIsErrorCore(bool isError) { } } public abstract partial class PipelineResponseHeaders : System.Collections.Generic.IEnumerable>, System.Collections.IEnumerable @@ -246,10 +249,4 @@ protected void AssertNotFrozen() { } public virtual void Freeze() { } public void SetHeader(string name, string value) { } } - public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy - { - public ResponseBufferingPolicy() { } - public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } - public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } - } } diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index 0e44a8d8acece..928285a6edded 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -30,6 +30,7 @@ public ClientResultException(System.ClientModel.Primitives.PipelineResponse resp protected ClientResultException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public ClientResultException(string message, System.ClientModel.Primitives.PipelineResponse? response = null, System.Exception? innerException = null) { } public int Status { get { throw null; } protected set { } } + public static System.Threading.Tasks.Task CreateAsync(System.ClientModel.Primitives.PipelineResponse response, System.Exception? innerException = null) { throw null; } public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { } public System.ClientModel.Primitives.PipelineResponse? GetRawResponse() { throw null; } } @@ -204,7 +205,7 @@ protected PipelineRequestHeaders() { } public abstract partial class PipelineResponse : System.IDisposable { protected PipelineResponse() { } - public virtual System.BinaryData Content { get { throw null; } } + public abstract System.BinaryData Content { get; } public abstract System.IO.Stream? ContentStream { get; set; } public System.ClientModel.Primitives.PipelineResponseHeaders Headers { get { throw null; } } public virtual bool IsError { get { throw null; } } @@ -212,6 +213,8 @@ protected PipelineResponse() { } public abstract int Status { get; } public abstract void Dispose(); protected abstract System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore(); + public abstract System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); + public abstract System.Threading.Tasks.ValueTask ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); protected virtual void SetIsErrorCore(bool isError) { } } public abstract partial class PipelineResponseHeaders : System.Collections.Generic.IEnumerable>, System.Collections.IEnumerable @@ -245,10 +248,4 @@ protected void AssertNotFrozen() { } public virtual void Freeze() { } public void SetHeader(string name, string value) { } } - public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy - { - public ResponseBufferingPolicy() { } - public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } - public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } - } } diff --git a/sdk/core/System.ClientModel/src/Convenience/ClientResultException.cs b/sdk/core/System.ClientModel/src/Convenience/ClientResultException.cs index bf86832f1dbab..b27ba2167b378 100644 --- a/sdk/core/System.ClientModel/src/Convenience/ClientResultException.cs +++ b/sdk/core/System.ClientModel/src/Convenience/ClientResultException.cs @@ -6,6 +6,7 @@ using System.Globalization; using System.Runtime.Serialization; using System.Text; +using System.Threading.Tasks; namespace System.ClientModel; @@ -17,6 +18,12 @@ public class ClientResultException : Exception, ISerializable private readonly PipelineResponse? _response; private int _status; + public static async Task CreateAsync(PipelineResponse response, Exception? innerException = default) + { + string message = await CreateMessageAsync(response).ConfigureAwait(false); + return new ClientResultException(message, response, innerException); + } + /// /// Gets the HTTP status code of the response. Returns. 0 if response was not received. /// @@ -66,8 +73,21 @@ public override void GetObjectData(SerializationInfo info, StreamingContext cont public PipelineResponse? GetRawResponse() => _response; private static string CreateMessage(PipelineResponse response) + => CreateMessageSyncOrAsync(response, async: false).EnsureCompleted(); + + private static async ValueTask CreateMessageAsync(PipelineResponse response) + => await CreateMessageSyncOrAsync(response, async: true).ConfigureAwait(false); + + private static async ValueTask CreateMessageSyncOrAsync(PipelineResponse response, bool async) { - response.BufferContent(); + if (async) + { + await response.ReadContentAsync().ConfigureAwait(false); + } + else + { + response.ReadContent(); + } StringBuilder messageBuilder = new(); diff --git a/sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs b/sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs index 869516e461bf9..41680950be29e 100644 --- a/sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs +++ b/sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs @@ -41,11 +41,32 @@ private static void ThrowOperationCanceledException(Exception? innerException, C /// Throws a cancellation exception if cancellation has been requested via . /// The token to check for a cancellation request. - internal static void ThrowIfCancellationRequested(CancellationToken cancellationToken) + /// The inner exception to wrap. May be null. + internal static void ThrowIfCancellationRequested(CancellationToken cancellationToken, Exception? innerException = default) { if (cancellationToken.IsCancellationRequested) { - ThrowOperationCanceledException(innerException: null, cancellationToken); + ThrowOperationCanceledException(innerException, cancellationToken); + } + } + + /// Throws a cancellation exception if cancellation has been requested via or . + /// The user-provided token. + /// The linked token that is cancelled on timeout provided token. + /// The inner exception to use. + /// The timeout used for the operation. +#pragma warning disable CA1068 // Cancellation token has to be the last parameter + internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken messageToken, CancellationToken timeoutToken, Exception? innerException, TimeSpan timeout) +#pragma warning restore CA1068 + { + ThrowIfCancellationRequested(messageToken, innerException); + + if (timeoutToken.IsCancellationRequested) + { + throw CreateOperationCanceledException( + innerException, + timeoutToken, + $"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. "); } } -} \ No newline at end of file +} diff --git a/sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs b/sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs index fd8c3bdd96d76..95d9b6c1cd2df 100644 --- a/sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs +++ b/sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System.ClientModel.Primitives; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -60,7 +59,7 @@ public override void Flush() public override int Read(byte[] buffer, int offset, int count) { - var source = StartTimeout(default, out bool dispose); + CancellationTokenSource source = StartTimeout(default, out bool dispose); try { return _stream.Read(buffer, offset, count); @@ -68,18 +67,18 @@ public override int Read(byte[] buffer, int offset, int count) // We dispose stream on timeout so catch and check if cancellation token was cancelled catch (IOException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); throw; } // We dispose stream on timeout so catch and check if cancellation token was cancelled catch (ObjectDisposedException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); throw; } catch (OperationCanceledException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout); throw; } finally @@ -90,7 +89,7 @@ public override int Read(byte[] buffer, int offset, int count) public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - var source = StartTimeout(cancellationToken, out bool dispose); + CancellationTokenSource source = StartTimeout(cancellationToken, out bool dispose); try { #pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets @@ -100,18 +99,18 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, // We dispose stream on timeout so catch and check if cancellation token was cancelled catch (IOException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); throw; } // We dispose stream on timeout so catch and check if cancellation token was cancelled catch (ObjectDisposedException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); throw; } catch (OperationCanceledException ex) { - ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout); throw; } finally diff --git a/sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs b/sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs index 21fdef014f56d..a7c206fa7debd 100644 --- a/sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs +++ b/sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs @@ -12,6 +12,9 @@ namespace System.ClientModel.Internal; internal static class StreamExtensions { + // Same value as Stream.CopyTo uses by default + private const int DefaultCopyBufferSize = 81920; + public static async Task WriteAsync(this Stream stream, ReadOnlyMemory buffer, CancellationToken cancellation = default) { Argument.AssertNotNull(stream, nameof(stream)); @@ -86,4 +89,45 @@ public static async Task WriteAsync(this Stream stream, ReadOnlySequence b ArrayPool.Shared.Return(array); } } + + public static async Task CopyToAsync(this Stream source, Stream destination, CancellationToken cancellationToken) + { + byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); + + try + { + while (true) + { +#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets + int bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); +#pragma warning restore // ReadAsync(Memory<>) overload is not available in all targets + if (bytesRead == 0) + break; + await destination.WriteAsync(new ReadOnlyMemory(buffer, 0, bytesRead), cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + public static void CopyTo(this Stream source, Stream destination, CancellationToken cancellationToken) + { + byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); + + try + { + int read; + while ((read = source.Read(buffer, 0, buffer.Length)) != 0) + { + cancellationToken.ThrowIfCancellationRequested(); + destination.Write(buffer, 0, read); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } } diff --git a/sdk/core/System.ClientModel/src/Message/PipelineResponse.cs b/sdk/core/System.ClientModel/src/Message/PipelineResponse.cs index 8e13b24ff2d03..5e280ea5d5938 100644 --- a/sdk/core/System.ClientModel/src/Message/PipelineResponse.cs +++ b/sdk/core/System.ClientModel/src/Message/PipelineResponse.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System.Buffers; -using System.ClientModel.Internal; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -12,7 +10,7 @@ namespace System.ClientModel.Primitives; public abstract class PipelineResponse : IDisposable { // TODO(matell): The .NET Framework team plans to add BinaryData.Empty in dotnet/runtime#49670, and we can use it then. - private static readonly BinaryData s_emptyBinaryData = new(Array.Empty()); + internal static readonly BinaryData s_EmptyBinaryData = new(Array.Empty()); private bool _isError = false; @@ -35,30 +33,11 @@ public abstract class PipelineResponse : IDisposable /// public abstract Stream? ContentStream { get; set; } - public virtual BinaryData Content - { - get - { - if (ContentStream == null) - { - return s_emptyBinaryData; - } - - if (!TryGetBufferedContent(out MemoryStream bufferedContent)) - { - throw new InvalidOperationException($"The response is not buffered."); - } - - if (bufferedContent.TryGetBuffer(out ArraySegment segment)) - { - return new BinaryData(segment.AsMemory()); - } - else - { - return new BinaryData(bufferedContent.ToArray()); - } - } - } + public abstract BinaryData Content { get; } + + public abstract BinaryData ReadContent(CancellationToken cancellationToken = default); + + public abstract ValueTask ReadContentAsync(CancellationToken cancellationToken = default); /// /// Indicates whether the status code of the returned response is considered @@ -73,100 +52,5 @@ public virtual BinaryData Content protected virtual void SetIsErrorCore(bool isError) => _isError = isError; - internal TimeSpan NetworkTimeout { get; set; } = ClientPipeline.DefaultNetworkTimeout; - public abstract void Dispose(); - - #region Response Buffering - - // Same value as Stream.CopyTo uses by default - private const int DefaultCopyBufferSize = 81920; - - internal bool TryGetBufferedContent(out MemoryStream bufferedContent) - { - if (ContentStream is MemoryStream content) - { - bufferedContent = content; - return true; - } - - bufferedContent = default!; - return false; - } - - internal void BufferContent(TimeSpan? timeout = default, CancellationTokenSource? cts = default) - { - Stream? responseContentStream = ContentStream; - if (responseContentStream == null || TryGetBufferedContent(out _)) - { - // No need to buffer content. - return; - } - - MemoryStream bufferStream = new(); - CopyTo(responseContentStream, bufferStream, timeout ?? NetworkTimeout, cts ?? new CancellationTokenSource()); - responseContentStream.Dispose(); - bufferStream.Position = 0; - ContentStream = bufferStream; - } - - internal async Task BufferContentAsync(TimeSpan? timeout = default, CancellationTokenSource? cts = default) - { - Stream? responseContentStream = ContentStream; - if (responseContentStream == null || TryGetBufferedContent(out _)) - { - // No need to buffer content. - return; - } - - MemoryStream bufferStream = new(); - await CopyToAsync(responseContentStream, bufferStream, timeout ?? NetworkTimeout, cts ?? new CancellationTokenSource()).ConfigureAwait(false); - responseContentStream.Dispose(); - bufferStream.Position = 0; - ContentStream = bufferStream; - } - - private static async Task CopyToAsync(Stream source, Stream destination, TimeSpan timeout, CancellationTokenSource cancellationTokenSource) - { - byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); - try - { - while (true) - { - cancellationTokenSource.CancelAfter(timeout); -#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets - int bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationTokenSource.Token).ConfigureAwait(false); -#pragma warning restore // ReadAsync(Memory<>) overload is not available in all targets - if (bytesRead == 0) break; - await destination.WriteAsync(new ReadOnlyMemory(buffer, 0, bytesRead), cancellationTokenSource.Token).ConfigureAwait(false); - } - } - finally - { - cancellationTokenSource.CancelAfter(Timeout.InfiniteTimeSpan); - ArrayPool.Shared.Return(buffer); - } - } - - private static void CopyTo(Stream source, Stream destination, TimeSpan timeout, CancellationTokenSource cancellationTokenSource) - { - byte[] buffer = ArrayPool.Shared.Rent(DefaultCopyBufferSize); - try - { - int read; - while ((read = source.Read(buffer, 0, buffer.Length)) != 0) - { - cancellationTokenSource.Token.ThrowIfCancellationRequested(); - cancellationTokenSource.CancelAfter(timeout); - destination.Write(buffer, 0, read); - } - } - finally - { - cancellationTokenSource.CancelAfter(Timeout.InfiniteTimeSpan); - ArrayPool.Shared.Return(buffer); - } - } - - #endregion } diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs index 7d49e441cdfdf..4f518de1ccc84 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs @@ -69,7 +69,6 @@ public static ClientPipeline Create( pipelineLength += options.BeforeTransportPolicies?.Length ?? 0; pipelineLength++; // for retry policy - pipelineLength++; // for response buffering policy pipelineLength++; // for transport PipelinePolicy[] policies = new PipelinePolicy[pipelineLength]; @@ -103,9 +102,6 @@ public static ClientPipeline Create( int perTryIndex = index; - // Response buffering comes before the transport. - policies[index++] = ResponseBufferingPolicy.Default; - // Before transport policies come before the transport. beforeTransportPolicies.CopyTo(policies.AsSpan(index)); index += beforeTransportPolicies.Length; diff --git a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.Response.cs b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.Response.cs index dc7c8e2f5228e..cfddb4a4c6fed 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.Response.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.Response.cs @@ -1,14 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.ClientModel.Internal; using System.IO; using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; namespace System.ClientModel.Primitives; public partial class HttpClientPipelineTransport { - private class HttpPipelineResponse : PipelineResponse + private class HttpClientTransportResponse : PipelineResponse { private readonly HttpResponseMessage _httpResponse; @@ -20,13 +23,17 @@ private class HttpPipelineResponse : PipelineResponse private readonly HttpContent _httpResponseContent; private Stream? _contentStream; + private BinaryData? _bufferedContent; private bool _disposed; - public HttpPipelineResponse(HttpResponseMessage httpResponse) + public HttpClientTransportResponse(HttpResponseMessage httpResponse) { _httpResponse = httpResponse ?? throw new ArgumentNullException(nameof(httpResponse)); _httpResponseContent = _httpResponse.Content; + + // Don't dispose the content so it remains available for reading headers. + _httpResponse.Content = null; } public override int Status => (int)_httpResponse.StatusCode; @@ -39,17 +46,103 @@ protected override PipelineResponseHeaders GetHeadersCore() public override Stream? ContentStream { - get => _contentStream; - set + get { - // Make sure we don't dispose the content if the stream was replaced - _httpResponse.Content = null; + if (_contentStream is not null) + { + return _contentStream; + } + + if (_bufferedContent is not null) + { + return _bufferedContent.ToStream(); + } + return null; + } + set + { _contentStream = value; + + // Invalidate the cache since the source-stream has been replaced. + _bufferedContent = null; } } - #region IDisposable + public override BinaryData Content + { + get + { + if (_bufferedContent is not null) + { + return _bufferedContent; + } + + if (_contentStream is null || _contentStream is MemoryStream) + { + return ReadContent(); + } + + throw new InvalidOperationException($"The response is not buffered."); + } + } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + => ReadContentSyncOrAsync(cancellationToken, async: false).EnsureCompleted(); + + public override async ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + => await ReadContentSyncOrAsync(cancellationToken, async: true).ConfigureAwait(false); + + private async ValueTask ReadContentSyncOrAsync(CancellationToken cancellationToken, bool async) + { + if (_bufferedContent is not null) + { + // Content has already been buffered. + return _bufferedContent; + } + + if (_contentStream == null) + { + // Content is not buffered but there is no source stream. + // Our contract from Azure.Core is to return BinaryData.Empty in this case. + _bufferedContent = s_EmptyBinaryData; + return _bufferedContent; + } + + if (_contentStream.CanSeek && _contentStream.Position != 0) + { + throw new InvalidOperationException("Content stream position is not at beginning of stream."); + } + + // ContentStream still holds the source stream. Buffer the content + // and dispose the source stream. + BufferedContentStream bufferStream = new(); + + if (async) + { + await _contentStream.CopyToAsync(bufferStream, cancellationToken).ConfigureAwait(false); +#if NETSTANDARD2_0 + _contentStream.Dispose(); +#else + await _contentStream.DisposeAsync().ConfigureAwait(false); +#endif + } + else + { + _contentStream.CopyTo(bufferStream, cancellationToken); + _contentStream.Dispose(); + } + + _contentStream = null; + + bufferStream.Position = 0; + + _bufferedContent = bufferStream.TryGetBuffer(out ArraySegment segment) ? + new BinaryData(segment.AsMemory()) : + new BinaryData(bufferStream.ToArray()); + + return _bufferedContent; + } public override void Dispose() { @@ -62,41 +155,22 @@ protected virtual void Dispose(bool disposing) { if (disposing && !_disposed) { - var httpResponse = _httpResponse; + HttpResponseMessage httpResponse = _httpResponse; httpResponse?.Dispose(); - // Some notes on this: - // - // 1. If the content is buffered, we want it to remain available to the - // client for model deserialization and in case the end user of the - // client calls OutputMessage.GetRawResponse. So, we don't dispose it. - // - // If the content is buffered, we assume that the entity that did the - // buffering took responsibility for disposing the network stream. - // - // 2. If the content is not buffered, we dispose it so that we don't leave - // a network connection open. - // - // One tricky piece here is that in some cases, we may not have buffered - // the content because we wanted to pass the live network stream out of - // the client method and back to the end-user caller of the client e.g. - // for a streaming API. If the latter is the case, the client should have - // called the HttpMessage.ExtractResponseContent method to obtain a reference - // to the network stream, and the response content was replaced by a stream - // that we are ok to dispose here. In this case, the network stream is - // not disposed, because the entity that replaced the response content - // intentionally left the network stream undisposed. - - var contentStream = _contentStream; - if (contentStream is not null && !TryGetBufferedContent(out _)) + if (ContentStream is MemoryStream) { - contentStream?.Dispose(); - _contentStream = null; + ReadContent(); } + Stream? contentStream = _contentStream; + contentStream?.Dispose(); + _contentStream = null; + _disposed = true; } } - #endregion + + private class BufferedContentStream : MemoryStream { } } -} \ No newline at end of file +} diff --git a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs index fb800ca3e7ca9..d25a89823dc9e 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs @@ -147,7 +147,7 @@ private async ValueTask ProcessSyncOrAsync(PipelineMessage message, bool async) throw new ClientResultException(e.Message, response: default, e); } - message.Response = new HttpPipelineResponse(responseMessage); + message.Response = new HttpClientTransportResponse(responseMessage); // This extensibility point lets derived types do the following: // 1. Set message.Response to an implementation-specific type, e.g. Azure.Core.Response. diff --git a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs index a516ed3a4cef6..c34fb164ee36b 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs @@ -1,83 +1,170 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.ClientModel.Internal; using System.Collections.Generic; using System.Diagnostics; +using System.IO; +using System.Threading; using System.Threading.Tasks; namespace System.ClientModel.Primitives; public abstract class PipelineTransport : PipelinePolicy { + #region CreateMessage + /// /// TBD: needed for inheritdoc. /// - /// - public void Process(PipelineMessage message) + public PipelineMessage CreateMessage() { - ProcessCore(message); + PipelineMessage message = CreateMessageCore(); + message.NetworkTimeout ??= ClientPipeline.DefaultNetworkTimeout; + + if (message.Request is null) + { + throw new InvalidOperationException("Request was not set on message."); + } - if (message.Response is null) + if (message.Response is not null) { - throw new InvalidOperationException("Response was not set by transport."); + throw new InvalidOperationException("Response should not be set before transport is invoked."); } - message.Response.SetIsError(ClassifyResponse(message)); + return message; } + protected abstract PipelineMessage CreateMessageCore(); + + #endregion + + #region Process message + + /// + /// TBD: needed for inheritdoc. + /// + /// + public void Process(PipelineMessage message) + => ProcessSyncOrAsync(message, async: false).EnsureCompleted(); + /// /// TBD: needed for inheritdoc. /// /// public async ValueTask ProcessAsync(PipelineMessage message) + => await ProcessSyncOrAsync(message, async: true).ConfigureAwait(false); + + private async ValueTask ProcessSyncOrAsync(PipelineMessage message, bool async) { - await ProcessCoreAsync(message).ConfigureAwait(false); + Debug.Assert(message.NetworkTimeout is not null, "NetworkTimeout is not set on PipelineMessage."); - if (message.Response is null) + // Implement network timeout behavior around call to ProcessCore. + TimeSpan networkTimeout = (TimeSpan)message.NetworkTimeout!; + CancellationToken messageToken = message.CancellationToken; + using CancellationTokenSource timeoutTokenSource = CancellationTokenSource.CreateLinkedTokenSource(messageToken); + timeoutTokenSource.CancelAfter(networkTimeout); + + try + { + message.CancellationToken = timeoutTokenSource.Token; + + if (async) + { + await ProcessCoreAsync(message).ConfigureAwait(false); + } + else + { + ProcessCore(message); + } + } + catch (OperationCanceledException ex) { - throw new InvalidOperationException("Response was not set by transport."); + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(messageToken, timeoutTokenSource.Token, ex, networkTimeout); + throw; + } + finally + { + message.CancellationToken = messageToken; + timeoutTokenSource.CancelAfter(Timeout.Infinite); } - message.Response.SetIsError(ClassifyResponse(message)); - } + message.AssertResponse(); + message.Response!.SetIsError(ClassifyResponse(message)); - private static bool ClassifyResponse(PipelineMessage message) - { - if (!message.ResponseClassifier.TryClassify(message, out bool isError)) + // The remainder of the method handles response content according to + // buffering logic specified by value of message.BufferResponse. + + Stream? contentStream = message.Response!.ContentStream; + if (contentStream is null) { - bool classified = PipelineMessageClassifier.Default.TryClassify(message, out isError); + // There is no response content. + return; + } - Debug.Assert(classified); + if (!message.BufferResponse) + { + // Client has requested not to buffer the message response content. + // If applicable, wrap it in a read-timeout stream. + if (networkTimeout != Timeout.InfiniteTimeSpan) + { + message.Response.ContentStream = new ReadTimeoutStream(contentStream, networkTimeout); + } + + return; } - return isError; + try + { + // If cancellation is possible (whether due to network timeout or a user + // cancellation token being passed), then register callback to dispose the + // stream on cancellation. + if (networkTimeout != Timeout.InfiniteTimeSpan || messageToken.CanBeCanceled) + { + timeoutTokenSource.Token.Register(state => ((Stream?)state)?.Dispose(), contentStream); + timeoutTokenSource.CancelAfter(networkTimeout); + } + + if (async) + { + await message.Response.ReadContentAsync(timeoutTokenSource.Token).ConfigureAwait(false); + } + else + { + message.Response.ReadContent(timeoutTokenSource.Token); + } + } + // We dispose stream on timeout or user cancellation so catch and check if + // cancellation token was cancelled + catch (Exception ex) when (ex is ObjectDisposedException + or IOException + or OperationCanceledException + or NotSupportedException) + { + CancellationHelper.ThrowIfCancellationRequestedOrTimeout(messageToken, timeoutTokenSource.Token, ex, networkTimeout); + throw; + } } protected abstract void ProcessCore(PipelineMessage message); protected abstract ValueTask ProcessCoreAsync(PipelineMessage message); - /// - /// TBD: needed for inheritdoc. - /// - public PipelineMessage CreateMessage() + private static bool ClassifyResponse(PipelineMessage message) { - PipelineMessage message = CreateMessageCore(); - - if (message.Request is null) + if (!message.ResponseClassifier.TryClassify(message, out bool isError)) { - throw new InvalidOperationException("Request was not set on message."); - } + bool classified = PipelineMessageClassifier.Default.TryClassify(message, out isError); - if (message.Response is not null) - { - throw new InvalidOperationException("Response should not be set before transport is invoked."); + Debug.Assert(classified, "Error classifier did not classify message."); } - return message; + return isError; } - protected abstract PipelineMessage CreateMessageCore(); + #endregion + + #region PipelinePolicy.Process overrides // These methods from PipelinePolicy just say "you've reached the end // of the line", i.e. they stop the invocation of the policy chain. @@ -85,13 +172,15 @@ public sealed override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) { await ProcessAsync(message).ConfigureAwait(false); - Debug.Assert(++currentIndex == pipeline.Count); + Debug.Assert(++currentIndex == pipeline.Count, "Transport is not at last position in pipeline."); } + + #endregion } diff --git a/sdk/core/System.ClientModel/src/Pipeline/ResponseBufferingPolicy.cs b/sdk/core/System.ClientModel/src/Pipeline/ResponseBufferingPolicy.cs deleted file mode 100644 index 9bec762e3cd22..0000000000000 --- a/sdk/core/System.ClientModel/src/Pipeline/ResponseBufferingPolicy.cs +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace System.ClientModel.Primitives; - -/// -/// Pipeline policy to buffer response content or add a timeout to response content -/// managed by the client. -/// -public class ResponseBufferingPolicy : PipelinePolicy -{ - internal static readonly ResponseBufferingPolicy Default = new(); - - public ResponseBufferingPolicy() - { - } - - public sealed override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) - => ProcessSyncOrAsync(message, pipeline, currentIndex, async: false).EnsureCompleted(); - - public sealed override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) - => await ProcessSyncOrAsync(message, pipeline, currentIndex, async: true).ConfigureAwait(false); - - private async ValueTask ProcessSyncOrAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex, bool async) - { - Debug.Assert(message.NetworkTimeout is not null); - - TimeSpan invocationNetworkTimeout = (TimeSpan)message.NetworkTimeout!; - - CancellationToken oldToken = message.CancellationToken; - using CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(oldToken); - cts.CancelAfter(invocationNetworkTimeout); - - try - { - message.CancellationToken = cts.Token; - if (async) - { - await ProcessNextAsync(message, pipeline, currentIndex).ConfigureAwait(false); - } - else - { - ProcessNext(message, pipeline, currentIndex); - } - } - catch (OperationCanceledException ex) - { - ThrowIfCancellationRequestedOrTimeout(oldToken, cts.Token, ex, invocationNetworkTimeout); - throw; - } - finally - { - message.CancellationToken = oldToken; - cts.CancelAfter(Timeout.Infinite); - } - - message.AssertResponse(); - message.Response!.NetworkTimeout = invocationNetworkTimeout; - - Stream? responseContentStream = message.Response!.ContentStream; - if (responseContentStream is null || - message.Response.TryGetBufferedContent(out var _)) - { - // There is either no content on the response, or the content has already - // been buffered. - return; - } - - if (!message.BufferResponse) - { - // Client has requested not to buffer the message response content. - // If applicable, wrap it in a read-timeout stream. - if (invocationNetworkTimeout != Timeout.InfiniteTimeSpan) - { - message.Response.ContentStream = new ReadTimeoutStream(responseContentStream, invocationNetworkTimeout); - } - - return; - } - - // If we got this far, buffer the response. - - // If cancellation is possible (whether due to network timeout or a user cancellation token being passed), then - // register callback to dispose the stream on cancellation. - if (invocationNetworkTimeout != Timeout.InfiniteTimeSpan || oldToken.CanBeCanceled) - { - cts.Token.Register(state => ((Stream?)state)?.Dispose(), responseContentStream); - } - - try - { - if (async) - { - await message.Response.BufferContentAsync(invocationNetworkTimeout, cts).ConfigureAwait(false); - } - else - { - message.Response.BufferContent(invocationNetworkTimeout, cts); - } - } - // We dispose stream on timeout or user cancellation so catch and check if cancellation token was cancelled - catch (Exception ex) - when (ex is ObjectDisposedException - or IOException - or OperationCanceledException - or NotSupportedException) - { - ThrowIfCancellationRequestedOrTimeout(oldToken, cts.Token, ex, invocationNetworkTimeout); - throw; - } - } - - /// Throws a cancellation exception if cancellation has been requested via or . - /// The customer provided token. - /// The linked token that is cancelled on timeout provided token. - /// The inner exception to use. - /// The timeout used for the operation. -#pragma warning disable CA1068 // Cancellation token has to be the last parameter - internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken originalToken, CancellationToken timeoutToken, Exception? inner, TimeSpan timeout) -#pragma warning restore CA1068 - { - CancellationHelper.ThrowIfCancellationRequested(originalToken); - - if (timeoutToken.IsCancellationRequested) - { - throw CancellationHelper.CreateOperationCanceledException( - inner, - timeoutToken, - $"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. "); - } - } -} \ No newline at end of file diff --git a/sdk/core/System.ClientModel/tests/Convenience/ClientRequestExceptionTests.cs b/sdk/core/System.ClientModel/tests/Convenience/ClientRequestExceptionTests.cs index 173090c729d9b..3f538573846ee 100644 --- a/sdk/core/System.ClientModel/tests/Convenience/ClientRequestExceptionTests.cs +++ b/sdk/core/System.ClientModel/tests/Convenience/ClientRequestExceptionTests.cs @@ -5,6 +5,7 @@ using NUnit.Framework; using System.ClientModel.Primitives; using System.IO; +using System.Threading.Tasks; namespace System.ClientModel.Tests.Exceptions; @@ -24,6 +25,20 @@ public void CanCreateFromResponse() exception.Message); } + [Test] + public async Task CanCreateFromAsyncFactory() + { + PipelineResponse response = new MockPipelineResponse(200, "MockReason"); + + ClientResultException exception = await ClientResultException.CreateAsync(response); + + Assert.AreEqual(response.Status, exception.Status); + Assert.AreEqual(response, exception.GetRawResponse()); + Assert.AreEqual( + $"Service request failed.{Environment.NewLine}Status: 200 (MockReason){Environment.NewLine}", + exception.Message); + } + [Test] public void PassingMessageOverridesResponseMessage() { diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs index 1bc236899fcdd..e4163ffc85f59 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs @@ -5,6 +5,8 @@ using System.ClientModel.Primitives; using System.IO; using System.Text; +using System.Threading; +using System.Threading.Tasks; namespace ClientModel.Tests.Mocks; @@ -13,6 +15,7 @@ public class MockPipelineResponse : PipelineResponse private int _status; private string _reasonPhrase; private Stream? _contentStream; + private BinaryData? _bufferedContent; private readonly PipelineResponseHeaders _headers; @@ -50,6 +53,31 @@ public override Stream? ContentStream set => _contentStream = value; } + public override BinaryData Content + { + get + { + if (_contentStream is null) + { + return new BinaryData(Array.Empty()); + } + + if (ContentStream is not MemoryStream memoryContent) + { + throw new InvalidOperationException($"The response is not buffered."); + } + + if (memoryContent.TryGetBuffer(out ArraySegment segment)) + { + return new BinaryData(segment.AsMemory()); + } + else + { + return new BinaryData(memoryContent.ToArray()); + } + } + } + protected override PipelineResponseHeaders GetHeadersCore() => _headers; public sealed override void Dispose() @@ -63,7 +91,7 @@ protected void Dispose(bool disposing) { if (disposing && !_disposed) { - var content = _contentStream; + Stream? content = _contentStream; if (content != null) { _contentStream = null; @@ -73,4 +101,59 @@ protected void Dispose(bool disposing) _disposed = true; } } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + if (_bufferedContent is not null) + { + return _bufferedContent; + } + + if (_contentStream is null) + { + _bufferedContent = new BinaryData(Array.Empty()); + return _bufferedContent; + } + + MemoryStream bufferStream = new(); + _contentStream.CopyTo(bufferStream); + _contentStream.Dispose(); + _contentStream = bufferStream; + + // Less efficient FromStream method called here because it is a mock. + // For intended production implementation, see HttpClientTransportResponse. + _bufferedContent = BinaryData.FromStream(bufferStream); + return _bufferedContent; + } + + public override async ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + if (_bufferedContent is not null) + { + return _bufferedContent; + } + + if (_contentStream is null) + { + _bufferedContent = new BinaryData(Array.Empty()); + return _bufferedContent; + } + + MemoryStream bufferStream = new(); + +#if NETSTANDARD2_0 || NETFRAMEWORK + await _contentStream.CopyToAsync(bufferStream).ConfigureAwait(false); + _contentStream.Dispose(); +#else + await _contentStream.CopyToAsync(bufferStream, cancellationToken).ConfigureAwait(false); + await _contentStream.DisposeAsync().ConfigureAwait(false); +#endif + + _contentStream = bufferStream; + + // Less efficient FromStream method called here because it is a mock. + // For intended production implementation, see HttpClientTransportResponse. + _bufferedContent = BinaryData.FromStream(bufferStream); + return _bufferedContent; + } } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs index 17cb7968dd7d5..c811ae1304462 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs @@ -6,6 +6,7 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace ClientModel.Tests.Mocks; @@ -176,11 +177,23 @@ public override Stream? ContentStream set => throw new NotImplementedException(); } + public override BinaryData Content => throw new NotImplementedException(); + protected override PipelineResponseHeaders GetHeadersCore() { throw new NotImplementedException(); } public override void Dispose() { } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } } } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs index 7bb84c2bf79f8..b9cc98db95d90 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs @@ -6,6 +6,7 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace ClientModel.Tests.Mocks; @@ -137,6 +138,8 @@ public override Stream? ContentStream set => throw new NotImplementedException(); } + public override BinaryData Content => throw new NotImplementedException(); + protected override PipelineResponseHeaders GetHeadersCore() { throw new NotImplementedException(); @@ -146,5 +149,15 @@ public override void Dispose() { throw new NotImplementedException(); } + + public override BinaryData ReadContent(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override ValueTask ReadContentAsync(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } } }