Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClientModel: Move buffering into the transport #41772

Merged
merged 12 commits into from
Feb 13, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,4 @@ protected void AssertNotFrozen() { }
public virtual void Freeze() { }
public void SetHeader(string name, string value) { }
}
public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy
{
public ResponseBufferingPolicy() { }
public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { }
public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { throw null; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,4 @@ protected void AssertNotFrozen() { }
public virtual void Freeze() { }
public void SetHeader(string name, string value) { }
}
public partial class ResponseBufferingPolicy : System.ClientModel.Primitives.PipelinePolicy
{
public ResponseBufferingPolicy() { }
public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { }
public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList<System.ClientModel.Primitives.PipelinePolicy> pipeline, int currentIndex) { throw null; }
}
}
22 changes: 21 additions & 1 deletion sdk/core/System.ClientModel/src/Internal/CancellationHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,24 @@ internal static void ThrowIfCancellationRequested(CancellationToken cancellation
ThrowOperationCanceledException(innerException: null, cancellationToken);
}
}
}

/// <summary>Throws a cancellation exception if cancellation has been requested via <paramref name="messageToken"/> or <paramref name="timeoutToken"/>.</summary>
/// <param name="messageToken">The user-provided token.</param>
/// <param name="timeoutToken">The linked token that is cancelled on timeout provided token.</param>
/// <param name="innerException">The inner exception to use.</param>
/// <param name="timeout">The timeout used for the operation.</param>
#pragma warning disable CA1068 // Cancellation token has to be the last parameter
internal static void ThrowIfCancellationRequestedOrTimeout(CancellationToken messageToken, CancellationToken timeoutToken, Exception? innerException, TimeSpan timeout)
#pragma warning restore CA1068
{
annelo-msft marked this conversation as resolved.
Show resolved Hide resolved
ThrowIfCancellationRequested(messageToken);

if (timeoutToken.IsCancellationRequested)
{
throw CreateOperationCanceledException(
innerException,
timeoutToken,
$"The operation was cancelled because it exceeded the configured timeout of {timeout:g}. ");
}
}
}
17 changes: 8 additions & 9 deletions sdk/core/System.ClientModel/src/Internal/ReadTimeoutStream.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.ClientModel.Primitives;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -60,26 +59,26 @@ public override void Flush()

