Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClientModel: Move buffering into the transport #41772

Merged
merged 12 commits into from
Feb 13, 2024
11 changes: 4 additions & 7 deletions sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public ClientResultException(System.ClientModel.Primitives.PipelineResponse resp
protected ClientResultException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
public ClientResultException(string message, System.ClientModel.Primitives.PipelineResponse? response = null, System.Exception? innerException = null) { }
public int Status { get { throw null; } protected set { } }
public static System.Threading.Tasks.Task<System.ClientModel.ClientResultException> CreateAsync(System.ClientModel.Primitives.PipelineResponse response, System.Exception? innerException = null) { throw null; }
public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
public System.ClientModel.Primitives.PipelineResponse? GetRawResponse() { throw null; }
}
Expand Down Expand Up @@ -205,14 +206,16 @@ protected PipelineRequestHeaders() { }
public abstract partial class PipelineResponse : System.IDisposable
{
protected PipelineResponse() { }
public virtual System.BinaryData Content { get { throw null; } }
public abstract System.BinaryData Content { get; }
public abstract System.IO.Stream? ContentStream { get; set; }
public System.ClientModel.Primitives.PipelineResponseHeaders Headers { get { throw null; } }
public virtual bool IsError { get { throw null; } }
public abstract string ReasonPhrase { get; }
public abstract int Status { get; }
public abstract void Dispose();
protected abstract System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore();
public abstract System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
public abstract System.Threading.Tasks.ValueTask<System.BinaryData> ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
protected virtual void SetIsErrorCore(bool isError) { }
}
public abstract partial class PipelineResponseHeaders : System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<string, string>>, System.Collections.IEnumerable
Expand Down Expand Up @@ -246,10 +249,4 @@ protected void AssertNotFrozen() { }
public virtual void Freeze() { }
public void SetHeader(string name, string value) { }
}
public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy
{
public ResponseBufferingPolicy() { }
public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { }
public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { throw null; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public ClientResultException(System.ClientModel.Primitives.PipelineResponse resp
protected ClientResultException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
public ClientResultException(string message, System.ClientModel.Primitives.PipelineResponse? response = null, System.Exception? innerException = null) { }
public int Status { get { throw null; } protected set { } }
public static System.Threading.Tasks.Task<System.ClientModel.ClientResultException> CreateAsync(System.ClientModel.Primitives.PipelineResponse response, System.Exception? innerException = null) { throw null; }
public override void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
public System.ClientModel.Primitives.PipelineResponse? GetRawResponse() { throw null; }
}
Expand Down Expand Up @@ -204,14 +205,16 @@ protected PipelineRequestHeaders() { }
public abstract partial class PipelineResponse : System.IDisposable
{
protected PipelineResponse() { }
public virtual System.BinaryData Content { get { throw null; } }
public abstract System.BinaryData Content { get; }
public abstract System.IO.Stream? ContentStream { get; set; }
public System.ClientModel.Primitives.PipelineResponseHeaders Headers { get { throw null; } }
public virtual bool IsError { get { throw null; } }
public abstract string ReasonPhrase { get; }
public abstract int Status { get; }
public abstract void Dispose();
protected abstract System.ClientModel.Primitives.PipelineResponseHeaders GetHeadersCore();
public abstract System.BinaryData ReadContent(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
public abstract System.Threading.Tasks.ValueTask<System.BinaryData> ReadContentAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
protected virtual void SetIsErrorCore(bool isError) { }
}
public abstract partial class PipelineResponseHeaders : System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<string, string>>, System.Collections.IEnumerable
Expand Down Expand Up @@ -245,10 +248,4 @@ protected void AssertNotFrozen() { }
public virtual void Freeze() { }
public void SetHeader(string name, string value) { }
}
public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy
{
public ResponseBufferingPolicy() { }
public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { }
public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { throw null; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Globalization;
using System.Runtime.Serialization;
using System.Text;
using System.Threading.Tasks;

namespace System.ClientModel;

Expand All @@ -17,6 +18,12 @@ public class ClientResultException : Exception, ISerializable
private readonly PipelineResponse? _response;
private int _status;

public static async Task<ClientResultException> CreateAsync(PipelineResponse response, Exception? innerException = default)
{
string message = await CreateMessageAsync(response).ConfigureAwait(false);
return new ClientResultException(message, response, innerException);
}

/// <summary>
/// Gets the HTTP status code of the response. Returns. <code>0</code> if response was not received.
/// </summary>
Expand Down Expand Up @@ -66,8 +73,21 @@ public override void GetObjectData(SerializationInfo info, StreamingContext cont
public PipelineResponse? GetRawResponse() => _response;

private static string CreateMessage(PipelineResponse response)
=> CreateMessageSyncOrAsync(response, async: false).EnsureCompleted();

private static async ValueTask<string> CreateMessageAsync(PipelineResponse response)
=> await CreateMessageSyncOrAsync(response, async: true).ConfigureAwait(false);

private static async ValueTask<string> CreateMessageSyncOrAsync(PipelineResponse response, bool async)
{
response.BufferContent();
if (async)
{
await response.ReadContentAsync().ConfigureAwait(false);
}
else
{
response.ReadContent();
}

StringBuilder messageBuilder = new();

Expand Down
27 changes: 24 additions & 3 deletions sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,32 @@ private static void ThrowOperationCanceledException(Exception? innerException, C

/// <summary>Throws a cancellation exception if cancellation has been requested via <paramref name="cancellationToken"/>.</summary>
/// <param name="cancellationToken">The token to check for a cancellation request.</param>
internal static void ThrowIfCancellationRequested(CancellationToken cancellationToken)
/// <param name="innerException">The inner exception to wrap. May be null.</param>
internal static void ThrowIfCancellationRequested(CancellationToken cancellationToken, Exception? innerException = default)
{
if (cancellationToken.IsCancellationRequested)
{
ThrowOperationCanceledException(innerException: null, cancellationToken);
ThrowOperationCanceledException(innerException, cancellationToken);
}
}

/// <summary>Throws a cancellation exception if cancellation has been requested via <paramref name="messageToken"/> or <paramref name="timeoutToken"/>.</summary>
/// <param name="messageToken">The user-provided token.</param>
/// <param name="timeoutToken">The linked token that is cancelled on timeout provided token.</param>
/// <param name="innerException">The inner exception to use.</param>
/// <param name="timeout">The timeout used for the operation.</param>
#pragma warning disable CA1068 // Cancellation token has to be the last parameter
internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken messageToken, CancellationToken timeoutToken, Exception? innerException, TimeSpan timeout)
#pragma warning restore CA1068
{
ThrowIfCancellationRequested(messageToken, innerException);

if (timeoutToken.IsCancellationRequested)
{
throw CreateOperationCanceledException(
innerException,
timeoutToken,
$"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. ");
}
}
}
}
17 changes: 8 additions & 9 deletions sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.ClientModel.Primitives;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -60,26 +59,26 @@ public override void Flush()

public override int Read(byte[] buffer, int offset, int count)
{
var source = StartTimeout(default, out bool dispose);
CancellationTokenSource source = StartTimeout(default, out bool dispose);
annelo-msft marked this conversation as resolved.
Show resolved Hide resolved
try
{
return _stream.Read(buffer, offset, count);
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (IOException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (ObjectDisposedException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
catch (OperationCanceledException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
finally
Expand All @@ -90,7 +89,7 @@ public override int Read(byte[] buffer, int offset, int count)

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var source = StartTimeout(cancellationToken, out bool dispose);
CancellationTokenSource source = StartTimeout(cancellationToken, out bool dispose);
annelo-msft marked this conversation as resolved.
Show resolved Hide resolved
try
{
#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets
Expand All @@ -100,18 +99,18 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (IOException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (ObjectDisposedException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
catch (OperationCanceledException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
finally
Expand Down
44 changes: 44 additions & 0 deletions sdk/core/System.ClientModel/src/Internal/StreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ namespace System.ClientModel.Internal;

internal static class StreamExtensions
{
// Same value as Stream.CopyTo uses by default
private const int DefaultCopyBufferSize = 81920;

public static async Task WriteAsync(this Stream stream, ReadOnlyMemory<byte> buffer, CancellationToken cancellation = default)
{
Argument.AssertNotNull(stream, nameof(stream));
Expand Down Expand Up @@ -86,4 +89,45 @@ public static async Task WriteAsync(this Stream stream, ReadOnlySequence<byte> b
ArrayPool<byte>.Shared.Return(array);
}
}

public static async Task CopyToAsync(this Stream source, Stream destination, CancellationToken cancellationToken)
KrzysztofCwalina marked this conversation as resolved.
Show resolved Hide resolved
{
byte[] buffer = ArrayPool<byte>.Shared.Rent(DefaultCopyBufferSize);

try
{
while (true)
{
#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets
int bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
#pragma warning restore // ReadAsync(Memory<>) overload is not available in all targets
if (bytesRead == 0)
break;
await destination.WriteAsync(new ReadOnlyMemory<byte>(buffer, 0, bytesRead), cancellationToken).ConfigureAwait(false);
}
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
}

public static void CopyTo(this Stream source, Stream destination, CancellationToken cancellationToken)
{
byte[] buffer = ArrayPool<byte>.Shared.Rent(DefaultCopyBufferSize);

try
{
int read;
while ((read = source.Read(buffer, 0, buffer.Length)) != 0)
{
cancellationToken.ThrowIfCancellationRequested();
destination.Write(buffer, 0, read);
}
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
}
}
Loading