From ad9610cc6c27f1c047da1c9cf0d04dd249b64b7f Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Fri, 26 Jan 2024 13:25:55 -0800 Subject: [PATCH] ClientModel: Add retry classification to MessageClassifier (#41586) * Add IsRetriable to MessageClassifier * simplify APIs * fix * fix * nits * Make TryClassify methods abstract --- .../api/System.ClientModel.net6.0.cs | 10 +- .../api/System.ClientModel.netstandard2.0.cs | 10 +- .../src/Message/PipelineMessage.cs | 3 +- .../src/Options/PipelineMessageClassifier.cs | 78 +++++- .../src/Options/RequestOptions.cs | 6 - .../src/Options/ResponseStatusClassifier.cs | 67 +++-- .../src/Pipeline/ClientRetryPolicy.cs | 57 +--- .../src/Pipeline/PipelineTransport.cs | 14 +- .../Options/PipelineMessageClassifierTests.cs | 256 +++++++++++++++++- .../Mocks/MockMessageClassifier.cs | 14 +- 10 files changed, 406 insertions(+), 109 deletions(-) diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index e21f7d222a0fc..19ea53a04a078 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -147,7 +147,7 @@ public partial class PipelineMessage : System.IDisposable protected internal PipelineMessage(System.ClientModel.Primitives.PipelineRequest request) { } public bool BufferResponse { get { throw null; } set { } } public System.Threading.CancellationToken CancellationToken { get { throw null; } protected internal set { } } - public System.ClientModel.Primitives.PipelineMessageClassifier? MessageClassifier { get { throw null; } set { } } + public System.ClientModel.Primitives.PipelineMessageClassifier MessageClassifier { get { throw null; } set { } } public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineRequest Request { get { throw null; } } public System.ClientModel.Primitives.PipelineResponse? Response { get { throw null; } protected internal set { } } @@ -157,11 +157,13 @@ protected virtual void Dispose(bool disposing) { } public void SetProperty(System.Type type, object value) { } public bool TryGetProperty(System.Type type, out object? value) { throw null; } } - public partial class PipelineMessageClassifier + public abstract partial class PipelineMessageClassifier { - protected internal PipelineMessageClassifier() { } + protected PipelineMessageClassifier() { } + public static System.ClientModel.Primitives.PipelineMessageClassifier Default { get { throw null; } } public static System.ClientModel.Primitives.PipelineMessageClassifier Create(System.ReadOnlySpan successStatusCodes) { throw null; } - public virtual bool IsErrorResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } + public abstract bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError); + public abstract bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable); } public abstract partial class PipelinePolicy { diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index cf21ffff8255e..1d5b1cb1f78b6 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -146,7 +146,7 @@ public partial class PipelineMessage : System.IDisposable protected internal PipelineMessage(System.ClientModel.Primitives.PipelineRequest request) { } public bool BufferResponse { get { throw null; } set { } } public System.Threading.CancellationToken CancellationToken { get { throw null; } protected internal set { } } - public System.ClientModel.Primitives.PipelineMessageClassifier? MessageClassifier { get { throw null; } set { } } + public System.ClientModel.Primitives.PipelineMessageClassifier MessageClassifier { get { throw null; } set { } } public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineRequest Request { get { throw null; } } public System.ClientModel.Primitives.PipelineResponse? Response { get { throw null; } protected internal set { } } @@ -156,11 +156,13 @@ protected virtual void Dispose(bool disposing) { } public void SetProperty(System.Type type, object value) { } public bool TryGetProperty(System.Type type, out object? value) { throw null; } } - public partial class PipelineMessageClassifier + public abstract partial class PipelineMessageClassifier { - protected internal PipelineMessageClassifier() { } + protected PipelineMessageClassifier() { } + public static System.ClientModel.Primitives.PipelineMessageClassifier Default { get { throw null; } } public static System.ClientModel.Primitives.PipelineMessageClassifier Create(System.ReadOnlySpan successStatusCodes) { throw null; } - public virtual bool IsErrorResponse(System.ClientModel.Primitives.PipelineMessage message) { throw null; } + public abstract bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, out bool isError); + public abstract bool TryClassify(System.ClientModel.Primitives.PipelineMessage message, System.Exception? exception, out bool isRetriable); } public abstract partial class PipelinePolicy { diff --git a/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs b/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs index 736436adbee9a..d2be344d7ca52 100644 --- a/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs +++ b/sdk/core/System.ClientModel/src/Message/PipelineMessage.cs @@ -19,6 +19,7 @@ protected internal PipelineMessage(PipelineRequest request) _propertyBag = new ArrayBackedPropertyBag(); BufferResponse = true; + MessageClassifier = PipelineMessageClassifier.Default; } public PipelineRequest Request { get; } @@ -64,7 +65,7 @@ public CancellationToken CancellationToken // the client-provided classifier or compose a chain of classification // handlers that preserve the functionality of the client-provided classifier // at the end of the chain. - public PipelineMessageClassifier? MessageClassifier { get; set; } + public PipelineMessageClassifier MessageClassifier { get; set; } public void Apply(RequestOptions options) { diff --git a/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs b/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs index 6bbe849041063..4fd20bcb75700 100644 --- a/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs +++ b/sdk/core/System.ClientModel/src/Options/PipelineMessageClassifier.cs @@ -1,25 +1,83 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.IO; + namespace System.ClientModel.Primitives; -public class PipelineMessageClassifier +public abstract class PipelineMessageClassifier { - internal static PipelineMessageClassifier Default { get; } = new PipelineMessageClassifier(); + public static PipelineMessageClassifier Default { get; } = new EndOfChainClassifier(); public static PipelineMessageClassifier Create(ReadOnlySpan successStatusCodes) => new ResponseStatusClassifier(successStatusCodes); - protected internal PipelineMessageClassifier() { } + protected PipelineMessageClassifier() { } + + public abstract bool TryClassify(PipelineMessage message, out bool isError); - /// - /// Specifies if the response contained in the is not successful. - /// - public virtual bool IsErrorResponse(PipelineMessage message) + public abstract bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable); + + internal class EndOfChainClassifier : PipelineMessageClassifier { - message.AssertResponse(); + public override bool TryClassify(PipelineMessage message, out bool isError) + { + message.AssertResponse(); + + int statusKind = message.Response!.Status / 100; + isError = statusKind == 4 || statusKind == 5; + + // Always classify the message + return true; + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + isRetriable = exception is null ? + IsRetriable(message) : + IsRetriable(message, exception); + + // Always classify the message + return true; + } + + private static bool IsRetriable(PipelineMessage message) + { + message.AssertResponse(); + + return message.Response!.Status switch + { + // Request Timeout + 408 => true, + + // Too Many Requests + 429 => true, + + // Internal Server Error + 500 => true, + + // Bad Gateway + 502 => true, + + // Service Unavailable + 503 => true, + + // Gateway Timeout + 504 => true, + + // Default case + _ => false + }; + } + + private static bool IsRetriable(PipelineMessage message, Exception exception) + => IsRetriable(exception) || + // Retry non-user initiated cancellations + (exception is OperationCanceledException && + !message.CancellationToken.IsCancellationRequested); - int statusKind = message.Response!.Status / 100; - return statusKind == 4 || statusKind == 5; + private static bool IsRetriable(Exception exception) + => (exception is IOException) || + (exception is ClientResultException ex && ex.Status == 0); } } diff --git a/sdk/core/System.ClientModel/src/Options/RequestOptions.cs b/sdk/core/System.ClientModel/src/Options/RequestOptions.cs index 01ded73811162..71e7f8f56c2e1 100644 --- a/sdk/core/System.ClientModel/src/Options/RequestOptions.cs +++ b/sdk/core/System.ClientModel/src/Options/RequestOptions.cs @@ -106,12 +106,6 @@ internal void Apply(PipelineMessage message) // cancellation token will be set again in HttpPipeline.Send. message.CancellationToken = CancellationToken; - // We don't overwrite the classifier on the message if it's already set. - // This preserves any values set by the client author, and is also - // needed for Azure.Core-based clients so we don't overwrite a default - // Azure.Core ResponseClassifier. - message.MessageClassifier ??= PipelineMessageClassifier.Default; - // Copy custom pipeline policies to the message. message.PerCallPolicies = _perCallPolicies; message.PerTryPolicies = _perTryPolicies; diff --git a/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs b/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs index 44138c8d89fb4..6ea5a7e3205fe 100644 --- a/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs +++ b/sdk/core/System.ClientModel/src/Options/ResponseStatusClassifier.cs @@ -2,40 +2,53 @@ // Licensed under the MIT License. using System.ClientModel.Internal; +using System.Diagnostics; -namespace System.ClientModel.Primitives +namespace System.ClientModel.Primitives; + +internal class ResponseStatusClassifier : PipelineMessageClassifier { - internal class ResponseStatusClassifier : PipelineMessageClassifier + private BitVector640 _successCodes; + + /// + /// Creates a new instance of . + /// + /// The status codes that this classifier + /// will consider not to be errors. + public ResponseStatusClassifier(ReadOnlySpan successStatusCodes) { - private BitVector640 _successCodes; - - /// - /// Creates a new instance of . - /// - /// The status codes that this classifier - /// will consider not to be errors. - public ResponseStatusClassifier(ReadOnlySpan successStatusCodes) - { - _successCodes = new(); + _successCodes = new(); - foreach (int statusCode in successStatusCodes) - { - AddClassifier(statusCode, isError: false); - } + foreach (int statusCode in successStatusCodes) + { + AddClassifier(statusCode, isError: false); } + } - public sealed override bool IsErrorResponse(PipelineMessage message) - { - message.AssertResponse(); + public override bool TryClassify(PipelineMessage message, out bool isError) + { + message.AssertResponse(); - return !_successCodes[message.Response!.Status]; - } + isError = !_successCodes[message.Response!.Status]; - private void AddClassifier(int statusCode, bool isError) - { - Argument.AssertInRange(statusCode, 0, 639, nameof(statusCode)); + // BitVector-based classifiers should always end any chain. + return true; + } - _successCodes[statusCode] = !isError; - } + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + bool classified = Default.TryClassify(message, exception, out isRetriable); + + Debug.Assert(classified); + + // BitVector-based classifiers should always end any chain. + return true; + } + + private void AddClassifier(int statusCode, bool isError) + { + Argument.AssertInRange(statusCode, 0, 639, nameof(statusCode)); + + _successCodes[statusCode] = !isError; } -} \ No newline at end of file +} diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs index a6690c30b3bce..2a8cd155c9bd1 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs @@ -3,7 +3,7 @@ using System.ClientModel.Internal; using System.Collections.Generic; -using System.IO; +using System.Diagnostics; using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; @@ -165,9 +165,14 @@ protected virtual bool ShouldRetryCore(PipelineMessage message, Exception? excep return false; } - return exception is null ? - IsRetriable(message) : - IsRetriable(message, exception); + if (!message.MessageClassifier.TryClassify(message, exception, out bool isRetriable)) + { + bool classified = PipelineMessageClassifier.Default.TryClassify(message, exception, out isRetriable); + + Debug.Assert(classified); + } + + return isRetriable; } protected virtual ValueTask ShouldRetryCoreAsync(PipelineMessage message, Exception? exception) @@ -200,48 +205,4 @@ protected virtual void WaitCore(TimeSpan time, CancellationToken cancellationTok CancellationHelper.ThrowIfCancellationRequested(cancellationToken); } } - - #region Retry Classifier - - // Overriding response-retriable classification will be added in a later ClientModel release. - private static bool IsRetriable(PipelineMessage message) - { - message.AssertResponse(); - - return message.Response!.Status switch - { - // Request Timeout - 408 => true, - - // Too Many Requests - 429 => true, - - // Internal Server Error - 500 => true, - - // Bad Gateway - 502 => true, - - // Service Unavailable - 503 => true, - - // Gateway Timeout - 504 => true, - - // Default case - _ => false - }; - } - - private static bool IsRetriable(PipelineMessage message, Exception exception) - => IsRetriable(exception) || - // Retry non-user initiated cancellations - (exception is OperationCanceledException && - !message.CancellationToken.IsCancellationRequested); - - private static bool IsRetriable(Exception exception) - => (exception is IOException) || - (exception is ClientResultException ex && ex.Status == 0); - - #endregion } diff --git a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs index 43e974cb885e8..a590c400d2893 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs @@ -41,9 +41,17 @@ public async ValueTask ProcessAsync(PipelineMessage message) message.Response.SetIsError(ClassifyResponse(message)); } - private static bool ClassifyResponse(PipelineMessage message) => - message.MessageClassifier?.IsErrorResponse(message) ?? - PipelineMessageClassifier.Default.IsErrorResponse(message); + private static bool ClassifyResponse(PipelineMessage message) + { + if (!message.MessageClassifier.TryClassify(message, out bool isError)) + { + bool classified = PipelineMessageClassifier.Default.TryClassify(message, out isError); + + Debug.Assert(classified); + } + + return isError; + } protected abstract void ProcessCore(PipelineMessage message); diff --git a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs index a9b2d6e217e91..0f53404e2ad89 100644 --- a/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs +++ b/sdk/core/System.ClientModel/tests/Options/PipelineMessageClassifierTests.cs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Diagnostics; using ClientModel.Tests.Mocks; using NUnit.Framework; -using System.ClientModel.Primitives; namespace System.ClientModel.Tests.Options; @@ -23,7 +25,8 @@ public void ClassifiesSingleCodeAsNonError() MockPipelineMessage message = new MockPipelineMessage(); message.SetResponse(new MockPipelineResponse(code)); - bool isNonError = !classifier.IsErrorResponse(message); + classifier.TryClassify(message, out bool isError); + bool isNonError = !isError; if (nonError == code) { @@ -50,6 +53,253 @@ public void ClassifiesMultipleCodesAsNonErrors(int code, bool isError) MockPipelineMessage message = new MockPipelineMessage(); message.SetResponse(new MockPipelineResponse(code)); - Assert.AreEqual(isError, classifier.IsErrorResponse(message)); + Assert.IsTrue(classifier.TryClassify(message, out bool error)); + Assert.AreEqual(isError, error); + } + + [Test] + public void CanComposeErrorClassifiers() + { + var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); + + ChainingClassifier classifier = new ChainingClassifier(last); + classifier.AddClassifier(new ErrorStatusCodeClassifier(403, isError: false)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(404, isError: false)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(201, isError: true)); + + MockPipelineMessage message = new(); + + message.SetResponse(new MockPipelineResponse(200)); + Assert.IsTrue(classifier.TryClassify(message, out bool isError)); + Assert.IsFalse(isError); + + message.SetResponse(new MockPipelineResponse(201)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(isError); + + message.SetResponse(new MockPipelineResponse(204)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsFalse(isError); + + message.SetResponse(new MockPipelineResponse(304)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(isError); + + message.SetResponse(new MockPipelineResponse(403)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsFalse(isError); + + message.SetResponse(new MockPipelineResponse(404)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsFalse(isError); + + message.SetResponse(new MockPipelineResponse(500)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(isError); } + + [Test] + public void CanComposeRetryClassifiers() + { + var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); + + ChainingClassifier classifier = new ChainingClassifier(last); + classifier.AddClassifier(new RetriableStatusCodeClassifier(403, isRetriable: false)); + classifier.AddClassifier(new RetriableStatusCodeClassifier(404, isRetriable: false)); + classifier.AddClassifier(new RetriableStatusCodeClassifier(201, isRetriable: true)); + + MockPipelineMessage message = new(); + + message.SetResponse(new MockPipelineResponse(200)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out bool isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(201)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isRetriable); + + message.SetResponse(new MockPipelineResponse(204)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(304)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(403)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(404)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(500)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isRetriable); + } + + [Test] + public void CanComposeErrorAndRetryClassifiers() + { + var last = PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201, 204 }); + + ChainingClassifier classifier = new ChainingClassifier(last); + classifier.AddClassifier(new RetriableStatusCodeClassifier(429, isRetriable: false)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(404, isError: false)); + classifier.AddClassifier(new ErrorStatusCodeClassifier(201, isError: true)); + + MockPipelineMessage message = new(); + + message.SetResponse(new MockPipelineResponse(200)); + Assert.IsTrue(classifier.TryClassify(message, out bool isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out bool isRetriable)); + Assert.IsFalse(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(201)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(204)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(304)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(404)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsFalse(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(429)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isError); + Assert.IsFalse(isRetriable); + + message.SetResponse(new MockPipelineResponse(500)); + Assert.IsTrue(classifier.TryClassify(message, out isError)); + Assert.IsTrue(classifier.TryClassify(message, exception: default, out isRetriable)); + Assert.IsTrue(isError); + Assert.IsTrue(isRetriable); + } + + #region Helpers + + internal class ErrorStatusCodeClassifier : PipelineMessageClassifier + { + private readonly (int, bool) _code; + + public ErrorStatusCodeClassifier(int code, bool isError) + { + _code = (code, isError); + } + + public override bool TryClassify(PipelineMessage message, out bool isError) + { + if (message.Response!.Status == _code.Item1) + { + isError = _code.Item2; + return true; + } + + isError = false; + return false; + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + isRetriable = false; + return false; + } + } + + internal class RetriableStatusCodeClassifier : PipelineMessageClassifier + { + private readonly (int, bool) _code; + + public RetriableStatusCodeClassifier(int code, bool isRetriable) + { + _code = (code, isRetriable); + } + + public override bool TryClassify(PipelineMessage message, out bool isError) + { + isError = false; + return false; + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + if (message.Response!.Status == _code.Item1) + { + isRetriable = _code.Item2; + return true; + } + + isRetriable = false; + return false; + } + } + + internal class ChainingClassifier : PipelineMessageClassifier + { + private readonly List _classifiers; + private readonly PipelineMessageClassifier _endOfChain; + + public ChainingClassifier(PipelineMessageClassifier endOfChain) + { + _classifiers = new(); + _endOfChain = endOfChain; + } + + public void AddClassifier(PipelineMessageClassifier classifier) + { + _classifiers.Add(classifier); + } + + public override bool TryClassify(PipelineMessage message, out bool isError) + { + foreach (var classifier in _classifiers) + { + if (classifier.TryClassify(message, out isError)) + { + return true; + } + } + + return _endOfChain.TryClassify(message, out isError); + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + foreach (var classifier in _classifiers) + { + if (classifier.TryClassify(message, exception, out isRetriable)) + { + return true; + } + } + + if (!_endOfChain.TryClassify(message, exception, out isRetriable)) + { + bool classified = Default.TryClassify(message, exception, out isRetriable); + Debug.Assert(classified); + } + + return true; + } + } + + #endregion } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs index de2ac4cabcf99..36f929a6aa2cc 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockMessageClassifier.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.ClientModel.Primitives; namespace ClientModel.Tests.Mocks; @@ -29,7 +30,7 @@ public MockMessageClassifier(string id, int[]? successCodes) public string Id { get; set; } - public override bool IsErrorResponse(PipelineMessage message) + public override bool TryClassify(PipelineMessage message, out bool isError) { if (_successCodes is not null) { @@ -37,13 +38,20 @@ public override bool IsErrorResponse(PipelineMessage message) { if (message.Response!.Status == code) { + isError = true; return true; } } - return false; + isError = false; + return true; } - return base.IsErrorResponse(message); + return Default.TryClassify(message, out isError); + } + + public override bool TryClassify(PipelineMessage message, Exception? exception, out bool isRetriable) + { + return Default.TryClassify(message, exception, out isRetriable); } }