public override int Read(byte[] buffer, int offset, int count)
{
var source = StartTimeout(default, out bool dispose);
CancellationTokenSource source = StartTimeout(default, out bool dispose);
annelo-msft marked this conversation as resolved.
Show resolved Hide resolved
try
{
return _stream.Read(buffer, offset, count);
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (IOException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (ObjectDisposedException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
catch (OperationCanceledException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(default, source.Token, ex, _readTimeout);
throw;
}
finally
Expand All @@ -90,7 +89,7 @@ public override int Read(byte[] buffer, int offset, int count)

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var source = StartTimeout(cancellationToken, out bool dispose);
CancellationTokenSource source = StartTimeout(cancellationToken, out bool dispose);
annelo-msft marked this conversation as resolved.
Show resolved Hide resolved
try
{
#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets
Expand All @@ -100,18 +99,18 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (IOException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
// We dispose stream on timeout so catch and check if cancellation token was cancelled
catch (ObjectDisposedException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
catch (OperationCanceledException ex)
{
ResponseBufferingPolicy.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(cancellationToken, source.Token, ex, _readTimeout);
throw;
}
finally
Expand Down
1 change: 1 addition & 0 deletions sdk/core/System.ClientModel/src/Message/PipelineMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ protected internal PipelineMessage(PipelineRequest request)

BufferResponse = true;
ResponseClassifier = PipelineMessageClassifier.Default;
NetworkTimeout = ClientPipeline.DefaultNetworkTimeout;
}

public PipelineRequest Request { get; }
Expand Down
85 changes: 72 additions & 13 deletions sdk/core/System.ClientModel/src/Message/PipelineResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public virtual BinaryData Content

protected virtual void SetIsErrorCore(bool isError) => _isError = isError;

internal TimeSpan NetworkTimeout { get; set; } = ClientPipeline.DefaultNetworkTimeout;
internal TimeSpan NetworkTimeout { get; set; }

public abstract void Dispose();

Expand All @@ -94,23 +94,72 @@ internal bool TryGetBufferedContent(out MemoryStream bufferedContent)
return false;
}

internal void BufferContent(TimeSpan? timeout = default, CancellationTokenSource? cts = default)
internal void ProcessContent(bool bufferResponse, CancellationToken userToken, CancellationTokenSource joinedTokenSource)
annelo-msft marked this conversation as resolved.
Show resolved Hide resolved
=> ProcessContentSyncOrAsync(bufferResponse, userToken, joinedTokenSource, async: false).EnsureCompleted();

internal async Task ProcessContentAsync(bool bufferResponse, CancellationToken userToken, CancellationTokenSource joinedTokenSource)
=> await ProcessContentSyncOrAsync(bufferResponse, userToken, joinedTokenSource, async: true).ConfigureAwait(false);

internal async Task ProcessContentSyncOrAsync(bool bufferResponse, CancellationToken userToken, CancellationTokenSource joinedTokenSource, bool async)
{
Stream? responseContentStream = ContentStream;
if (responseContentStream == null || TryGetBufferedContent(out _))
if (ContentStream is null)
{
// No need to buffer content.
return;
}

MemoryStream bufferStream = new();
CopyTo(responseContentStream, bufferStream, timeout ?? NetworkTimeout, cts ?? new CancellationTokenSource());
responseContentStream.Dispose();
bufferStream.Position = 0;
ContentStream = bufferStream;
if (!bufferResponse)
{
// Don't buffer the response content, e.g. in order to return the
// network stream to the end user of a client as part of a streaming
// API. In this case, we wrap the content stream in a read-timeout
// stream, to respect the client's network timeout setting.
if (NetworkTimeout != Timeout.InfiniteTimeSpan)
{
ContentStream = new ReadTimeoutStream(ContentStream!, NetworkTimeout);
}

return;
}

// If cancellation is possible (whether due to network timeout or a user
// cancellation token being passed), then register a callback to dispose
// the stream on cancellation.
if (NetworkTimeout != Timeout.InfiniteTimeSpan || userToken.CanBeCanceled)
{
joinedTokenSource.Token.Register(state => ((Stream?)state)?.Dispose(), ContentStream);
}

try
{
if (async)
{
await BufferContentAsync(joinedTokenSource).ConfigureAwait(false);
}
else
{
BufferContent(joinedTokenSource);
}
}
// We dispose the stream on timeout or user cancellation so catch and
// check if cancellation token was cancelled.
catch (Exception ex)
when (ex is ObjectDisposedException
or IOException
or OperationCanceledException
or NotSupportedException)
{
CancellationHelper.ThrowIfCancellationRequestedOrTimeout(userToken, joinedTokenSource.Token, ex, NetworkTimeout);
throw;
}
}

internal async Task BufferContentAsync(TimeSpan? timeout = default, CancellationTokenSource? cts = default)
internal void BufferContent(CancellationTokenSource? cts = default)
=> BufferContentSyncOrAsync(cts, async: false).EnsureCompleted();

internal async Task BufferContentAsync(CancellationTokenSource? cts = default)
=> await BufferContentSyncOrAsync(cts, async: true).ConfigureAwait(false);

private async Task BufferContentSyncOrAsync(CancellationTokenSource? cts, bool async)
{
Stream? responseContentStream = ContentStream;
if (responseContentStream == null || TryGetBufferedContent(out _))
Expand All @@ -120,7 +169,16 @@ internal async Task BufferContentAsync(TimeSpan? timeout = default, Cancellation
}

MemoryStream bufferStream = new();
await CopyToAsync(responseContentStream, bufferStream, timeout ?? NetworkTimeout, cts ?? new CancellationTokenSource()).ConfigureAwait(false);

if (async)
{
await CopyToAsync(responseContentStream, bufferStream, NetworkTimeout, cts ?? new CancellationTokenSource()).ConfigureAwait(false);
}
else
{
CopyTo(responseContentStream, bufferStream, NetworkTimeout, cts ?? new CancellationTokenSource());
}

responseContentStream.Dispose();
bufferStream.Position = 0;
ContentStream = bufferStream;
Expand All @@ -137,7 +195,8 @@ private static async Task CopyToAsync(Stream source, Stream destination, TimeSpa
#pragma warning disable CA1835 // ReadAsync(Memory<>) overload is not available in all targets
int bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationTokenSource.Token).ConfigureAwait(false);
#pragma warning restore // ReadAsync(Memory<>) overload is not available in all targets
if (bytesRead == 0) break;
if (bytesRead == 0)
break;
await destination.WriteAsync(new ReadOnlyMemory<byte>(buffer, 0, bytesRead), cancellationTokenSource.Token).ConfigureAwait(false);
}
}
Expand Down
4 changes: 0 additions & 4 deletions sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ public static ClientPipeline Create(
pipelineLength += options.BeforeTransportPolicies?.Length ?? 0;

pipelineLength++; // for retry policy
pipelineLength++; // for response buffering policy
pipelineLength++; // for transport

PipelinePolicy[] policies = new PipelinePolicy[pipelineLength];
Expand Down Expand Up @@ -103,9 +102,6 @@ public static ClientPipeline Create(

int perTryIndex = index;

// Response buffering comes before the transport.
policies[index++] = ResponseBufferingPolicy.Default;

// Before transport policies come before the transport.
beforeTransportPolicies.CopyTo(policies.AsSpan(index));
index += beforeTransportPolicies.Length;
Expand Down
Loading
Loading