diff --git a/src/main/java/com/google/devtools/build/lib/remote/Retrier.java b/src/main/java/com/google/devtools/build/lib/remote/Retrier.java index 457880268764d5..6647b155465cb7 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/Retrier.java +++ b/src/main/java/com/google/devtools/build/lib/remote/Retrier.java @@ -23,6 +23,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.remote.Retrier.CircuitBreaker.State; import java.io.IOException; +import java.util.Objects; import java.util.concurrent.Callable; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; @@ -100,7 +101,7 @@ enum State { State state(); /** Called after an execution failed. */ - void recordFailure(Exception e); + void recordFailure(); /** Called after an execution succeeded. */ void recordSuccess(); @@ -130,7 +131,7 @@ public State state() { } @Override - public void recordFailure(Exception e) {} + public void recordFailure() {} @Override public void recordSuccess() {} @@ -245,12 +246,14 @@ public T execute(Callable call, Backoff backoff) throws Exception { circuitBreaker.recordSuccess(); return r; } catch (Exception e) { - circuitBreaker.recordFailure(e); Throwables.throwIfInstanceOf(e, InterruptedException.class); - if (State.TRIAL_CALL.equals(circuitState)) { + if (!shouldRetry.test(e)) { + // A non-retriable error doesn't represent server failure. + circuitBreaker.recordSuccess(); throw e; } - if (!shouldRetry.test(e)) { + circuitBreaker.recordFailure(); + if (Objects.equals(circuitState, State.TRIAL_CALL)) { throw e; } final long delayMillis = backoff.nextDelayMillis(e); @@ -297,11 +300,11 @@ public ListenableFuture executeAsync(AsyncCallable call, Backoff backo private ListenableFuture onExecuteAsyncFailure( Exception t, AsyncCallable call, Backoff backoff, State circuitState) { - circuitBreaker.recordFailure(t); - if (circuitState.equals(State.TRIAL_CALL)) { - return Futures.immediateFailedFuture(t); - } if (isRetriable(t)) { + circuitBreaker.recordFailure(); + if (circuitState.equals(State.TRIAL_CALL)) { + return Futures.immediateFailedFuture(t); + } long waitMillis = backoff.nextDelayMillis(t); if (waitMillis >= 0) { try { @@ -315,6 +318,10 @@ private ListenableFuture onExecuteAsyncFailure( return Futures.immediateFailedFuture(t); } } else { + // gRPC Errors NOT_FOUND, OUT_OF_RANGE, ALREADY_EXISTS etc. are non-retriable error, and they + // don't represent an + // issue in Server. So treating these errors as successful api call. + circuitBreaker.recordSuccess(); return Futures.immediateFailedFuture(t); } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/circuitbreaker/CircuitBreakerFactory.java b/src/main/java/com/google/devtools/build/lib/remote/circuitbreaker/CircuitBreakerFactory.java index 6ab6b4258d2cbd..7781440880c885 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/circuitbreaker/CircuitBreakerFactory.java +++ b/src/main/java/com/google/devtools/build/lib/remote/circuitbreaker/CircuitBreakerFactory.java @@ -13,16 +13,11 @@ // limitations under the License. package com.google.devtools.build.lib.remote.circuitbreaker; -import com.google.common.collect.ImmutableSet; import com.google.devtools.build.lib.remote.Retrier; -import com.google.devtools.build.lib.remote.common.CacheNotFoundException; import com.google.devtools.build.lib.remote.options.RemoteOptions; /** Factory for {@link Retrier.CircuitBreaker} */ public class CircuitBreakerFactory { - - public static final ImmutableSet> DEFAULT_IGNORED_ERRORS = - ImmutableSet.of(CacheNotFoundException.class); public static final int DEFAULT_MIN_CALL_COUNT_TO_COMPUTE_FAILURE_RATE = 100; private CircuitBreakerFactory() {} diff --git a/src/main/java/com/google/devtools/build/lib/remote/circuitbreaker/FailureCircuitBreaker.java b/src/main/java/com/google/devtools/build/lib/remote/circuitbreaker/FailureCircuitBreaker.java index 2baeba4ed07f3d..66ebe2270f2992 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/circuitbreaker/FailureCircuitBreaker.java +++ b/src/main/java/com/google/devtools/build/lib/remote/circuitbreaker/FailureCircuitBreaker.java @@ -13,7 +13,6 @@ // limitations under the License. package com.google.devtools.build.lib.remote.circuitbreaker; -import com.google.common.collect.ImmutableSet; import com.google.devtools.build.lib.remote.Retrier; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -32,12 +31,10 @@ public class FailureCircuitBreaker implements Retrier.CircuitBreaker { private State state; private final AtomicInteger successes; private final AtomicInteger failures; - private final AtomicInteger ignoredFailures; private final int failureRateThreshold; private final int slidingWindowSize; private final int minCallCountToComputeFailureRate; private final ScheduledExecutorService scheduledExecutor; - private final ImmutableSet> ignoredErrors; /** * Creates a {@link FailureCircuitBreaker}. @@ -49,14 +46,12 @@ public class FailureCircuitBreaker implements Retrier.CircuitBreaker { public FailureCircuitBreaker(int failureRateThreshold, int slidingWindowSize) { this.failures = new AtomicInteger(0); this.successes = new AtomicInteger(0); - this.ignoredFailures = new AtomicInteger(0); this.failureRateThreshold = failureRateThreshold; this.slidingWindowSize = slidingWindowSize; this.minCallCountToComputeFailureRate = CircuitBreakerFactory.DEFAULT_MIN_CALL_COUNT_TO_COMPUTE_FAILURE_RATE; this.state = State.ACCEPT_CALLS; this.scheduledExecutor = slidingWindowSize > 0 ? Executors.newSingleThreadScheduledExecutor() : null; - this.ignoredErrors = CircuitBreakerFactory.DEFAULT_IGNORED_ERRORS; } @Override @@ -65,31 +60,24 @@ public State state() { } @Override - public void recordFailure(Exception e) { - if (!ignoredErrors.contains(e.getClass())) { - int failureCount = failures.incrementAndGet(); - int totalCallCount = successes.get() + failureCount + ignoredFailures.get(); - if (slidingWindowSize > 0) { - var unused = - scheduledExecutor.schedule( - failures::decrementAndGet, slidingWindowSize, TimeUnit.MILLISECONDS); - } + public void recordFailure() { + int failureCount = failures.incrementAndGet(); + int totalCallCount = successes.get() + failureCount; + if (slidingWindowSize > 0) { + var unused = + scheduledExecutor.schedule( + failures::decrementAndGet, slidingWindowSize, TimeUnit.MILLISECONDS); + } - if (totalCallCount < minCallCountToComputeFailureRate) { - // The remote call count is below the threshold required to calculate the failure rate. - return; - } - double failureRate = (failureCount * 100.0) / totalCallCount; + if (totalCallCount < minCallCountToComputeFailureRate) { + // The remote call count is below the threshold required to calculate the failure rate. + return; + } + double failureRate = (failureCount * 100.0) / totalCallCount; - // Since the state can only be changed to the open state, synchronization is not required. - if (failureRate > this.failureRateThreshold) { - this.state = State.REJECT_CALLS; - } - } else { - ignoredFailures.incrementAndGet(); - if (slidingWindowSize > 0) { - scheduledExecutor.schedule(ignoredFailures::decrementAndGet, slidingWindowSize, TimeUnit.MILLISECONDS); - } + // Since the state can only be changed to the open state, synchronization is not required. + if (failureRate > this.failureRateThreshold) { + this.state = State.REJECT_CALLS; } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/RetrierTest.java b/src/test/java/com/google/devtools/build/lib/remote/RetrierTest.java index 7c30e1bf6eddc3..0889ae9ce52f11 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RetrierTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RetrierTest.java @@ -16,7 +16,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -31,6 +30,10 @@ import com.google.devtools.build.lib.remote.Retrier.CircuitBreakerException; import com.google.devtools.build.lib.remote.Retrier.ZeroBackoff; import com.google.devtools.build.lib.testutil.TestUtils; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -94,7 +97,7 @@ public void retryShouldWork_failure() throws Exception { assertThat(e).hasMessageThat().isEqualTo("call failed"); assertThat(numCalls.get()).isEqualTo(3); - verify(alwaysOpen, times(3)).recordFailure(any(Exception.class)); + verify(alwaysOpen, times(3)).recordFailure(); verify(alwaysOpen, never()).recordSuccess(); } @@ -118,8 +121,8 @@ public void retryShouldWorkNoRetries_failure() throws Exception { assertThat(e).hasMessageThat().isEqualTo("call failed"); assertThat(numCalls.get()).isEqualTo(1); - verify(alwaysOpen, times(1)).recordFailure(e); - verify(alwaysOpen, never()).recordSuccess(); + verify(alwaysOpen, never()).recordFailure(); + verify(alwaysOpen, times(1)).recordSuccess(); } @Test @@ -139,7 +142,7 @@ public void retryShouldWork_success() throws Exception { }); assertThat(val).isEqualTo(1); - verify(alwaysOpen, times(2)).recordFailure(any(Exception.class)); + verify(alwaysOpen, times(2)).recordFailure(); verify(alwaysOpen, times(1)).recordSuccess(); } @@ -332,6 +335,46 @@ public void asyncRetryEmptyError() throws Exception { assertThat(e).hasCauseThat().hasMessageThat().isEqualTo(""); } + @Test + public void testCircuitBreakerFailureAndSuccessCallOnDifferentGrpcError() { + int maxRetries = 2; + Supplier s = () -> new ZeroBackoff(maxRetries); + List retriableGrpcError = + Arrays.asList(Status.ABORTED, Status.UNKNOWN, Status.DEADLINE_EXCEEDED); + List nonRetriableGrpcError = + Arrays.asList(Status.NOT_FOUND, Status.OUT_OF_RANGE, Status.ALREADY_EXISTS); + TripAfterNCircuitBreaker cb = + new TripAfterNCircuitBreaker(retriableGrpcError.size() * (maxRetries + 1)); + Retrier r = new Retrier(s, RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryService, cb); + + int expectedConsecutiveFailures = 0; + + for (Status status : retriableGrpcError) { + ListenableFuture res = + r.executeAsync( + () -> { + throw new StatusRuntimeException(status); + }); + expectedConsecutiveFailures += maxRetries + 1; + assertThrows(ExecutionException.class, res::get); + assertThat(cb.consecutiveFailures).isEqualTo(expectedConsecutiveFailures); + } + + assertThat(cb.state).isEqualTo(State.REJECT_CALLS); + cb.trialCall(); + + for (Status status : nonRetriableGrpcError) { + ListenableFuture res = + r.executeAsync( + () -> { + throw new StatusRuntimeException(status); + }); + assertThat(cb.consecutiveFailures).isEqualTo(0); + assertThrows(ExecutionException.class, res::get); + } + assertThat(cb.state).isEqualTo(State.ACCEPT_CALLS); + } + /** Simple circuit breaker that trips after N consecutive failures. */ @ThreadSafe private static class TripAfterNCircuitBreaker implements CircuitBreaker { @@ -351,7 +394,7 @@ public synchronized State state() { } @Override - public synchronized void recordFailure(Exception e) { + public synchronized void recordFailure() { consecutiveFailures++; if (consecutiveFailures >= maxConsecutiveFailures) { state = State.REJECT_CALLS; diff --git a/src/test/java/com/google/devtools/build/lib/remote/circuitbreaker/FailureCircuitBreakerTest.java b/src/test/java/com/google/devtools/build/lib/remote/circuitbreaker/FailureCircuitBreakerTest.java index 502c49ddff6225..73bc48327d9d58 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/circuitbreaker/FailureCircuitBreakerTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/circuitbreaker/FailureCircuitBreakerTest.java @@ -15,10 +15,7 @@ import static com.google.common.truth.Truth.assertThat; -import build.bazel.remote.execution.v2.Digest; import com.google.devtools.build.lib.remote.Retrier.CircuitBreaker.State; -import com.google.devtools.build.lib.remote.common.CacheNotFoundException; - import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -33,36 +30,37 @@ public class FailureCircuitBreakerTest { @Test - public void testRecordFailure_withIgnoredErrors() throws InterruptedException { + public void testRecordFailure_circuitTrips() throws InterruptedException { final int failureRateThreshold = 10; final int windowInterval = 100; FailureCircuitBreaker failureCircuitBreaker = new FailureCircuitBreaker(failureRateThreshold, windowInterval); - List listOfExceptionThrownOnFailure = new ArrayList<>(); + List listOfSuccessAndFailureCalls = new ArrayList<>(); for (int index = 0; index < failureRateThreshold; index++) { - listOfExceptionThrownOnFailure.add(new Exception()); + listOfSuccessAndFailureCalls.add(failureCircuitBreaker::recordFailure); } + for (int index = 0; index < failureRateThreshold * 9; index++) { - listOfExceptionThrownOnFailure.add(new CacheNotFoundException(Digest.newBuilder().build())); + listOfSuccessAndFailureCalls.add(failureCircuitBreaker::recordSuccess); } - Collections.shuffle(listOfExceptionThrownOnFailure); + Collections.shuffle(listOfSuccessAndFailureCalls); // make calls equals to threshold number of not ignored failure calls in parallel. - listOfExceptionThrownOnFailure.stream().parallel().forEach(failureCircuitBreaker::recordFailure); + listOfSuccessAndFailureCalls.stream().parallel().forEach(Runnable::run); assertThat(failureCircuitBreaker.state()).isEqualTo(State.ACCEPT_CALLS); // Sleep for windowInterval + 1ms. Thread.sleep(windowInterval + 1 /*to compensate any delay*/); // make calls equals to threshold number of not ignored failure calls in parallel. - listOfExceptionThrownOnFailure.stream().parallel().forEach(failureCircuitBreaker::recordFailure); + listOfSuccessAndFailureCalls.stream().parallel().forEach(Runnable::run); assertThat(failureCircuitBreaker.state()).isEqualTo(State.ACCEPT_CALLS); // Sleep for less than windowInterval. Thread.sleep(windowInterval - 5); - failureCircuitBreaker.recordFailure(new Exception()); + failureCircuitBreaker.recordFailure(); assertThat(failureCircuitBreaker.state()).isEqualTo(State.REJECT_CALLS); } @@ -74,15 +72,19 @@ public void testRecordFailure_minCallCriteriaNotMet() throws InterruptedExceptio FailureCircuitBreaker failureCircuitBreaker = new FailureCircuitBreaker(failureRateThreshold, windowInterval); - // make half failure call, half success call and number of total call less than minCallToComputeFailure. - IntStream.range(0, minCallToComputeFailure >> 1).parallel() - .forEach(i -> failureCircuitBreaker.recordFailure(new Exception())); - IntStream.range(0, minCallToComputeFailure >> 1).parallel().forEach(i -> failureCircuitBreaker.recordSuccess()); + // make half failure call, half success call and number of total call less than + // minCallToComputeFailure. + IntStream.range(0, minCallToComputeFailure >> 1) + .parallel() + .forEach(i -> failureCircuitBreaker.recordFailure()); + IntStream.range(0, minCallToComputeFailure >> 1) + .parallel() + .forEach(i -> failureCircuitBreaker.recordSuccess()); assertThat(failureCircuitBreaker.state()).isEqualTo(State.ACCEPT_CALLS); // Sleep for less than windowInterval. - Thread.sleep(windowInterval - 20); - failureCircuitBreaker.recordFailure(new Exception()); + Thread.sleep(windowInterval - 50); + failureCircuitBreaker.recordFailure(); assertThat(failureCircuitBreaker.state()).isEqualTo(State.REJECT_CALLS); } }