From 83873e621af336d24f4e2a3a3138bdd88904e374 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 25 Jan 2024 15:42:06 -0800 Subject: [PATCH 1/6] Add IsRetriable to MessageClassifier --- .../api/System.ClientModel.net6.0.cs | 13 +- .../api/System.ClientModel.netstandard2.0.cs | 13 +- .../src/Internal/ChainingClassifier.cs | 113 ++++++++++ .../src/Message/PipelineMessage.cs | 8 - .../Options/MessageClassificationHandler.cs | 35 ++++ .../src/Options/PipelineMessageClassifier.cs | 44 ++++ .../src/Options/RequestOptions.cs | 58 +++++- .../src/Options/ResponseStatusClassifier.cs | 64 +++--- .../tests/Message/PipelineMessageTests.cs | 4 +- .../tests/Options/RequestOptionsTests.cs | 16 +- .../Pipeline/ClientPipelineFunctionalTests.cs | 2 +- .../tests/Pipeline/ClientPipelineTests.cs | 2 +- .../Mocks/MockPipelineMessage.cs | 7 +- .../Mocks/MockPipelineResponse.cs | 5 +- .../Mocks/MockResponseHeaders.cs | 7 +- .../Options/ChainingClassifierTests.cs | 196 ++++++++++++++++++ 16 files changed, 526 insertions(+), 61 deletions(-) create mode 100644 sdk/core/System.ClientModel/src/Internal/ChainingClassifier.cs create mode 100644 sdk/core/System.ClientModel/src/Options/MessageClassificationHandler.cs create mode 100644 sdk/core/System.ClientModel/tests/internal/Options/ChainingClassifierTests.cs diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index e21f7d222a0fc..92f8f4af64723 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -121,6 +121,12 @@ public partial interface IPersistableModel string GetFormatFromOptions(System.ClientModel.Primitives.ModelReaderWriterOptions options); System.BinaryData Write(System.ClientModel.Primitives.ModelReaderWriterOptions options); } + public partial class MessageClassificationHandler + { + public MessageClassificationHandler() { } + public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError) { throw null; } + public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable) { throw null; } + } public static partial class ModelReaderWriter { public static object? Read(System.BinaryData data, [System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembersAttribute(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.NonPublicConstructors | System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] System.Type returnType, System.ClientModel.Primitives.ModelReaderWriterOptions? options = null) { throw null; } @@ -151,7 +157,6 @@ protected internal PipelineMessage(System.ClientModel.Primitives.PipelineRequest public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineRequest Request { get { throw null; } } public System.ClientModel.Primitives.PipelineResponse? Response { get { throw null; } protected internal set { } } - public void Apply(System.ClientModel.Primitives.RequestOptions options) { } public void Dispose() { } protected virtual void Dispose(bool disposing) { } public void SetProperty(System.Type type, object value) { } @@ -162,6 +167,9 @@ public partial class PipelineMessageClassifier protected internal PipelineMessageClassifier() { } public static System.ClientModel.Primitives.PipelineMessageClassifier Create(System.ReadOnlySpan successStatusCodes) { throw null; } public virtual bool IsErrorResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } + public virtual bool IsRetriable(System.ClientModel.Primitives.PipelineMessage message, System.Exception exception) { throw null; } + public virtual bool IsRetriableException(System.Exception exception) { throw null; } + public virtual bool IsRetriableResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } } public abstract partial class PipelinePolicy { @@ -242,8 +250,11 @@ public partial class RequestOptions public RequestOptions() { } public System.Threading.CancellationToken CancellationToken { get { throw null; } set { } } public System.ClientModel.Primitives.ClientErrorBehaviors ErrorOptions { get { throw null; } set { } } + public void AddClassifier(System.ClientModel.Primitives.MessageClassificationHandler classifier) { } + public void AddClassifier(int statusCode, bool isError) { } public void AddHeader(string name, string value) { } public void AddPolicy(System.ClientModel.Primitives.PipelinePolicy policy, System.ClientModel.Primitives.PipelinePosition position) { } + public virtual void Apply(System.ClientModel.Primitives.PipelineMessage message) { } protected void AssertNotFrozen() { } public virtual void Freeze() { } public void SetHeader(string name, string value) { } diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index cf21ffff8255e..e736c59f0c442 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -121,6 +121,12 @@ public partial interface IPersistableModel string GetFormatFromOptions(System.ClientModel.Primitives.ModelReaderWriterOptions options); System.BinaryData Write(System.ClientModel.Primitives.ModelReaderWriterOptions options); } + public partial class MessageClassificationHandler + { + public MessageClassificationHandler() { } + public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError) { throw null; } + public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable) { throw null; } + } public static partial class ModelReaderWriter { public static object? Read(System.BinaryData data, System.Type returnType, System.ClientModel.Primitives.ModelReaderWriterOptions? options = null) { throw null; } @@ -150,7 +156,6 @@ protected internal PipelineMessage(System.ClientModel.Primitives.PipelineRequest public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineRequest Request { get { throw null; } } public System.ClientModel.Primitives.PipelineResponse? Response { get { throw null; } protected internal set { } } - public void Apply(System.ClientModel.Primitives.RequestOptions options) { } public void Dispose() { } protected virtual void Dispose(bool disposing) { } public void SetProperty(System.Type type, object value) { } @@ -161,6 +166,9 @@ public partial class PipelineMessageClassifier protected internal PipelineMessageClassifier() { } public static System.ClientModel.Primitives.PipelineMessageClassifier Create(System.ReadOnlySpan successStatusCodes) { throw null; } public virtual bool IsErrorResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } + public virtual bool IsRetriable(System.ClientModel.Primitives.PipelineMessage message, System.Exception exception) { throw null; } + public virtual bool IsRetriableException(System.Exception exception) { throw null; } + public virtual bool IsRetriableResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } } public abstract partial class PipelinePolicy { @@ -241,8 +249,11 @@ public partial class RequestOptions public RequestOptions() { } public System.Threading.CancellationToken CancellationToken { get { throw null; } set { } } public System.ClientModel.Primitives.ClientErrorBehaviors ErrorOptions { get { throw null; } set { } } + public void AddClassifier(System.ClientModel.Primitives.MessageClassificationHandler classifier) { } + public void AddClassifier(int statusCode, bool isError) { } public void AddHeader(string name, string value) { } public void AddPolicy(System.ClientModel.Primitives.PipelinePolicy policy, System.ClientModel.Primitives.PipelinePosition position) { } + public virtual void Apply(System.ClientModel.Primitives.PipelineMessage message) { } protected void AssertNotFrozen() { } public virtual void Freeze() { } public void SetHeader(string name, string value) { } diff --git a/sdk/core/System.ClientModel/src/Internal/ChainingClassifier.cs b/sdk/core/System.ClientModel/src/Internal/ChainingClassifier.cs new file mode 100644 index 0000000000000..a409c0c59a8ec --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/ChainingClassifier.cs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; + +namespace System.ClientModel.Internal; + +internal class ChainingClassifier : PipelineMessageClassifier +{ + private MessageClassificationHandler[]? _handlers; + private readonly PipelineMessageClassifier _endOfChain; + + public ChainingClassifier((int Status, bool IsError)[]? statusCodes, + MessageClassificationHandler[]? handlers, + PipelineMessageClassifier endOfChain) + { + if (handlers != null) + { + AddClassifiers(handlers); + } + + if (statusCodes != null) + { + StatusCodeHandler[] handler = { new StatusCodeHandler(statusCodes) }; + AddClassifiers(new ReadOnlySpan(handler)); + } + + _endOfChain = endOfChain; + } + + public override bool IsErrorResponse(PipelineMessage message) + { + if (_handlers != null) + { + foreach (var handler in _handlers) + { + if (handler.TryClassify(message, out bool isError)) + { + return isError; + } + } + } + + return _endOfChain.IsErrorResponse(message); + } + + public override bool IsRetriable(PipelineMessage message, Exception exception) + { + if (_handlers != null) + { + foreach (var handler in _handlers) + { + if (handler.TryClassify(message, exception, out bool isRetriable)) + { + return isRetriable; + } + } + } + + return _endOfChain.IsRetriable(message, exception); + } + + public override bool IsRetriableResponse(PipelineMessage message) + { + if (_handlers != null) + { + foreach (var handler in _handlers) + { + if (handler.TryClassify(message, default, out bool isRetriable)) + { + return isRetriable; + } + } + } + + return _endOfChain.IsRetriableResponse(message); + } + + private void AddClassifiers(ReadOnlySpan handlers) + { + int length = _handlers == null ? 0 : _handlers.Length; + Array.Resize(ref _handlers, length + handlers.Length); + Span target = new Span(_handlers, length, handlers.Length); + handlers.CopyTo(target); + } + + private class StatusCodeHandler : MessageClassificationHandler + { + private readonly (int Status, bool IsError)[] _statusCodes; + + public StatusCodeHandler((int Status, bool IsError)[] statusCodes) + { + _statusCodes = statusCodes; + } + + public override bool TryClassify(PipelineMessage message, out bool isError) + { + message.AssertResponse(); + + foreach (var classification in _statusCodes) + { + if (classification.Status == message.Response!.Status) + { + isError = classification.IsError; + return true; + } + } + + isError = false; + return false; + } + } +} diff --git a/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs b/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs index 736436adbee9a..dec5bbbf8c041 100644 --- a/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs +++ b/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs @@ -66,14 +66,6 @@ public CancellationToken CancellationToken // at the end of the chain. public PipelineMessageClassifier? MessageClassifier { get; set; } - public void Apply(RequestOptions options) - { - // This design moves the client-author API (options.Apply) off the - // client-user type RequestOptions. Its only purpose is to call through to - // the internal options.Apply method. - options.Apply(this); - } - public bool TryGetProperty(Type type, out object? value) => _propertyBag.TryGetValue((ulong)type.TypeHandle.Value, out value); diff --git a/sdk/core/System.ClientModel/src/Options/MessageClassificationHandler.cs b/sdk/core/System.ClientModel/src/Options/MessageClassificationHandler.cs new file mode 100644 index 0000000000000..b56f1e70258c8 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Options/MessageClassificationHandler.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace System.ClientModel.Primitives; + +/// +/// A type that analyzes an HTTP message and determines if the response it holds +/// should be treated as an error response. A classifier of this type may use information +/// from the request, the response, or other message property to decide +/// whether and how to classify the message. +/// +/// This type's TryClassify method allows chaining together handlers before +/// applying default classifier logic. +/// If a handler in the chain returns false from TryClassify, +/// the next handler will be tried, and so on. The first handler that returns true +/// will determine whether the response is an error. +/// +public class MessageClassificationHandler +{ + public virtual bool TryClassify(PipelineMessage message, out bool isError) + { + isError = false; + + // Don't classify for errors unless overridden. + return false; + } + + public virtual bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + isRetriable = false; + + // Don't classify for retries unless overridden. + return false; + } +} diff --git a/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs b/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs index 6bbe849041063..8303f3af76a8d 100644 --- a/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs +++ b/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.IO; + namespace System.ClientModel.Primitives; public class PipelineMessageClassifier @@ -22,4 +24,46 @@ public virtual bool IsErrorResponse(PipelineMessage message) int statusKind = message.Response!.Status / 100; return statusKind == 4 || statusKind == 5; } + + public virtual bool IsRetriableResponse(PipelineMessage message) + { + message.AssertResponse(); + + return message.Response!.Status switch + { + // Request Timeout + 408 => true, + + // Too Many Requests + 429 => true, + + // Internal Server Error + 500 => true, + + // Bad Gateway + 502 => true, + + // Service Unavailable + 503 => true, + + // Gateway Timeout + 504 => true, + + // Default case + _ => false + }; + } + + public virtual bool IsRetriableException(Exception exception) + { + return (exception is IOException) || + (exception is ClientResultException requestFailed && requestFailed.Status == 0); + } + + public virtual bool IsRetriable(PipelineMessage message, Exception exception) + { + return IsRetriableException(exception) || + // Retry non-user initiated cancellations + (exception is OperationCanceledException && !message.CancellationToken.IsCancellationRequested); + } } diff --git a/sdk/core/System.ClientModel/src/Options/RequestOptions.cs b/sdk/core/System.ClientModel/src/Options/RequestOptions.cs index 01ded73811162..364ca3c8f22a3 100644 --- a/sdk/core/System.ClientModel/src/Options/RequestOptions.cs +++ b/sdk/core/System.ClientModel/src/Options/RequestOptions.cs @@ -23,6 +23,12 @@ public class RequestOptions private PipelinePolicy[]? _perTryPolicies; private PipelinePolicy[]? _beforeTransportPolicies; + private (int Status, bool IsError)[]? _statusCodes; + internal (int Status, bool IsError)[]? StatusCodes => _statusCodes; + + private MessageClassificationHandler[]? _handlers; + internal MessageClassificationHandler[]? Handlers => _handlers; + private List? _headersUpdates; public RequestOptions() @@ -95,8 +101,30 @@ public void AddPolicy(PipelinePolicy policy, PipelinePosition position) } } + public void AddClassifier(int statusCode, bool isError) + { + Argument.AssertInRange(statusCode, 100, 599, nameof(statusCode)); + + AssertNotFrozen(); + + int length = _statusCodes == null ? 0 : _statusCodes.Length; + Array.Resize(ref _statusCodes, length + 1); + Array.Copy(_statusCodes, 0, _statusCodes, 1, length); + _statusCodes[0] = (statusCode, isError); + } + + public void AddClassifier(MessageClassificationHandler classifier) + { + AssertNotFrozen(); + + int length = _handlers == null ? 0 : _handlers.Length; + Array.Resize(ref _handlers, length + 1); + Array.Copy(_handlers, 0, _handlers, 1, length); + _handlers[0] = classifier; + } + // Set options on the message before sending it through the pipeline. - internal void Apply(PipelineMessage message) + public virtual void Apply(PipelineMessage message) { Freeze(); @@ -110,7 +138,7 @@ internal void Apply(PipelineMessage message) // This preserves any values set by the client author, and is also // needed for Azure.Core-based clients so we don't overwrite a default // Azure.Core ResponseClassifier. - message.MessageClassifier ??= PipelineMessageClassifier.Default; + message.MessageClassifier = ApplyClassifier(message.MessageClassifier ?? PipelineMessageClassifier.Default); // Copy custom pipeline policies to the message. message.PerCallPolicies = _perCallPolicies; @@ -138,6 +166,32 @@ internal void Apply(PipelineMessage message) } } + internal PipelineMessageClassifier ApplyClassifier(PipelineMessageClassifier classifier) + { + if (_statusCodes == null && _handlers == null) + { + return classifier; + } + + if (classifier is ResponseStatusClassifier statusCodeClassifier) + { + ResponseStatusClassifier clone = statusCodeClassifier.Clone(); + clone.Handlers = _handlers; + + if (_statusCodes != null) + { + foreach (var classification in _statusCodes) + { + clone.AddClassifier(classification.Status, classification.IsError); + } + } + + return clone; + } + + return new ChainingClassifier(_statusCodes, _handlers, classifier); + } + public virtual void Freeze() => _frozen = true; protected void AssertNotFrozen() diff --git a/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs b/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs index 44138c8d89fb4..8919b0bb9d088 100644 --- a/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs +++ b/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs @@ -3,39 +3,49 @@ using System.ClientModel.Internal; -namespace System.ClientModel.Primitives +namespace System.ClientModel.Primitives; + +internal class ResponseStatusClassifier : PipelineMessageClassifier { - internal class ResponseStatusClassifier : PipelineMessageClassifier + private BitVector640 _successCodes; + + internal MessageClassificationHandler[]? Handlers { get; set; } + + /// + /// Creates a new instance of . + /// + /// The status codes that this classifier + /// will consider not to be errors. + public ResponseStatusClassifier(ReadOnlySpan successStatusCodes) { - private BitVector640 _successCodes; - - /// - /// Creates a new instance of . - /// - /// The status codes that this classifier - /// will consider not to be errors. - public ResponseStatusClassifier(ReadOnlySpan successStatusCodes) - { - _successCodes = new(); + _successCodes = new(); - foreach (int statusCode in successStatusCodes) - { - AddClassifier(statusCode, isError: false); - } + foreach (int statusCode in successStatusCodes) + { + AddClassifier(statusCode, isError: false); } + } - public sealed override bool IsErrorResponse(PipelineMessage message) - { - message.AssertResponse(); + private ResponseStatusClassifier(BitVector640 successCodes, MessageClassificationHandler[]? handlers) + { + _successCodes = successCodes; + Handlers = handlers; + } - return !_successCodes[message.Response!.Status]; - } + public sealed override bool IsErrorResponse(PipelineMessage message) + { + message.AssertResponse(); - private void AddClassifier(int statusCode, bool isError) - { - Argument.AssertInRange(statusCode, 0, 639, nameof(statusCode)); + return !_successCodes[message.Response!.Status]; + } - _successCodes[statusCode] = !isError; - } + internal virtual ResponseStatusClassifier Clone() + => new(_successCodes, Handlers); + + internal void AddClassifier(int statusCode, bool isError) + { + Argument.AssertInRange(statusCode, 0, 639, nameof(statusCode)); + + _successCodes[statusCode] = !isError; } -} \ No newline at end of file +} diff --git a/sdk/core/System.ClientModel/tests/Message/PipelineMessageTests.cs b/sdk/core/System.ClientModel/tests/Message/PipelineMessageTests.cs index 28e632f32a321..734177c3ba2c2 100644 --- a/sdk/core/System.ClientModel/tests/Message/PipelineMessageTests.cs +++ b/sdk/core/System.ClientModel/tests/Message/PipelineMessageTests.cs @@ -17,7 +17,7 @@ public void ApplyAddsRequestHeaders() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "MockValue"); - message.Apply(options); + options.Apply(message); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue", value); @@ -34,7 +34,7 @@ public void ApplySetsCancellationToken() RequestOptions options = new RequestOptions(); options.CancellationToken = cts.Token; - message.Apply(options); + options.Apply(message); Assert.AreEqual(message.CancellationToken, cts.Token); Assert.IsFalse(message.CancellationToken.IsCancellationRequested); diff --git a/sdk/core/System.ClientModel/tests/Options/RequestOptionsTests.cs b/sdk/core/System.ClientModel/tests/Options/RequestOptionsTests.cs index b622b2b6f3ce6..a96a693e42620 100644 --- a/sdk/core/System.ClientModel/tests/Options/RequestOptionsTests.cs +++ b/sdk/core/System.ClientModel/tests/Options/RequestOptionsTests.cs @@ -18,7 +18,7 @@ public void CanAddRequestHeaders() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "MockValue"); - message.Apply(options); + options.Apply(message); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue", value); @@ -33,7 +33,7 @@ public void CanAddMultiValueRequestHeaders() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "MockValue1"); options.AddHeader("MockHeader", "MockValue2"); - message.Apply(options); + options.Apply(message); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue1,MockValue2", value); @@ -47,7 +47,7 @@ public void CanSetRequestHeaders() RequestOptions options = new RequestOptions(); options.SetHeader("MockHeader", "MockValue"); - message.Apply(options); + options.Apply(message); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue", value); @@ -62,7 +62,7 @@ public void SetReplacesHeaderAddedViaRequestOptions() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "MockValue1"); options.SetHeader("MockHeader", "MockValue2"); - message.Apply(options); + options.Apply(message); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue2", value); @@ -78,7 +78,7 @@ public void AddHeaderAddsValueToRequestMessageValue() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "RequestOptions Value"); - message.Apply(options); + options.Apply(message); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("Message.Request Value,RequestOptions Value", value); @@ -94,7 +94,7 @@ public void SetReplacesHeaderAddedViaMessageRequest() RequestOptions options = new RequestOptions(); options.SetHeader("MockHeader", "RequestOptions Value"); - message.Apply(options); + options.Apply(message); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("RequestOptions Value", value); @@ -120,7 +120,7 @@ public void CanInterleaveAddAndSetCalls() options.SetHeader("MockHeader2", "RequestOptions SetHeader Value 1"); options.AddHeader("MockHeader2", "RequestOptions AddHeader Value 2"); - message.Apply(options); + options.Apply(message); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader1", out string? value1)); Assert.AreEqual("RequestOptions SetHeader Value 2", value1); @@ -139,7 +139,7 @@ public void CannotModifyOptionsAfterFrozen() PipelineMessage message = pipeline.CreateMessage(); RequestOptions options = new RequestOptions(); - message.Apply(options); + options.Apply(message); Assert.Throws(() => options.CancellationToken = CancellationToken.None); Assert.Throws(() => options.ErrorOptions = ClientErrorBehaviors.NoThrow); diff --git a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs index 86e3bc1d831fe..66642472a073c 100644 --- a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs +++ b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs @@ -372,7 +372,7 @@ public async Task DoesntRetryClientCancellation() // Set CancellationToken on the message. RequestOptions options = new() { CancellationToken = cts.Token }; - message.Apply(options); + options.Apply(message); var task = Task.Run(() => pipeline.SendSyncOrAsync(message, IsAsync)); diff --git a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs index 7ea521d2d3fb0..4c985c0aaf118 100644 --- a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs +++ b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs @@ -305,7 +305,7 @@ public async Task RequestOptionsCanCustomizePipeline() requestOptions.AddPolicy(new ObservablePolicy("B"), PipelinePosition.PerTry); PipelineMessage message = pipeline.CreateMessage(); - message.Apply(requestOptions); + requestOptions.Apply(message); await pipeline.SendSyncOrAsync(message, IsAsync); List observations = ObservablePolicy.GetData(message); diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineMessage.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineMessage.cs index 7d8430a4e016c..be1e8f288f9b9 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineMessage.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineMessage.cs @@ -1,12 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System; using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace ClientModel.Tests.Mocks; @@ -21,5 +16,5 @@ public MockPipelineMessage(PipelineRequest request) : base(request) } public void SetResponse(PipelineResponse response) - => this.Response = response; + => Response = response; } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs index 1bc236899fcdd..6d8926db844ce 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs @@ -14,7 +14,7 @@ public class MockPipelineResponse : PipelineResponse private string _reasonPhrase; private Stream? _contentStream; - private readonly PipelineResponseHeaders _headers; + private readonly MockResponseHeaders _headers; private bool _disposed; @@ -52,6 +52,9 @@ public override Stream? ContentStream protected override PipelineResponseHeaders GetHeadersCore() => _headers; + public void AddHeader(string key, string value) + => _headers.AddHeader(key, value); + public sealed override void Dispose() { Dispose(true); diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs index 614997980beb4..41189df704a6c 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs @@ -16,15 +16,16 @@ public MockResponseHeaders() _headers = new Dictionary(); } + public void AddHeader(string key, string value) + => _headers.Add(key, value); + public override IEnumerator> GetEnumerator() { throw new NotImplementedException(); } public override bool TryGetValue(string name, out string? value) - { - return _headers.TryGetValue(name, out value); - } + => _headers.TryGetValue(name, out value); public override bool TryGetValues(string name, out IEnumerable? values) { diff --git a/sdk/core/System.ClientModel/tests/internal/Options/ChainingClassifierTests.cs b/sdk/core/System.ClientModel/tests/internal/Options/ChainingClassifierTests.cs new file mode 100644 index 0000000000000..08807dd182cc7 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Options/ChainingClassifierTests.cs @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.ClientModel.Primitives; +using ClientModel.Tests.Mocks; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal.Options; + +public class ChainingClassifierTests +{ + [Test] + public void ClassifiesUsingOnlyEndOfChain() + { + ChainingClassifier classifier = new ChainingClassifier( + statusCodes: null, + handlers: null, + HelperResponseClassifier.Instance); + + MockPipelineMessage message = new(); + + message.SetResponse(new MockPipelineResponse(204)); + Assert.IsFalse(classifier.IsErrorResponse(message)); + + message.SetResponse(new MockPipelineResponse(304)); + Assert.IsFalse(classifier.IsErrorResponse(message)); + + message.SetResponse(new MockPipelineResponse(404)); + Assert.IsTrue(classifier.IsErrorResponse(message)); + + message.SetResponse(new MockPipelineResponse(500)); + Assert.IsTrue(classifier.IsErrorResponse(message)); + } + + [Test] + public void ClassifiesUsingHandlersAndEndOfChain() + { + MessageClassificationHandler[] handlers = new MessageClassificationHandler[] + { + new HeaderClassificationHandler(204, "ErrorCode", "Error"), + new HeaderClassificationHandler(404, "ErrorCode", "NonError"), + }; + + ChainingClassifier classifier = new ChainingClassifier( + statusCodes: null, + handlers: handlers, + HelperResponseClassifier.Instance); + + MockPipelineMessage message = new(); + + var response = new MockPipelineResponse(204); + response.AddHeader("ErrorCode", "Error"); + message.Response = response; + Assert.IsTrue(classifier.IsErrorResponse(message)); + + response = new MockPipelineResponse(304); + response.AddHeader("ErrorCode", "Error"); + message.Response = response; + Assert.IsFalse(classifier.IsErrorResponse(message)); + + response = new MockPipelineResponse(404); + response.AddHeader("ErrorCode", "Error"); + message.Response = response; + Assert.IsFalse(classifier.IsErrorResponse(message)); + + response = new MockPipelineResponse(500); + response.AddHeader("ErrorCode", "Error"); + message.Response = response; + Assert.IsTrue(classifier.IsErrorResponse(message)); + } + + [Test] + public void ClassifiesUsingStatusCodesAndEndOfChain() + { + (int Status, bool IsError)[] classifications = new (int Status, bool IsError)[] + { + (204, true), + (404, false), + (500, false), + }; + + ChainingClassifier classifier = new ChainingClassifier( + statusCodes: classifications, + handlers: null, + HelperResponseClassifier.Instance); + + MockPipelineMessage message = new(); + + message.Response = new MockPipelineResponse(204); + Assert.IsTrue(classifier.IsErrorResponse(message)); + + message.Response = new MockPipelineResponse(304); + Assert.IsFalse(classifier.IsErrorResponse(message)); + + message.Response = new MockPipelineResponse(404); + Assert.IsFalse(classifier.IsErrorResponse(message)); + + message.Response = new MockPipelineResponse(500); + Assert.IsFalse(classifier.IsErrorResponse(message)); + } + + [Test] + public void ClassifiesUsingAllHandlersTakePrecedenceOverStatusCodes() + { + (int Status, bool IsError)[] classifications = new (int Status, bool IsError)[] + { + (204, false), + (500, false), + }; + + MessageClassificationHandler[] handlers = new MessageClassificationHandler[] + { + new HeaderClassificationHandler(204, "ErrorCode", "Error"), + new HeaderClassificationHandler(404, "ErrorCode", "NonError"), + }; + + ChainingClassifier classifier = new ChainingClassifier( + statusCodes: classifications, + handlers: handlers, + HelperResponseClassifier.Instance); + + MockPipelineMessage message = new(); + + // Handler takes precedence + var response = new MockPipelineResponse(204); + response.AddHeader("ErrorCode", "Error"); + message.Response = response; + Assert.IsTrue(classifier.IsErrorResponse(message)); + + // End of chain is reached + message.Response = new MockPipelineResponse(304); + Assert.IsFalse(classifier.IsErrorResponse(message)); + + // Handler takes precedence + response = new MockPipelineResponse(404); + response.AddHeader("ErrorCode", "Error"); + message.Response = response; + Assert.IsFalse(classifier.IsErrorResponse(message)); + + // Status code handler is reached + message.Response = new MockPipelineResponse(500); + Assert.IsFalse(classifier.IsErrorResponse(message)); + } + + #region Helpers + private sealed class HelperResponseClassifier : PipelineMessageClassifier + { + public static PipelineMessageClassifier Instance = new HelperResponseClassifier(); + public override bool IsErrorResponse(PipelineMessage message) + { + message.AssertResponse(); + + return message.Response!.Status switch + { + >= 100 and < 400 => false, + _ => true + }; + } + } + + public class HeaderClassificationHandler : MessageClassificationHandler + { + private readonly int _statusCode; + private readonly string _headerName; // e.g. "ErrorCode"; + private readonly string _headerValue; // e.g. "LeaseNotAquired"; + + public HeaderClassificationHandler(int statusCode, string headerName, string headerValue) + { + _statusCode = statusCode; + _headerName = headerName; + _headerValue = headerValue; + } + + public override bool TryClassify(PipelineMessage message, out bool isError) + { + isError = false; + + message.AssertResponse(); + + if (message.Response!.Status != _statusCode) + { + return false; + } + + if (message.Response.Headers.TryGetValue(_headerName, out string? value) && + _headerValue == value) + { + isError = true; + } + + return true; + } + } + #endregion +} From c34b6fd8c38c97a0e093deca1cfbbfa1f2271c24 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Fri, 26 Jan 2024 07:50:52 -0800 Subject: [PATCH 2/6] simplify APIs --- .../api/System.ClientModel.net6.0.cs | 19 +- .../api/System.ClientModel.netstandard2.0.cs | 19 +- .../src/Internal/ChainingClassifier.cs | 113 ---------- .../src/Message/PipelineMessage.cs | 8 + .../Options/MessageClassificationHandler.cs | 35 ---- .../src/Options/PipelineMessageClassifier.cs | 98 +++++---- .../src/Options/RequestOptions.cs | 58 +----- .../src/Options/ResponseStatusClassifier.cs | 20 +- .../src/Pipeline/ClientRetryPolicy.cs | 58 +----- .../src/Pipeline/PipelineTransport.cs | 16 +- .../tests/Message/PipelineMessageTests.cs | 4 +- .../Options/PipelineMessageClassifierTests.cs | 106 +++++++++- .../tests/Options/RequestOptionsTests.cs | 16 +- .../Pipeline/ClientPipelineFunctionalTests.cs | 2 +- .../tests/Pipeline/ClientPipelineTests.cs | 2 +- .../Mocks/MockMessageClassifier.cs | 8 +- .../Mocks/MockPipelineMessage.cs | 7 +- .../Mocks/MockPipelineResponse.cs | 5 +- .../Mocks/MockResponseHeaders.cs | 7 +- .../Options/ChainingClassifierTests.cs | 196 ------------------ 20 files changed, 240 insertions(+), 557 deletions(-) delete mode 100644 sdk/core/System.ClientModel/src/Internal/ChainingClassifier.cs delete mode 100644 sdk/core/System.ClientModel/src/Options/MessageClassificationHandler.cs delete mode 100644 sdk/core/System.ClientModel/tests/internal/Options/ChainingClassifierTests.cs diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index 92f8f4af64723..b9f2779b562f8 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -121,12 +121,6 @@ public partial interface IPersistableModel string GetFormatFromOptions(System.ClientModel.Primitives.ModelReaderWriterOptions options); System.BinaryData Write(System.ClientModel.Primitives.ModelReaderWriterOptions options); } - public partial class MessageClassificationHandler - { - public MessageClassificationHandler() { } - public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError) { throw null; } - public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable) { throw null; } - } public static partial class ModelReaderWriter { public static object? Read(System.BinaryData data, [System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembersAttribute(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.NonPublicConstructors | System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] System.Type returnType, System.ClientModel.Primitives.ModelReaderWriterOptions? options = null) { throw null; } @@ -157,6 +151,7 @@ protected internal PipelineMessage(System.ClientModel.Primitives.PipelineRequest public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineRequest Request { get { throw null; } } public System.ClientModel.Primitives.PipelineResponse? Response { get { throw null; } protected internal set { } } + public void Apply(System.ClientModel.Primitives.RequestOptions options) { } public void Dispose() { } protected virtual void Dispose(bool disposing) { } public void SetProperty(System.Type type, object value) { } @@ -164,12 +159,11 @@ public void SetProperty(System.Type type, object value) { } } public partial class PipelineMessageClassifier { - protected internal PipelineMessageClassifier() { } + protected PipelineMessageClassifier() { } + public static System.ClientModel.Primitives.PipelineMessageClassifier Default { get { throw null; } } public static System.ClientModel.Primitives.PipelineMessageClassifier Create(System.ReadOnlySpan successStatusCodes) { throw null; } - public virtual bool IsErrorResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } - public virtual bool IsRetriable(System.ClientModel.Primitives.PipelineMessage message, System.Exception exception) { throw null; } - public virtual bool IsRetriableException(System.Exception exception) { throw null; } - public virtual bool IsRetriableResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } + public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError) { throw null; } + public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable) { throw null; } } public abstract partial class PipelinePolicy { @@ -250,11 +244,8 @@ public partial class RequestOptions public RequestOptions() { } public System.Threading.CancellationToken CancellationToken { get { throw null; } set { } } public System.ClientModel.Primitives.ClientErrorBehaviors ErrorOptions { get { throw null; } set { } } - public void AddClassifier(System.ClientModel.Primitives.MessageClassificationHandler classifier) { } - public void AddClassifier(int statusCode, bool isError) { } public void AddHeader(string name, string value) { } public void AddPolicy(System.ClientModel.Primitives.PipelinePolicy policy, System.ClientModel.Primitives.PipelinePosition position) { } - public virtual void Apply(System.ClientModel.Primitives.PipelineMessage message) { } protected void AssertNotFrozen() { } public virtual void Freeze() { } public void SetHeader(string name, string value) { } diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index e736c59f0c442..0c3250d3adc0b 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -121,12 +121,6 @@ public partial interface IPersistableModel string GetFormatFromOptions(System.ClientModel.Primitives.ModelReaderWriterOptions options); System.BinaryData Write(System.ClientModel.Primitives.ModelReaderWriterOptions options); } - public partial class MessageClassificationHandler - { - public MessageClassificationHandler() { } - public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError) { throw null; } - public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable) { throw null; } - } public static partial class ModelReaderWriter { public static object? Read(System.BinaryData data, System.Type returnType, System.ClientModel.Primitives.ModelReaderWriterOptions? options = null) { throw null; } @@ -156,6 +150,7 @@ protected internal PipelineMessage(System.ClientModel.Primitives.PipelineRequest public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineRequest Request { get { throw null; } } public System.ClientModel.Primitives.PipelineResponse? Response { get { throw null; } protected internal set { } } + public void Apply(System.ClientModel.Primitives.RequestOptions options) { } public void Dispose() { } protected virtual void Dispose(bool disposing) { } public void SetProperty(System.Type type, object value) { } @@ -163,12 +158,11 @@ public void SetProperty(System.Type type, object value) { } } public partial class PipelineMessageClassifier { - protected internal PipelineMessageClassifier() { } + protected PipelineMessageClassifier() { } + public static System.ClientModel.Primitives.PipelineMessageClassifier Default { get { throw null; } } public static System.ClientModel.Primitives.PipelineMessageClassifier Create(System.ReadOnlySpan successStatusCodes) { throw null; } - public virtual bool IsErrorResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } - public virtual bool IsRetriable(System.ClientModel.Primitives.PipelineMessage message, System.Exception exception) { throw null; } - public virtual bool IsRetriableException(System.Exception exception) { throw null; } - public virtual bool IsRetriableResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } + public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError) { throw null; } + public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable) { throw null; } } public abstract partial class PipelinePolicy { @@ -249,11 +243,8 @@ public partial class RequestOptions public RequestOptions() { } public System.Threading.CancellationToken CancellationToken { get { throw null; } set { } } public System.ClientModel.Primitives.ClientErrorBehaviors ErrorOptions { get { throw null; } set { } } - public void AddClassifier(System.ClientModel.Primitives.MessageClassificationHandler classifier) { } - public void AddClassifier(int statusCode, bool isError) { } public void AddHeader(string name, string value) { } public void AddPolicy(System.ClientModel.Primitives.PipelinePolicy policy, System.ClientModel.Primitives.PipelinePosition position) { } - public virtual void Apply(System.ClientModel.Primitives.PipelineMessage message) { } protected void AssertNotFrozen() { } public virtual void Freeze() { } public void SetHeader(string name, string value) { } diff --git a/sdk/core/System.ClientModel/src/Internal/ChainingClassifier.cs b/sdk/core/System.ClientModel/src/Internal/ChainingClassifier.cs deleted file mode 100644 index a409c0c59a8ec..0000000000000 --- a/sdk/core/System.ClientModel/src/Internal/ChainingClassifier.cs +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Primitives; - -namespace System.ClientModel.Internal; - -internal class ChainingClassifier : PipelineMessageClassifier -{ - private MessageClassificationHandler[]? _handlers; - private readonly PipelineMessageClassifier _endOfChain; - - public ChainingClassifier((int Status, bool IsError)[]? statusCodes, - MessageClassificationHandler[]? handlers, - PipelineMessageClassifier endOfChain) - { - if (handlers != null) - { - AddClassifiers(handlers); - } - - if (statusCodes != null) - { - StatusCodeHandler[] handler = { new StatusCodeHandler(statusCodes) }; - AddClassifiers(new ReadOnlySpan(handler)); - } - - _endOfChain = endOfChain; - } - - public override bool IsErrorResponse(PipelineMessage message) - { - if (_handlers != null) - { - foreach (var handler in _handlers) - { - if (handler.TryClassify(message, out bool isError)) - { - return isError; - } - } - } - - return _endOfChain.IsErrorResponse(message); - } - - public override bool IsRetriable(PipelineMessage message, Exception exception) - { - if (_handlers != null) - { - foreach (var handler in _handlers) - { - if (handler.TryClassify(message, exception, out bool isRetriable)) - { - return isRetriable; - } - } - } - - return _endOfChain.IsRetriable(message, exception); - } - - public override bool IsRetriableResponse(PipelineMessage message) - { - if (_handlers != null) - { - foreach (var handler in _handlers) - { - if (handler.TryClassify(message, default, out bool isRetriable)) - { - return isRetriable; - } - } - } - - return _endOfChain.IsRetriableResponse(message); - } - - private void AddClassifiers(ReadOnlySpan handlers) - { - int length = _handlers == null ? 0 : _handlers.Length; - Array.Resize(ref _handlers, length + handlers.Length); - Span target = new Span(_handlers, length, handlers.Length); - handlers.CopyTo(target); - } - - private class StatusCodeHandler : MessageClassificationHandler - { - private readonly (int Status, bool IsError)[] _statusCodes; - - public StatusCodeHandler((int Status, bool IsError)[] statusCodes) - { - _statusCodes = statusCodes; - } - - public override bool TryClassify(PipelineMessage message, out bool isError) - { - message.AssertResponse(); - - foreach (var classification in _statusCodes) - { - if (classification.Status == message.Response!.Status) - { - isError = classification.IsError; - return true; - } - } - - isError = false; - return false; - } - } -} diff --git a/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs b/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs index dec5bbbf8c041..736436adbee9a 100644 --- a/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs +++ b/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs @@ -66,6 +66,14 @@ public CancellationToken CancellationToken // at the end of the chain. public PipelineMessageClassifier? MessageClassifier { get; set; } + public void Apply(RequestOptions options) + { + // This design moves the client-author API (options.Apply) off the + // client-user type RequestOptions. Its only purpose is to call through to + // the internal options.Apply method. + options.Apply(this); + } + public bool TryGetProperty(Type type, out object? value) => _propertyBag.TryGetValue((ulong)type.TypeHandle.Value, out value); diff --git a/sdk/core/System.ClientModel/src/Options/MessageClassificationHandler.cs b/sdk/core/System.ClientModel/src/Options/MessageClassificationHandler.cs deleted file mode 100644 index b56f1e70258c8..0000000000000 --- a/sdk/core/System.ClientModel/src/Options/MessageClassificationHandler.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace System.ClientModel.Primitives; - -/// -/// A type that analyzes an HTTP message and determines if the response it holds -/// should be treated as an error response. A classifier of this type may use information -/// from the request, the response, or other message property to decide -/// whether and how to classify the message. -/// -/// This type's TryClassify method allows chaining together handlers before -/// applying default classifier logic. -/// If a handler in the chain returns false from TryClassify, -/// the next handler will be tried, and so on. The first handler that returns true -/// will determine whether the response is an error. -/// -public class MessageClassificationHandler -{ - public virtual bool TryClassify(PipelineMessage message, out bool isError) - { - isError = false; - - // Don't classify for errors unless overridden. - return false; - } - - public virtual bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) - { - isRetriable = false; - - // Don't classify for retries unless overridden. - return false; - } -} diff --git a/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs b/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs index 8303f3af76a8d..07677034cc4d1 100644 --- a/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs +++ b/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs @@ -7,63 +7,85 @@ namespace System.ClientModel.Primitives; public class PipelineMessageClassifier { - internal static PipelineMessageClassifier Default { get; } = new PipelineMessageClassifier(); + public static PipelineMessageClassifier Default { get; } = new EndOfChainClassifier(); public static PipelineMessageClassifier Create(ReadOnlySpan successStatusCodes) => new ResponseStatusClassifier(successStatusCodes); - protected internal PipelineMessageClassifier() { } + protected PipelineMessageClassifier() { } - /// - /// Specifies if the response contained in the is not successful. - /// - public virtual bool IsErrorResponse(PipelineMessage message) + public virtual bool TryClassify(PipelineMessage message, out bool isError) { - message.AssertResponse(); + isError = false; + return false; + } - int statusKind = message.Response!.Status / 100; - return statusKind == 4 || statusKind == 5; + public virtual bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + isRetriable = false; + return false; } - public virtual bool IsRetriableResponse(PipelineMessage message) + internal class EndOfChainClassifier : PipelineMessageClassifier { - message.AssertResponse(); + public override bool TryClassify(PipelineMessage message, out bool isError) + { + message.AssertResponse(); + + int statusKind = message.Response!.Status / 100; + isError = statusKind == 4 || statusKind == 5; - return message.Response!.Status switch + // Always classify the message + return true; + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) { - // Request Timeout - 408 => true, + isRetriable = exception is null ? + IsRetriable(message) : + IsRetriable(message, exception); - // Too Many Requests - 429 => true, + // Always classify the message + return true; + } - // Internal Server Error - 500 => true, + private static bool IsRetriable(PipelineMessage message) + { + message.AssertResponse(); - // Bad Gateway - 502 => true, + return message.Response!.Status switch + { + // Request Timeout + 408 => true, - // Service Unavailable - 503 => true, + // Too Many Requests + 429 => true, - // Gateway Timeout - 504 => true, + // Internal Server Error + 500 => true, - // Default case - _ => false - }; - } + // Bad Gateway + 502 => true, - public virtual bool IsRetriableException(Exception exception) - { - return (exception is IOException) || - (exception is ClientResultException requestFailed && requestFailed.Status == 0); - } + // Service Unavailable + 503 => true, - public virtual bool IsRetriable(PipelineMessage message, Exception exception) - { - return IsRetriableException(exception) || - // Retry non-user initiated cancellations - (exception is OperationCanceledException && !message.CancellationToken.IsCancellationRequested); + // Gateway Timeout + 504 => true, + + // Default case + _ => false + }; + } + + private static bool IsRetriable(PipelineMessage message, Exception exception) + => IsRetriable(exception) || + // Retry non-user initiated cancellations + (exception is OperationCanceledException && + !message.CancellationToken.IsCancellationRequested); + + private static bool IsRetriable(Exception exception) + => (exception is IOException) || + (exception is ClientResultException ex && ex.Status == 0); } } diff --git a/sdk/core/System.ClientModel/src/Options/RequestOptions.cs b/sdk/core/System.ClientModel/src/Options/RequestOptions.cs index 364ca3c8f22a3..01ded73811162 100644 --- a/sdk/core/System.ClientModel/src/Options/RequestOptions.cs +++ b/sdk/core/System.ClientModel/src/Options/RequestOptions.cs @@ -23,12 +23,6 @@ public class RequestOptions private PipelinePolicy[]? _perTryPolicies; private PipelinePolicy[]? _beforeTransportPolicies; - private (int Status, bool IsError)[]? _statusCodes; - internal (int Status, bool IsError)[]? StatusCodes => _statusCodes; - - private MessageClassificationHandler[]? _handlers; - internal MessageClassificationHandler[]? Handlers => _handlers; - private List? _headersUpdates; public RequestOptions() @@ -101,30 +95,8 @@ public void AddPolicy(PipelinePolicy policy, PipelinePosition position) } } - public void AddClassifier(int statusCode, bool isError) - { - Argument.AssertInRange(statusCode, 100, 599, nameof(statusCode)); - - AssertNotFrozen(); - - int length = _statusCodes == null ? 0 : _statusCodes.Length; - Array.Resize(ref _statusCodes, length + 1); - Array.Copy(_statusCodes, 0, _statusCodes, 1, length); - _statusCodes[0] = (statusCode, isError); - } - - public void AddClassifier(MessageClassificationHandler classifier) - { - AssertNotFrozen(); - - int length = _handlers == null ? 0 : _handlers.Length; - Array.Resize(ref _handlers, length + 1); - Array.Copy(_handlers, 0, _handlers, 1, length); - _handlers[0] = classifier; - } - // Set options on the message before sending it through the pipeline. - public virtual void Apply(PipelineMessage message) + internal void Apply(PipelineMessage message) { Freeze(); @@ -138,7 +110,7 @@ public virtual void Apply(PipelineMessage message) // This preserves any values set by the client author, and is also // needed for Azure.Core-based clients so we don't overwrite a default // Azure.Core ResponseClassifier. - message.MessageClassifier = ApplyClassifier(message.MessageClassifier ?? PipelineMessageClassifier.Default); + message.MessageClassifier ??= PipelineMessageClassifier.Default; // Copy custom pipeline policies to the message. message.PerCallPolicies = _perCallPolicies; @@ -166,32 +138,6 @@ public virtual void Apply(PipelineMessage message) } } - internal PipelineMessageClassifier ApplyClassifier(PipelineMessageClassifier classifier) - { - if (_statusCodes == null && _handlers == null) - { - return classifier; - } - - if (classifier is ResponseStatusClassifier statusCodeClassifier) - { - ResponseStatusClassifier clone = statusCodeClassifier.Clone(); - clone.Handlers = _handlers; - - if (_statusCodes != null) - { - foreach (var classification in _statusCodes) - { - clone.AddClassifier(classification.Status, classification.IsError); - } - } - - return clone; - } - - return new ChainingClassifier(_statusCodes, _handlers, classifier); - } - public virtual void Freeze() => _frozen = true; protected void AssertNotFrozen() diff --git a/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs b/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs index 8919b0bb9d088..7caf652b36374 100644 --- a/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs +++ b/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs @@ -9,8 +9,6 @@ internal class ResponseStatusClassifier : PipelineMessageClassifier { private BitVector640 _successCodes; - internal MessageClassificationHandler[]? Handlers { get; set; } - /// /// Creates a new instance of . /// @@ -26,23 +24,17 @@ public ResponseStatusClassifier(ReadOnlySpan successStatusCodes) } } - private ResponseStatusClassifier(BitVector640 successCodes, MessageClassificationHandler[]? handlers) - { - _successCodes = successCodes; - Handlers = handlers; - } - - public sealed override bool IsErrorResponse(PipelineMessage message) + public override bool TryClassify(PipelineMessage message, out bool isError) { message.AssertResponse(); - return !_successCodes[message.Response!.Status]; - } + isError = !_successCodes[message.Response!.Status]; - internal virtual ResponseStatusClassifier Clone() - => new(_successCodes, Handlers); + // BitVector-based classifiers should always end any composition chain. + return true; + } - internal void AddClassifier(int statusCode, bool isError) + private void AddClassifier(int statusCode, bool isError) { Argument.AssertInRange(statusCode, 0, 639, nameof(statusCode)); diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs index a6690c30b3bce..ca2967e21e3c6 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs @@ -3,6 +3,7 @@ using System.ClientModel.Internal; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Runtime.ExceptionServices; using System.Threading; @@ -165,9 +166,16 @@ protected virtual bool ShouldRetryCore(PipelineMessage message, Exception? excep return false; } - return exception is null ? - IsRetriable(message) : - IsRetriable(message, exception); + Debug.Assert(message.MessageClassifier is not null); + + if (message.MessageClassifier!.TryClassify(message, exception, out bool isRetriable)) + { + bool classified = PipelineMessageClassifier.Default.TryClassify(message, exception, out isRetriable); + + Debug.Assert(classified); + } + + return isRetriable; } protected virtual ValueTask ShouldRetryCoreAsync(PipelineMessage message, Exception? exception) @@ -200,48 +208,4 @@ protected virtual void WaitCore(TimeSpan time, CancellationToken cancellationTok CancellationHelper.ThrowIfCancellationRequested(cancellationToken); } } - - #region Retry Classifier - - // Overriding response-retriable classification will be added in a later ClientModel release. - private static bool IsRetriable(PipelineMessage message) - { - message.AssertResponse(); - - return message.Response!.Status switch - { - // Request Timeout - 408 => true, - - // Too Many Requests - 429 => true, - - // Internal Server Error - 500 => true, - - // Bad Gateway - 502 => true, - - // Service Unavailable - 503 => true, - - // Gateway Timeout - 504 => true, - - // Default case - _ => false - }; - } - - private static bool IsRetriable(PipelineMessage message, Exception exception) - => IsRetriable(exception) || - // Retry non-user initiated cancellations - (exception is OperationCanceledException && - !message.CancellationToken.IsCancellationRequested); - - private static bool IsRetriable(Exception exception) - => (exception is IOException) || - (exception is ClientResultException ex && ex.Status == 0); - - #endregion } diff --git a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs index 43e974cb885e8..0ee923e05db5a 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs @@ -41,9 +41,19 @@ public async ValueTask ProcessAsync(PipelineMessage message) message.Response.SetIsError(ClassifyResponse(message)); } - private static bool ClassifyResponse(PipelineMessage message) => - message.MessageClassifier?.IsErrorResponse(message) ?? - PipelineMessageClassifier.Default.IsErrorResponse(message); + private static bool ClassifyResponse(PipelineMessage message) + { + var classifier = message.MessageClassifier ?? PipelineMessageClassifier.Default; + + if (!classifier.TryClassify(message, out bool isError)) + { + bool classified = PipelineMessageClassifier.Default.TryClassify(message, out isError); + + Debug.Assert(classified); + } + + return isError; + } protected abstract void ProcessCore(PipelineMessage message); diff --git a/sdk/core/System.ClientModel/tests/Message/PipelineMessageTests.cs b/sdk/core/System.ClientModel/tests/Message/PipelineMessageTests.cs index 734177c3ba2c2..28e632f32a321 100644 --- a/sdk/core/System.ClientModel/tests/Message/PipelineMessageTests.cs +++ b/sdk/core/System.ClientModel/tests/Message/PipelineMessageTests.cs @@ -17,7 +17,7 @@ public void ApplyAddsRequestHeaders() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "MockValue"); - options.Apply(message); + message.Apply(options); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue", value); @@ -34,7 +34,7 @@ public void ApplySetsCancellationToken() RequestOptions options = new RequestOptions(); options.CancellationToken = cts.Token; - options.Apply(message); + message.Apply(options); Assert.AreEqual(message.CancellationToken, cts.Token); Assert.IsFalse(message.CancellationToken.IsCancellationRequested); diff --git a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs index a9b2d6e217e91..47c52de8b7bd6 100644 --- a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs +++ b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.ClientModel.Primitives; +using System.Collections.Generic; using ClientModel.Tests.Mocks; using NUnit.Framework; -using System.ClientModel.Primitives; namespace System.ClientModel.Tests.Options; @@ -23,7 +24,8 @@ public void ClassifiesSingleCodeAsNonError() MockPipelineMessage message = new MockPipelineMessage(); message.SetResponse(new MockPipelineResponse(code)); - bool isNonError = !classifier.IsErrorResponse(message); + classifier.TryClassify(message, out bool isError); + bool isNonError = !isError; if (nonError == code) { @@ -50,6 +52,104 @@ public void ClassifiesMultipleCodesAsNonErrors(int code, bool isError) MockPipelineMessage message = new MockPipelineMessage(); message.SetResponse(new MockPipelineResponse(code)); - Assert.AreEqual(isError, classifier.IsErrorResponse(message)); + Assert.IsTrue(classifier.TryClassify(message, out bool error)); + Assert.AreEqual(isError, error); + } + + [Test] + public void CanComposeErrorClassifiers() + { + var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); + + ChainingClassifier classifier = new ChainingClassifier(last); + classifier.AddClassifier(new SingleStatusCodeClassifier(403, isError: false)); + classifier.AddClassifier(new SingleStatusCodeClassifier(404, isError: false)); + classifier.AddClassifier(new SingleStatusCodeClassifier(201, isError: true)); + + MockPipelineMessage message = new(); + + message.SetResponse(new MockPipelineResponse(200)); + Assert.IsTrue(classifier.TryClassify(message, out bool isError)); + Assert.IsFalse(isError); + + message.SetResponse(new MockPipelineResponse(201)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(isError); + + message.SetResponse(new MockPipelineResponse(204)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsFalse(isError); + + message.SetResponse(new MockPipelineResponse(304)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(isError); + + message.SetResponse(new MockPipelineResponse(403)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsFalse(isError); + + message.SetResponse(new MockPipelineResponse(404)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsFalse(isError); + + message.SetResponse(new MockPipelineResponse(500)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(isError); + } + + #region Helpers + + internal class SingleStatusCodeClassifier : PipelineMessageClassifier + { + private readonly (int, bool) _code; + + public SingleStatusCodeClassifier(int code, bool isError) + { + _code = (code, isError); + } + + public override bool TryClassify(PipelineMessage message, out bool isError) + { + if (message.Response!.Status == _code.Item1) + { + isError = _code.Item2; + return true; + } + + isError = false; + return false; + } } + + internal class ChainingClassifier : PipelineMessageClassifier + { + private readonly List _classifiers; + private readonly PipelineMessageClassifier _endOfChain; + + public ChainingClassifier(PipelineMessageClassifier endOfChain) + { + _classifiers = new(); + _endOfChain = endOfChain; + } + + public void AddClassifier(PipelineMessageClassifier classifier) + { + _classifiers.Add(classifier); + } + + public override bool TryClassify(PipelineMessage message, out bool isError) + { + foreach (var classifier in _classifiers) + { + if (classifier.TryClassify(message, out isError)) + { + return true; + } + } + + return _endOfChain.TryClassify(message, out isError); + } + } + + #endregion } diff --git a/sdk/core/System.ClientModel/tests/Options/RequestOptionsTests.cs b/sdk/core/System.ClientModel/tests/Options/RequestOptionsTests.cs index a96a693e42620..b622b2b6f3ce6 100644 --- a/sdk/core/System.ClientModel/tests/Options/RequestOptionsTests.cs +++ b/sdk/core/System.ClientModel/tests/Options/RequestOptionsTests.cs @@ -18,7 +18,7 @@ public void CanAddRequestHeaders() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "MockValue"); - options.Apply(message); + message.Apply(options); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue", value); @@ -33,7 +33,7 @@ public void CanAddMultiValueRequestHeaders() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "MockValue1"); options.AddHeader("MockHeader", "MockValue2"); - options.Apply(message); + message.Apply(options); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue1,MockValue2", value); @@ -47,7 +47,7 @@ public void CanSetRequestHeaders() RequestOptions options = new RequestOptions(); options.SetHeader("MockHeader", "MockValue"); - options.Apply(message); + message.Apply(options); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue", value); @@ -62,7 +62,7 @@ public void SetReplacesHeaderAddedViaRequestOptions() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "MockValue1"); options.SetHeader("MockHeader", "MockValue2"); - options.Apply(message); + message.Apply(options); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("MockValue2", value); @@ -78,7 +78,7 @@ public void AddHeaderAddsValueToRequestMessageValue() RequestOptions options = new RequestOptions(); options.AddHeader("MockHeader", "RequestOptions Value"); - options.Apply(message); + message.Apply(options); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("Message.Request Value,RequestOptions Value", value); @@ -94,7 +94,7 @@ public void SetReplacesHeaderAddedViaMessageRequest() RequestOptions options = new RequestOptions(); options.SetHeader("MockHeader", "RequestOptions Value"); - options.Apply(message); + message.Apply(options); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader", out string? value)); Assert.AreEqual("RequestOptions Value", value); @@ -120,7 +120,7 @@ public void CanInterleaveAddAndSetCalls() options.SetHeader("MockHeader2", "RequestOptions SetHeader Value 1"); options.AddHeader("MockHeader2", "RequestOptions AddHeader Value 2"); - options.Apply(message); + message.Apply(options); Assert.IsTrue(message.Request.Headers.TryGetValue("MockHeader1", out string? value1)); Assert.AreEqual("RequestOptions SetHeader Value 2", value1); @@ -139,7 +139,7 @@ public void CannotModifyOptionsAfterFrozen() PipelineMessage message = pipeline.CreateMessage(); RequestOptions options = new RequestOptions(); - options.Apply(message); + message.Apply(options); Assert.Throws(() => options.CancellationToken = CancellationToken.None); Assert.Throws(() => options.ErrorOptions = ClientErrorBehaviors.NoThrow); diff --git a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs index 66642472a073c..86e3bc1d831fe 100644 --- a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs +++ b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs @@ -372,7 +372,7 @@ public async Task DoesntRetryClientCancellation() // Set CancellationToken on the message. RequestOptions options = new() { CancellationToken = cts.Token }; - options.Apply(message); + message.Apply(options); var task = Task.Run(() => pipeline.SendSyncOrAsync(message, IsAsync)); diff --git a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs index 4c985c0aaf118..7ea521d2d3fb0 100644 --- a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs +++ b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs @@ -305,7 +305,7 @@ public async Task RequestOptionsCanCustomizePipeline() requestOptions.AddPolicy(new ObservablePolicy("B"), PipelinePosition.PerTry); PipelineMessage message = pipeline.CreateMessage(); - requestOptions.Apply(message); + message.Apply(requestOptions); await pipeline.SendSyncOrAsync(message, IsAsync); List observations = ObservablePolicy.GetData(message); diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs index de2ac4cabcf99..540a6f0049d18 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs @@ -29,7 +29,7 @@ public MockMessageClassifier(string id, int[]? successCodes) public string Id { get; set; } - public override bool IsErrorResponse(PipelineMessage message) + public override bool TryClassify(PipelineMessage message, out bool isError) { if (_successCodes is not null) { @@ -37,13 +37,15 @@ public override bool IsErrorResponse(PipelineMessage message) { if (message.Response!.Status == code) { + isError = true; return true; } } - return false; + isError = false; + return true; } - return base.IsErrorResponse(message); + return base.TryClassify(message, out isError); } } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineMessage.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineMessage.cs index be1e8f288f9b9..7d8430a4e016c 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineMessage.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineMessage.cs @@ -1,7 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; namespace ClientModel.Tests.Mocks; @@ -16,5 +21,5 @@ public MockPipelineMessage(PipelineRequest request) : base(request) } public void SetResponse(PipelineResponse response) - => Response = response; + => this.Response = response; } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs index 6d8926db844ce..1bc236899fcdd 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs @@ -14,7 +14,7 @@ public class MockPipelineResponse : PipelineResponse private string _reasonPhrase; private Stream? _contentStream; - private readonly MockResponseHeaders _headers; + private readonly PipelineResponseHeaders _headers; private bool _disposed; @@ -52,9 +52,6 @@ public override Stream? ContentStream protected override PipelineResponseHeaders GetHeadersCore() => _headers; - public void AddHeader(string key, string value) - => _headers.AddHeader(key, value); - public sealed override void Dispose() { Dispose(true); diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs index 41189df704a6c..614997980beb4 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs @@ -16,16 +16,15 @@ public MockResponseHeaders() _headers = new Dictionary(); } - public void AddHeader(string key, string value) - => _headers.Add(key, value); - public override IEnumerator> GetEnumerator() { throw new NotImplementedException(); } public override bool TryGetValue(string name, out string? value) - => _headers.TryGetValue(name, out value); + { + return _headers.TryGetValue(name, out value); + } public override bool TryGetValues(string name, out IEnumerable? values) { diff --git a/sdk/core/System.ClientModel/tests/internal/Options/ChainingClassifierTests.cs b/sdk/core/System.ClientModel/tests/internal/Options/ChainingClassifierTests.cs deleted file mode 100644 index 08807dd182cc7..0000000000000 --- a/sdk/core/System.ClientModel/tests/internal/Options/ChainingClassifierTests.cs +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.ClientModel.Internal; -using System.ClientModel.Primitives; -using ClientModel.Tests.Mocks; -using NUnit.Framework; - -namespace System.ClientModel.Tests.Internal.Options; - -public class ChainingClassifierTests -{ - [Test] - public void ClassifiesUsingOnlyEndOfChain() - { - ChainingClassifier classifier = new ChainingClassifier( - statusCodes: null, - handlers: null, - HelperResponseClassifier.Instance); - - MockPipelineMessage message = new(); - - message.SetResponse(new MockPipelineResponse(204)); - Assert.IsFalse(classifier.IsErrorResponse(message)); - - message.SetResponse(new MockPipelineResponse(304)); - Assert.IsFalse(classifier.IsErrorResponse(message)); - - message.SetResponse(new MockPipelineResponse(404)); - Assert.IsTrue(classifier.IsErrorResponse(message)); - - message.SetResponse(new MockPipelineResponse(500)); - Assert.IsTrue(classifier.IsErrorResponse(message)); - } - - [Test] - public void ClassifiesUsingHandlersAndEndOfChain() - { - MessageClassificationHandler[] handlers = new MessageClassificationHandler[] - { - new HeaderClassificationHandler(204, "ErrorCode", "Error"), - new HeaderClassificationHandler(404, "ErrorCode", "NonError"), - }; - - ChainingClassifier classifier = new ChainingClassifier( - statusCodes: null, - handlers: handlers, - HelperResponseClassifier.Instance); - - MockPipelineMessage message = new(); - - var response = new MockPipelineResponse(204); - response.AddHeader("ErrorCode", "Error"); - message.Response = response; - Assert.IsTrue(classifier.IsErrorResponse(message)); - - response = new MockPipelineResponse(304); - response.AddHeader("ErrorCode", "Error"); - message.Response = response; - Assert.IsFalse(classifier.IsErrorResponse(message)); - - response = new MockPipelineResponse(404); - response.AddHeader("ErrorCode", "Error"); - message.Response = response; - Assert.IsFalse(classifier.IsErrorResponse(message)); - - response = new MockPipelineResponse(500); - response.AddHeader("ErrorCode", "Error"); - message.Response = response; - Assert.IsTrue(classifier.IsErrorResponse(message)); - } - - [Test] - public void ClassifiesUsingStatusCodesAndEndOfChain() - { - (int Status, bool IsError)[] classifications = new (int Status, bool IsError)[] - { - (204, true), - (404, false), - (500, false), - }; - - ChainingClassifier classifier = new ChainingClassifier( - statusCodes: classifications, - handlers: null, - HelperResponseClassifier.Instance); - - MockPipelineMessage message = new(); - - message.Response = new MockPipelineResponse(204); - Assert.IsTrue(classifier.IsErrorResponse(message)); - - message.Response = new MockPipelineResponse(304); - Assert.IsFalse(classifier.IsErrorResponse(message)); - - message.Response = new MockPipelineResponse(404); - Assert.IsFalse(classifier.IsErrorResponse(message)); - - message.Response = new MockPipelineResponse(500); - Assert.IsFalse(classifier.IsErrorResponse(message)); - } - - [Test] - public void ClassifiesUsingAllHandlersTakePrecedenceOverStatusCodes() - { - (int Status, bool IsError)[] classifications = new (int Status, bool IsError)[] - { - (204, false), - (500, false), - }; - - MessageClassificationHandler[] handlers = new MessageClassificationHandler[] - { - new HeaderClassificationHandler(204, "ErrorCode", "Error"), - new HeaderClassificationHandler(404, "ErrorCode", "NonError"), - }; - - ChainingClassifier classifier = new ChainingClassifier( - statusCodes: classifications, - handlers: handlers, - HelperResponseClassifier.Instance); - - MockPipelineMessage message = new(); - - // Handler takes precedence - var response = new MockPipelineResponse(204); - response.AddHeader("ErrorCode", "Error"); - message.Response = response; - Assert.IsTrue(classifier.IsErrorResponse(message)); - - // End of chain is reached - message.Response = new MockPipelineResponse(304); - Assert.IsFalse(classifier.IsErrorResponse(message)); - - // Handler takes precedence - response = new MockPipelineResponse(404); - response.AddHeader("ErrorCode", "Error"); - message.Response = response; - Assert.IsFalse(classifier.IsErrorResponse(message)); - - // Status code handler is reached - message.Response = new MockPipelineResponse(500); - Assert.IsFalse(classifier.IsErrorResponse(message)); - } - - #region Helpers - private sealed class HelperResponseClassifier : PipelineMessageClassifier - { - public static PipelineMessageClassifier Instance = new HelperResponseClassifier(); - public override bool IsErrorResponse(PipelineMessage message) - { - message.AssertResponse(); - - return message.Response!.Status switch - { - >= 100 and < 400 => false, - _ => true - }; - } - } - - public class HeaderClassificationHandler : MessageClassificationHandler - { - private readonly int _statusCode; - private readonly string _headerName; // e.g. "ErrorCode"; - private readonly string _headerValue; // e.g. "LeaseNotAquired"; - - public HeaderClassificationHandler(int statusCode, string headerName, string headerValue) - { - _statusCode = statusCode; - _headerName = headerName; - _headerValue = headerValue; - } - - public override bool TryClassify(PipelineMessage message, out bool isError) - { - isError = false; - - message.AssertResponse(); - - if (message.Response!.Status != _statusCode) - { - return false; - } - - if (message.Response.Headers.TryGetValue(_headerName, out string? value) && - _headerValue == value) - { - isError = true; - } - - return true; - } - } - #endregion -} From 1314c215c1b2c0fbf19e90539bac9340adb8c828 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Fri, 26 Jan 2024 08:22:40 -0800 Subject: [PATCH 3/6] fix --- .../api/System.ClientModel.net6.0.cs | 2 +- .../api/System.ClientModel.netstandard2.0.cs | 2 +- .../src/Message/PipelineMessage.cs | 3 +- .../src/Options/RequestOptions.cs | 6 - .../src/Pipeline/ClientRetryPolicy.cs | 5 +- .../src/Pipeline/PipelineTransport.cs | 4 +- .../Options/PipelineMessageClassifierTests.cs | 148 +++++++++++++++++- 7 files changed, 149 insertions(+), 21 deletions(-) diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index b9f2779b562f8..f4627ed0d5f28 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -147,7 +147,7 @@ public partial class PipelineMessage : System.IDisposable protected internal PipelineMessage(System.ClientModel.Primitives.PipelineRequest request) { } public bool BufferResponse { get { throw null; } set { } } public System.Threading.CancellationToken CancellationToken { get { throw null; } protected internal set { } } - public System.ClientModel.Primitives.PipelineMessageClassifier? MessageClassifier { get { throw null; } set { } } + public System.ClientModel.Primitives.PipelineMessageClassifier MessageClassifier { get { throw null; } set { } } public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineRequest Request { get { throw null; } } public System.ClientModel.Primitives.PipelineResponse? Response { get { throw null; } protected internal set { } } diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index 0c3250d3adc0b..ef49adc323eb1 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -146,7 +146,7 @@ public partial class PipelineMessage : System.IDisposable protected internal PipelineMessage(System.ClientModel.Primitives.PipelineRequest request) { } public bool BufferResponse { get { throw null; } set { } } public System.Threading.CancellationToken CancellationToken { get { throw null; } protected internal set { } } - public System.ClientModel.Primitives.PipelineMessageClassifier? MessageClassifier { get { throw null; } set { } } + public System.ClientModel.Primitives.PipelineMessageClassifier MessageClassifier { get { throw null; } set { } } public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineRequest Request { get { throw null; } } public System.ClientModel.Primitives.PipelineResponse? Response { get { throw null; } protected internal set { } } diff --git a/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs b/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs index 736436adbee9a..d2be344d7ca52 100644 --- a/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs +++ b/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs @@ -19,6 +19,7 @@ protected internal PipelineMessage(PipelineRequest request) _propertyBag = new ArrayBackedPropertyBag(); BufferResponse = true; + MessageClassifier = PipelineMessageClassifier.Default; } public PipelineRequest Request { get; } @@ -64,7 +65,7 @@ public CancellationToken CancellationToken // the client-provided classifier or compose a chain of classification // handlers that preserve the functionality of the client-provided classifier // at the end of the chain. - public PipelineMessageClassifier? MessageClassifier { get; set; } + public PipelineMessageClassifier MessageClassifier { get; set; } public void Apply(RequestOptions options) { diff --git a/sdk/core/System.ClientModel/src/Options/RequestOptions.cs b/sdk/core/System.ClientModel/src/Options/RequestOptions.cs index 01ded73811162..71e7f8f56c2e1 100644 --- a/sdk/core/System.ClientModel/src/Options/RequestOptions.cs +++ b/sdk/core/System.ClientModel/src/Options/RequestOptions.cs @@ -106,12 +106,6 @@ internal void Apply(PipelineMessage message) // cancellation token will be set again in HttpPipeline.Send. message.CancellationToken = CancellationToken; - // We don't overwrite the classifier on the message if it's already set. - // This preserves any values set by the client author, and is also - // needed for Azure.Core-based clients so we don't overwrite a default - // Azure.Core ResponseClassifier. - message.MessageClassifier ??= PipelineMessageClassifier.Default; - // Copy custom pipeline policies to the message. message.PerCallPolicies = _perCallPolicies; message.PerTryPolicies = _perTryPolicies; diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs index ca2967e21e3c6..7f5ed0368ccac 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs @@ -4,7 +4,6 @@ using System.ClientModel.Internal; using System.Collections.Generic; using System.Diagnostics; -using System.IO; using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; @@ -166,9 +165,7 @@ protected virtual bool ShouldRetryCore(PipelineMessage message, Exception? excep return false; } - Debug.Assert(message.MessageClassifier is not null); - - if (message.MessageClassifier!.TryClassify(message, exception, out bool isRetriable)) + if (message.MessageClassifier.TryClassify(message, exception, out bool isRetriable)) { bool classified = PipelineMessageClassifier.Default.TryClassify(message, exception, out isRetriable); diff --git a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs index 0ee923e05db5a..a590c400d2893 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs @@ -43,9 +43,7 @@ public async ValueTask ProcessAsync(PipelineMessage message) private static bool ClassifyResponse(PipelineMessage message) { - var classifier = message.MessageClassifier ?? PipelineMessageClassifier.Default; - - if (!classifier.TryClassify(message, out bool isError)) + if (!message.MessageClassifier.TryClassify(message, out bool isError)) { bool classified = PipelineMessageClassifier.Default.TryClassify(message, out isError); diff --git a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs index 47c52de8b7bd6..edc99665e8cea 100644 --- a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs +++ b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs @@ -3,6 +3,7 @@ using System.ClientModel.Primitives; using System.Collections.Generic; +using System.Diagnostics; using ClientModel.Tests.Mocks; using NUnit.Framework; @@ -62,9 +63,9 @@ public void CanComposeErrorClassifiers() var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); ChainingClassifier classifier = new ChainingClassifier(last); - classifier.AddClassifier(new SingleStatusCodeClassifier(403, isError: false)); - classifier.AddClassifier(new SingleStatusCodeClassifier(404, isError: false)); - classifier.AddClassifier(new SingleStatusCodeClassifier(201, isError: true)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(403, isError: false)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(404, isError: false)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(201, isError: true)); MockPipelineMessage message = new(); @@ -97,13 +98,109 @@ public void CanComposeErrorClassifiers() Assert.IsTrue(isError); } + [Test] + public void CanComposeRetryClassifiers() + { + var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); + + ChainingClassifier classifier = new ChainingClassifier(last); + classifier.AddClassifier(new RetriableStatussCodeClassifier(403, isRetriable: false)); + classifier.AddClassifier(new RetriableStatussCodeClassifier(404, isRetriable: false)); + classifier.AddClassifier(new RetriableStatussCodeClassifier(201, isRetriable: true)); + + MockPipelineMessage message = new(); + + message.SetResponse(new MockPipelineResponse(200)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out bool isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(201)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isRetriable); + + message.SetResponse(new MockPipelineResponse(204)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(304)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(403)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(404)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(500)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isRetriable); + } + + [Test] + public void CanComposeErrorAndRetryClassifiers() + { + var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); + + ChainingClassifier classifier = new ChainingClassifier(last); + classifier.AddClassifier(new RetriableStatussCodeClassifier(429, isRetriable: false)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(404, isError: false)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(201, isError: true)); + + MockPipelineMessage message = new(); + + message.SetResponse(new MockPipelineResponse(200)); + Assert.IsTrue(classifier.TryClassify(message, out bool isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out bool isRetriable)); + Assert.IsFalse(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(201)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(204)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(304)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(404)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(429)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(500)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isError); + Assert.IsTrue(isRetriable); + } + #region Helpers - internal class SingleStatusCodeClassifier : PipelineMessageClassifier + internal class ErrorStatusCodeClassifier : PipelineMessageClassifier { private readonly (int, bool) _code; - public SingleStatusCodeClassifier(int code, bool isError) + public ErrorStatusCodeClassifier(int code, bool isError) { _code = (code, isError); } @@ -121,6 +218,28 @@ public override bool TryClassify(PipelineMessage message, out bool isError) } } + internal class RetriableStatussCodeClassifier : PipelineMessageClassifier + { + private readonly (int, bool) _code; + + public RetriableStatussCodeClassifier(int code, bool isRetriable) + { + _code = (code, isRetriable); + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + if (message.Response!.Status == _code.Item1) + { + isRetriable = _code.Item2; + return true; + } + + isRetriable = false; + return false; + } + } + internal class ChainingClassifier : PipelineMessageClassifier { private readonly List _classifiers; @@ -149,6 +268,25 @@ public override bool TryClassify(PipelineMessage message, out bool isError) return _endOfChain.TryClassify(message, out isError); } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + foreach (var classifier in _classifiers) + { + if (classifier.TryClassify(message, exception, out isRetriable)) + { + return true; + } + } + + if (!_endOfChain.TryClassify(message, exception, out isRetriable)) + { + bool classified = Default.TryClassify(message, exception, out isRetriable); + Debug.Assert(classified); + } + + return true; + } } #endregion From ccd192bc4a63ec6b35db6e7baa54a288faf21606 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Fri, 26 Jan 2024 09:03:24 -0800 Subject: [PATCH 4/6] fix --- sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs index 7f5ed0368ccac..2a8cd155c9bd1 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs @@ -165,7 +165,7 @@ protected virtual bool ShouldRetryCore(PipelineMessage message, Exception? excep return false; } - if (message.MessageClassifier.TryClassify(message, exception, out bool isRetriable)) + if (!message.MessageClassifier.TryClassify(message, exception, out bool isRetriable)) { bool classified = PipelineMessageClassifier.Default.TryClassify(message, exception, out isRetriable); From 105822a54e337fb03511db8ba71be6fd0c6d3e93 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Fri, 26 Jan 2024 09:07:50 -0800 Subject: [PATCH 5/6] nits --- .../tests/Options/PipelineMessageClassifierTests.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs index edc99665e8cea..ee0b701088409 100644 --- a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs +++ b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs @@ -104,9 +104,9 @@ public void CanComposeRetryClassifiers() var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); ChainingClassifier classifier = new ChainingClassifier(last); - classifier.AddClassifier(new RetriableStatussCodeClassifier(403, isRetriable: false)); - classifier.AddClassifier(new RetriableStatussCodeClassifier(404, isRetriable: false)); - classifier.AddClassifier(new RetriableStatussCodeClassifier(201, isRetriable: true)); + classifier.AddClassifier(new RetriableStatusCodeClassifier(403, isRetriable: false)); + classifier.AddClassifier(new RetriableStatusCodeClassifier(404, isRetriable: false)); + classifier.AddClassifier(new RetriableStatusCodeClassifier(201, isRetriable: true)); MockPipelineMessage message = new(); @@ -145,7 +145,7 @@ public void CanComposeErrorAndRetryClassifiers() var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); ChainingClassifier classifier = new ChainingClassifier(last); - classifier.AddClassifier(new RetriableStatussCodeClassifier(429, isRetriable: false)); + classifier.AddClassifier(new RetriableStatusCodeClassifier(429, isRetriable: false)); classifier.AddClassifier(new ErrorStatusCodeClassifier(404, isError: false)); classifier.AddClassifier(new ErrorStatusCodeClassifier(201, isError: true)); @@ -218,11 +218,11 @@ public override bool TryClassify(PipelineMessage message, out bool isError) } } - internal class RetriableStatussCodeClassifier : PipelineMessageClassifier + internal class RetriableStatusCodeClassifier : PipelineMessageClassifier { private readonly (int, bool) _code; - public RetriableStatussCodeClassifier(int code, bool isRetriable) + public RetriableStatusCodeClassifier(int code, bool isRetriable) { _code = (code, isRetriable); } From 24f9fde30b619fafcd31853dd952564ee0ab7a1f Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Fri, 26 Jan 2024 09:59:47 -0800 Subject: [PATCH 6/6] Make TryClassify methods abstract --- .../api/System.ClientModel.net6.0.cs | 6 +++--- .../api/System.ClientModel.netstandard2.0.cs | 6 +++--- .../src/Options/PipelineMessageClassifier.cs | 14 +++----------- .../src/Options/ResponseStatusClassifier.cs | 13 ++++++++++++- .../Options/PipelineMessageClassifierTests.cs | 12 ++++++++++++ .../TestFramework/Mocks/MockMessageClassifier.cs | 8 +++++++- 6 files changed, 40 insertions(+), 19 deletions(-) diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index f4627ed0d5f28..19ea53a04a078 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -157,13 +157,13 @@ protected virtual void Dispose(bool disposing) { } public void SetProperty(System.Type type, object value) { } public bool TryGetProperty(System.Type type, out object? value) { throw null; } } - public partial class PipelineMessageClassifier + public abstract partial class PipelineMessageClassifier { protected PipelineMessageClassifier() { } public static System.ClientModel.Primitives.PipelineMessageClassifier Default { get { throw null; } } public static System.ClientModel.Primitives.PipelineMessageClassifier Create(System.ReadOnlySpan successStatusCodes) { throw null; } - public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError) { throw null; } - public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable) { throw null; } + public abstract bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError); + public abstract bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable); } public abstract partial class PipelinePolicy { diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index ef49adc323eb1..1d5b1cb1f78b6 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -156,13 +156,13 @@ protected virtual void Dispose(bool disposing) { } public void SetProperty(System.Type type, object value) { } public bool TryGetProperty(System.Type type, out object? value) { throw null; } } - public partial class PipelineMessageClassifier + public abstract partial class PipelineMessageClassifier { protected PipelineMessageClassifier() { } public static System.ClientModel.Primitives.PipelineMessageClassifier Default { get { throw null; } } public static System.ClientModel.Primitives.PipelineMessageClassifier Create(System.ReadOnlySpan successStatusCodes) { throw null; } - public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError) { throw null; } - public virtual bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable) { throw null; } + public abstract bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError); + public abstract bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable); } public abstract partial class PipelinePolicy { diff --git a/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs b/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs index 07677034cc4d1..4fd20bcb75700 100644 --- a/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs +++ b/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs @@ -5,7 +5,7 @@ namespace System.ClientModel.Primitives; -public class PipelineMessageClassifier +public abstract class PipelineMessageClassifier { public static PipelineMessageClassifier Default { get; } = new EndOfChainClassifier(); @@ -14,17 +14,9 @@ public static PipelineMessageClassifier Create(ReadOnlySpan successStatu protected PipelineMessageClassifier() { } - public virtual bool TryClassify(PipelineMessage message, out bool isError) - { - isError = false; - return false; - } + public abstract bool TryClassify(PipelineMessage message, out bool isError); - public virtual bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) - { - isRetriable = false; - return false; - } + public abstract bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable); internal class EndOfChainClassifier : PipelineMessageClassifier { diff --git a/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs b/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs index 7caf652b36374..6ea5a7e3205fe 100644 --- a/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs +++ b/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.ClientModel.Internal; +using System.Diagnostics; namespace System.ClientModel.Primitives; @@ -30,7 +31,17 @@ public override bool TryClassify(PipelineMessage message, out bool isError) isError = !_successCodes[message.Response!.Status]; - // BitVector-based classifiers should always end any composition chain. + // BitVector-based classifiers should always end any chain. + return true; + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + bool classified = Default.TryClassify(message, exception, out isRetriable); + + Debug.Assert(classified); + + // BitVector-based classifiers should always end any chain. return true; } diff --git a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs index ee0b701088409..0f53404e2ad89 100644 --- a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs +++ b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs @@ -216,6 +216,12 @@ public override bool TryClassify(PipelineMessage message, out bool isError) isError = false; return false; } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + isRetriable = false; + return false; + } } internal class RetriableStatusCodeClassifier : PipelineMessageClassifier @@ -227,6 +233,12 @@ public RetriableStatusCodeClassifier(int code, bool isRetriable) _code = (code, isRetriable); } + public override bool TryClassify(PipelineMessage message, out bool isError) + { + isError = false; + return false; + } + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) { if (message.Response!.Status == _code.Item1) diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs index 540a6f0049d18..36f929a6aa2cc 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.ClientModel.Primitives; namespace ClientModel.Tests.Mocks; @@ -46,6 +47,11 @@ public override bool TryClassify(PipelineMessage message, out bool isError) return true; } - return base.TryClassify(message, out isError); + return Default.TryClassify(message, out isError); + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + return Default.TryClassify(message, exception, out isRetriable); } }