diff --git a/src/EventStore.Client/EventStoreClientBase.cs b/src/EventStore.Client/EventStoreClientBase.cs index a091d826d..ba01d112a 100644 --- a/src/EventStore.Client/EventStoreClientBase.cs +++ b/src/EventStore.Client/EventStoreClientBase.cs @@ -21,7 +21,7 @@ public abstract class EventStoreClientBase : private readonly IDictionary> _exceptionMap; private readonly CancellationTokenSource _cts; private readonly ChannelCache _channelCache; - private readonly SharingProvider _channelInfoProvider; + private readonly SharingProvider _channelInfoProvider; /// /// The name of the connection. @@ -48,27 +48,31 @@ protected EventStoreClientBase(EventStoreClientSettings? settings, ConnectionName = Settings.ConnectionName ?? $"ES-{Guid.NewGuid()}"; var channelSelector = new ChannelSelector(Settings, _channelCache); - _channelInfoProvider = new SharingProvider( - factory: (endPoint, onBroken) => GetChannelInfoExpensive(endPoint, onBroken, channelSelector, _cts.Token), - initialInput: null); + _channelInfoProvider = new SharingProvider( + factory: (endPoint, onBroken) => + GetChannelInfoExpensive(endPoint, onBroken, channelSelector, _cts.Token), + initialInput: ReconnectionRequired.Rediscover.Instance); } // Select a channel and query its capabilities. This is an expensive call that // we don't want to do often. private async Task GetChannelInfoExpensive( - DnsEndPoint? endPoint, - Action onBroken, + ReconnectionRequired reconnectionRequired, + Action onReconnectionRequired, IChannelSelector channelSelector, CancellationToken cancellationToken) { - var channel = endPoint is null - ? await channelSelector.SelectChannelAsync(cancellationToken).ConfigureAwait(false) - : channelSelector.SelectChannel(endPoint); + var channel = reconnectionRequired switch { + ReconnectionRequired.Rediscover => await channelSelector.SelectChannelAsync(cancellationToken) + .ConfigureAwait(false), + ReconnectionRequired.NewLeader (var endPoint) => channelSelector.SelectChannel(endPoint), + _ => throw new ArgumentException(null, nameof(reconnectionRequired)) + }; var invoker = channel.CreateCallInvoker() .Intercept(new TypedExceptionInterceptor(_exceptionMap)) .Intercept(new ConnectionNameInterceptor(ConnectionName)) - .Intercept(new ReportLeaderInterceptor(onBroken)); + .Intercept(new ReportLeaderInterceptor(onReconnectionRequired)); if (Settings.Interceptors is not null) { foreach (var interceptor in Settings.Interceptors) { @@ -92,6 +96,7 @@ protected async ValueTask GetChannelInfo(CancellationToken cancella // in cases where the server doesn't yet let the client know that it needs to. // see EventStoreClientExtensions.WarmUpWith. // note if rediscovery is already in progress it will continue, not restart. + // ReSharper disable once UnusedMember.Local private void Rediscover() { _channelInfoProvider.Reset(); } diff --git a/src/EventStore.Client/Interceptors/ReportLeaderInterceptor.cs b/src/EventStore.Client/Interceptors/ReportLeaderInterceptor.cs index 58bdfbe66..45841c61e 100644 --- a/src/EventStore.Client/Interceptors/ReportLeaderInterceptor.cs +++ b/src/EventStore.Client/Interceptors/ReportLeaderInterceptor.cs @@ -1,5 +1,4 @@ using System; -using System.Net; using System.Threading; using System.Threading.Tasks; using Grpc.Core; @@ -10,13 +9,13 @@ namespace EventStore.Client.Interceptors { // this has become more general than just detecting leader changes. // triggers the action on any rpc exception with StatusCode.Unavailable internal class ReportLeaderInterceptor : Interceptor { - private readonly Action _onError; + private readonly Action _onReconnectionRequired; private const TaskContinuationOptions ContinuationOptions = TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted; - internal ReportLeaderInterceptor(Action onError) { - _onError = onError; + internal ReportLeaderInterceptor(Action onReconnectionRequired) { + _onReconnectionRequired = onReconnectionRequired; } public override AsyncUnaryCall AsyncUnaryCall(TRequest request, @@ -24,7 +23,7 @@ public override AsyncUnaryCall AsyncUnaryCall(TR AsyncUnaryCallContinuation continuation) { var response = continuation(request, context); - response.ResponseAsync.ContinueWith(ReportNewLeader, ContinuationOptions); + response.ResponseAsync.ContinueWith(OnReconnectionRequired, ContinuationOptions); return new AsyncUnaryCall(response.ResponseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); @@ -35,7 +34,7 @@ public override AsyncClientStreamingCall AsyncClientStreami AsyncClientStreamingCallContinuation continuation) { var response = continuation(context); - response.ResponseAsync.ContinueWith(ReportNewLeader, ContinuationOptions); + response.ResponseAsync.ContinueWith(OnReconnectionRequired, ContinuationOptions); return new AsyncClientStreamingCall(response.RequestStream, response.ResponseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); @@ -47,7 +46,8 @@ public override AsyncDuplexStreamingCall AsyncDuplexStreami var response = continuation(context); return new AsyncDuplexStreamingCall(response.RequestStream, - new StreamReader(response.ResponseStream, ReportNewLeader), response.ResponseHeadersAsync, + new StreamReader(response.ResponseStream, OnReconnectionRequired), + response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } @@ -57,20 +57,23 @@ public override AsyncServerStreamingCall AsyncServerStreamingCall( - new StreamReader(response.ResponseStream, ReportNewLeader), response.ResponseHeadersAsync, + new StreamReader(response.ResponseStream, OnReconnectionRequired), + response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } - private void ReportNewLeader(Task task) { - if (task.Exception?.InnerException is NotLeaderException ex) { - _onError(ex.LeaderEndpoint); - } else if (task.Exception?.InnerException is RpcException { - StatusCode: StatusCode.Unavailable or - // StatusCode.Unknown or TODO: use RPC exceptions on server - StatusCode.Aborted - }) { - _onError(null); - } + private void OnReconnectionRequired(Task task) { + ReconnectionRequired reconnectionRequired = task.Exception?.InnerException switch { + NotLeaderException ex => new ReconnectionRequired.NewLeader(ex.LeaderEndpoint), + RpcException { + StatusCode: StatusCode.Unavailable + // or StatusCode.Unknown or TODO: use RPC exceptions on server + } => ReconnectionRequired.Rediscover.Instance, + _ => ReconnectionRequired.None.Instance + }; + + if (reconnectionRequired is not ReconnectionRequired.None) + _onReconnectionRequired(reconnectionRequired); } private class StreamReader : IAsyncStreamReader { diff --git a/src/EventStore.Client/ReconnectionRequired.cs b/src/EventStore.Client/ReconnectionRequired.cs new file mode 100644 index 000000000..bf448971d --- /dev/null +++ b/src/EventStore.Client/ReconnectionRequired.cs @@ -0,0 +1,15 @@ +using System.Net; + +namespace EventStore.Client { + internal abstract record ReconnectionRequired { + public record None : ReconnectionRequired { + public static None Instance = new(); + } + + public record Rediscover : ReconnectionRequired { + public static Rediscover Instance = new(); + } + + public record NewLeader(DnsEndPoint EndPoint) : ReconnectionRequired; + } +} diff --git a/test/EventStore.Client.Tests/Interceptors/ReportLeaderInterceptorTests.cs b/test/EventStore.Client.Tests/Interceptors/ReportLeaderInterceptorTests.cs index f36599e2b..c69237a3d 100644 --- a/test/EventStore.Client.Tests/Interceptors/ReportLeaderInterceptorTests.cs +++ b/test/EventStore.Client.Tests/Interceptors/ReportLeaderInterceptorTests.cs @@ -15,7 +15,6 @@ public class ReportLeaderInterceptorTests { private static readonly Marshaller _marshaller = new(_ => Array.Empty(), _ => new object()); private static readonly StatusCode[] ForcesRediscoveryStatusCodes = { - StatusCode.Aborted, //StatusCode.Unknown, TODO: use RPC exceptions on server StatusCode.Unavailable }; @@ -32,12 +31,12 @@ private static IEnumerable GrpcCalls() { [Theory, MemberData(nameof(ReportsNewLeaderCases))] public async Task ReportsNewLeader(GrpcCall call) { - EndPoint actual = default; - var sut = new ReportLeaderInterceptor(ep => actual = ep); + ReconnectionRequired actual = default; + var sut = new ReportLeaderInterceptor(result => actual = result); var result = await Assert.ThrowsAsync(() => call(sut, Task.FromException(new NotLeaderException("a.host", 2112)))); - Assert.Equal(result.LeaderEndpoint, actual); + Assert.Equal(new ReconnectionRequired.NewLeader(result.LeaderEndpoint), actual); } public static IEnumerable ForcesRediscoveryCases() => from call in GrpcCalls() @@ -46,18 +45,12 @@ from statusCode in ForcesRediscoveryStatusCodes [Theory, MemberData(nameof(ForcesRediscoveryCases))] public async Task ForcesRediscovery(GrpcCall call, StatusCode statusCode) { - EndPoint actual = default; - bool invoked = false; + ReconnectionRequired actual = default; + var sut = new ReportLeaderInterceptor(result => actual = result); - var sut = new ReportLeaderInterceptor(ep => { - invoked = true; - actual = ep; - }); - - var result = await Assert.ThrowsAsync(() => call(sut, + await Assert.ThrowsAsync(() => call(sut, Task.FromException(new RpcException(new Status(statusCode, "oops"))))); - Assert.Null(actual); - Assert.True(invoked); + Assert.Equal(ReconnectionRequired.Rediscover.Instance, actual); } public static IEnumerable DoesNotForceRediscoveryCases() => from call in GrpcCalls() @@ -68,12 +61,12 @@ from statusCode in Enum.GetValues(typeof(StatusCode)) [Theory, MemberData(nameof(DoesNotForceRediscoveryCases))] public async Task DoesNotForceRediscovery(GrpcCall call, StatusCode statusCode) { - bool invoked = false; - var sut = new ReportLeaderInterceptor(ep => invoked = true); + ReconnectionRequired actual = ReconnectionRequired.None.Instance; + var sut = new ReportLeaderInterceptor(result => actual = result); - var result = await Assert.ThrowsAsync(() => call(sut, + await Assert.ThrowsAsync(() => call(sut, Task.FromException(new RpcException(new Status(statusCode, "oops"))))); - Assert.False(invoked); + Assert.Equal(ReconnectionRequired.None.Instance, actual); }