From 0d530a626517f11694d9fb966d9189eb1ee5f16c Mon Sep 17 00:00:00 2001 From: Katie Byers Date: Tue, 23 Jul 2024 11:07:22 -0700 Subject: [PATCH] feat(utils): Add helpers for circuit breaker and circuit breaker tests (#74559) This is a follow-up to https://github.com/getsentry/sentry/pull/74557, which added the beginnings of a rate-limit-based circuit breaker, in the form of a new `CircuitBreaker` class. In this PR, various helpers, for checking the state of the breaker and the underlying rate limiters and for communicating with redis, have been added to the class. It also adds a `MockCircuitBreaker` subclass for use in tests, which contains a number of methods for mocking both circuit breaker and rate limiter state. Note that though these helpers don't have accompanying tests, they are tested (indirectly) in the final PR in the series[1], as part of testing the methods which use them. [1] https://github.com/getsentry/sentry/pull/74560 --- src/sentry/utils/circuit_breaker2.py | 126 +++++++++++- tests/sentry/utils/test_circuit_breaker2.py | 209 +++++++++++++++++++- 2 files changed, 328 insertions(+), 7 deletions(-) diff --git a/src/sentry/utils/circuit_breaker2.py b/src/sentry/utils/circuit_breaker2.py index 002c2f121403a0..108349f3189a9f 100644 --- a/src/sentry/utils/circuit_breaker2.py +++ b/src/sentry/utils/circuit_breaker2.py @@ -6,12 +6,13 @@ """ import logging +import time from enum import Enum -from typing import NotRequired, TypedDict +from typing import Any, Literal, NotRequired, TypedDict, overload from django.conf import settings -from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter +from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter, RequestedQuota logger = logging.getLogger(__name__) @@ -180,3 +181,124 @@ def __init__(self, key: str, config: CircuitBreakerConfig): default_recovery_duration, ) self.recovery_duration = default_recovery_duration + + def _get_from_redis(self, keys: list[str]) -> Any: + for key in keys: + self.redis_pipeline.get(key) + return self.redis_pipeline.execute() + + def _set_in_redis(self, keys_values_and_timeouts: list[tuple[str, Any, int]]) -> None: + for key, value, timeout in keys_values_and_timeouts: + self.redis_pipeline.set(key, value, timeout) + self.redis_pipeline.execute() + + def _get_state_and_remaining_time( + self, + ) -> tuple[CircuitBreakerState, int | None]: + """ + Return the current state of the breaker (OK, BROKEN, or in RECOVERY), along with the + number of seconds until that state expires (or `None` when in OK state, as it has no + expiry). + """ + now = int(time.time()) + + try: + broken_state_expiry, recovery_state_expiry = self._get_from_redis( + [self.broken_state_key, self.recovery_state_key] + ) + except Exception: + logger.exception("Couldn't get state from redis for circuit breaker '%s'", self.key) + + # Default to letting traffic through so the breaker doesn't become a single point of failure + return (CircuitBreakerState.OK, None) + + # The BROKEN state key should always expire before the RECOVERY state one, so check it first + if broken_state_expiry is not None: + broken_state_seconds_left = int(broken_state_expiry) - now + + # In theory there should always be time left (the key should have expired otherwise), + # but race conditions/caching/etc means we should check, just to be sure + if broken_state_seconds_left > 0: + return (CircuitBreakerState.BROKEN, broken_state_seconds_left) + + if recovery_state_expiry is not None: + recovery_state_seconds_left = int(recovery_state_expiry) - now + if recovery_state_seconds_left > 0: + return (CircuitBreakerState.RECOVERY, recovery_state_seconds_left) + + return (CircuitBreakerState.OK, None) + + @overload + def _get_controlling_quota( + self, state: Literal[CircuitBreakerState.OK, CircuitBreakerState.RECOVERY] + ) -> Quota: + ... + + @overload + def _get_controlling_quota(self, state: Literal[CircuitBreakerState.BROKEN]) -> None: + ... + + @overload + def _get_controlling_quota(self) -> Quota | None: + ... + + def _get_controlling_quota(self, state: CircuitBreakerState | None = None) -> Quota | None: + """ + Return the Quota corresponding to the given breaker state (or the current breaker state, if + no state is provided). If the state is question is the BROKEN state, return None. + """ + controlling_quota_by_state = { + CircuitBreakerState.OK: self.primary_quota, + CircuitBreakerState.BROKEN: None, + CircuitBreakerState.RECOVERY: self.recovery_quota, + } + + _state = state or self._get_state_and_remaining_time()[0] + + return controlling_quota_by_state[_state] + + @overload + def _get_remaining_error_quota(self, quota: None, window_end: int | None) -> None: + ... + + @overload + def _get_remaining_error_quota(self, quota: Quota, window_end: int | None) -> int: + ... + + @overload + def _get_remaining_error_quota(self, quota: None) -> None: + ... + + @overload + def _get_remaining_error_quota(self, quota: Quota) -> int: + ... + + @overload + def _get_remaining_error_quota(self) -> int | None: + ... + + def _get_remaining_error_quota( + self, quota: Quota | None = None, window_end: int | None = None + ) -> int | None: + """ + Get the number of allowable errors remaining in the given quota for the time window ending + at the given time. + + If no quota is given, in OK and RECOVERY states, return the current controlling quota's + remaining errors. In BROKEN state, return None. + + If no time window end is given, return the current amount of quota remaining. + """ + if not quota: + quota = self._get_controlling_quota() + if quota is None: # BROKEN state + return None + + now = int(time.time()) + window_end = window_end or now + + _, result = self.limiter.check_within_quotas( + [RequestedQuota(self.key, quota.limit, [quota])], window_end + ) + + return result[0].granted diff --git a/tests/sentry/utils/test_circuit_breaker2.py b/tests/sentry/utils/test_circuit_breaker2.py index 09d41eb80db929..5cb4370b09e304 100644 --- a/tests/sentry/utils/test_circuit_breaker2.py +++ b/tests/sentry/utils/test_circuit_breaker2.py @@ -1,12 +1,19 @@ +import time +from typing import Any from unittest import TestCase from unittest.mock import ANY, MagicMock, patch from django.conf import settings from redis.client import Pipeline -from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter +from sentry.ratelimits.sliding_windows import ( + GrantedQuota, + Quota, + RedisSlidingWindowRateLimiter, + RequestedQuota, +) from sentry.testutils.helpers.datetime import freeze_time -from sentry.utils.circuit_breaker2 import CircuitBreaker, CircuitBreakerConfig +from sentry.utils.circuit_breaker2 import CircuitBreaker, CircuitBreakerConfig, CircuitBreakerState # Note: These need to be relatively big. If the limit is too low, the RECOVERY quota isn't big # enough to be useful, and if the window is too short, redis (which doesn't seem to listen to the @@ -18,11 +25,203 @@ } +class MockCircuitBreaker(CircuitBreaker): + """ + A circuit breaker with extra methods useful for mocking state. + + To understand the methods below, it helps to understand the `RedisSlidingWindowRateLimiter` + which powers the circuit breaker. Details can be found in + https://github.com/getsentry/sentry-redis-tools/blob/d4f3dc883b1137d82b6b7a92f4b5b41991c1fc8a/sentry_redis_tools/sliding_windows_rate_limiter.py, + (which is the implementation behind the rate limiter) but TL;DR, quota usage during the time + window is tallied in buckets ("granules"), and as time passes the window slides forward one + granule at a time. To be able to mimic this, most of the methods here operate at the granule + level. + """ + + def _set_breaker_state( + self, state: CircuitBreakerState, seconds_left: int | None = None + ) -> None: + """ + Adjust redis keys to force the breaker into the given state. If no remaining seconds are + given, puts the breaker at the beginning of its time in the given state. + """ + now = int(time.time()) + + if state == CircuitBreakerState.OK: + self._delete_from_redis([self.broken_state_key, self.recovery_state_key]) + + elif state == CircuitBreakerState.BROKEN: + broken_state_timeout = seconds_left or self.broken_state_duration + broken_state_end = now + broken_state_timeout + recovery_timeout = broken_state_timeout + self.recovery_duration + recovery_end = now + recovery_timeout + + self._set_in_redis( + [ + (self.broken_state_key, broken_state_end, broken_state_timeout), + (self.recovery_state_key, recovery_end, recovery_timeout), + ] + ) + + elif state == CircuitBreakerState.RECOVERY: + recovery_timeout = seconds_left or self.recovery_duration + recovery_end = now + recovery_timeout + + self._delete_from_redis([self.broken_state_key]) + self._set_in_redis([(self.recovery_state_key, recovery_end, recovery_timeout)]) + + assert self._get_state_and_remaining_time() == ( + state, + ( + None + if state == CircuitBreakerState.OK + else ( + broken_state_timeout + if state == CircuitBreakerState.BROKEN + else recovery_timeout + ) + ), + ) + + def _add_quota_usage( + self, + quota: Quota, + amount_used: int, + granule_or_window_end: int | None = None, + ) -> None: + """ + Add to the usage total of the given quota, in the granule or window ending at the given + time. If a window (rather than a granule) end time is given, usage will be added to the + final granule. + + If no end time is given, the current time will be used. + """ + now = int(time.time()) + window_end_time = granule_or_window_end or now + + self.limiter.use_quotas( + [RequestedQuota(self.key, amount_used, [quota])], + [GrantedQuota(self.key, amount_used, [])], + window_end_time, + ) + + def _clear_quota(self, quota: Quota, window_end: int | None = None) -> list[int]: + """ + Clear usage of the given quota up until the end of the given time window. If no window end + is given, clear the quota up to the present. + + Returns the list of granule values which were cleared. + """ + now = int(time.time()) + window_end_time = window_end or now + granule_end_times = self._get_granule_end_times(quota, window_end_time) + num_granules = len(granule_end_times) + previous_granule_values = [0] * num_granules + + current_total_quota_used = quota.limit - self._get_remaining_error_quota( + quota, window_end_time + ) + if current_total_quota_used != 0: + # Empty the granules one by one, starting with the oldest. + # + # To empty each granule, we need to add negative quota usage, which means we need to + # know how much usage is currently in each granule. Unfortunately, the limiter will only + # report quota usage at the window level, not the granule level. To get around this, we + # start with a window ending with the oldest granule. Any granules before it will have + # expired, so the window usage will equal the granule usage.ending in that granule will + # have a total usage equal to that of the granule. + # + # Once we zero-out the granule, we can move the window one granule forward. It will now + # consist of expired granules, the granule we just set to 0, and the granule we care + # about. Thus the window usage will again match the granule usage, which we can use to + # empty the granule. We then just repeat the pattern until we've reached the end of the + # window we want to clear. + for i, granule_end_time in enumerate(granule_end_times): + granule_quota_used = quota.limit - self._get_remaining_error_quota( + quota, granule_end_time + ) + previous_granule_values[i] = granule_quota_used + self._add_quota_usage(quota, -granule_quota_used, granule_end_time) + + new_total_quota_used = quota.limit - self._get_remaining_error_quota( + quota, window_end_time + ) + assert new_total_quota_used == 0 + + return previous_granule_values + + def _get_granule_end_times( + self, quota: Quota, window_end: int, newest_first: bool = False + ) -> list[int]: + """ + Given a quota and the end of the time window it's covering, return the timestamps + corresponding to the end of each granule. + """ + window_duration = quota.window_seconds + granule_duration = quota.granularity_seconds + num_granules = window_duration // granule_duration + + # Walk backwards through the granules + end_times_newest_first = [ + window_end - num_granules_ago * granule_duration + for num_granules_ago in range(num_granules) + ] + + return end_times_newest_first if newest_first else list(reversed(end_times_newest_first)) + + def _set_granule_values( + self, + quota: Quota, + values: list[int | None], + window_end: int | None = None, + ) -> None: + """ + Set the usage in each granule of the given quota, for the time window ending at the given + time. + + If no ending time is given, the current time is used. + + The list of values should be ordered from oldest to newest and must contain the same number + of elements as the window has granules. To only change some of the values, pass `None` in + the spot of any value which should remain unchanged. (For example, in a two-granule window, + to only change the older granule, pass `[3, None]`.) + """ + window_duration = quota.window_seconds + granule_duration = quota.granularity_seconds + num_granules = window_duration // granule_duration + + if len(values) != num_granules: + raise Exception( + f"Exactly {num_granules} granule values must be provided. " + + "To leave an existing value as is, include `None` in its spot." + ) + + now = int(time.time()) + window_end_time = window_end or now + + previous_values = self._clear_quota(quota, window_end_time) + + for i, granule_end_time, value in zip( + range(num_granules), self._get_granule_end_times(quota, window_end_time), values + ): + # When we cleared the quota above, we set each granule's value to 0, so here "adding" + # usage is actually setting usage + if value is not None: + self._add_quota_usage(quota, value, granule_end_time) + else: + self._add_quota_usage(quota, previous_values[i], granule_end_time) + + def _delete_from_redis(self, keys: list[str]) -> Any: + for key in keys: + self.redis_pipeline.delete(key) + return self.redis_pipeline.execute() + + @freeze_time() class CircuitBreakerTest(TestCase): def setUp(self) -> None: self.config = DEFAULT_CONFIG - self.breaker = CircuitBreaker("dogs_are_great", self.config) + self.breaker = MockCircuitBreaker("dogs_are_great", self.config) # Clear all existing keys from redis self.breaker.redis_pipeline.flushall() @@ -78,7 +277,7 @@ def test_fixes_too_loose_recovery_limit(self, mock_logger: MagicMock): (False, mock_logger.warning), ]: settings.DEBUG = settings_debug_value - breaker = CircuitBreaker("dogs_are_great", config) + breaker = MockCircuitBreaker("dogs_are_great", config) expected_log_function.assert_called_with( "Circuit breaker '%s' has a recovery error limit (%d) greater than or equal" @@ -104,7 +303,7 @@ def test_fixes_mismatched_state_durations(self, mock_logger: MagicMock): (False, mock_logger.warning), ]: settings.DEBUG = settings_debug_value - breaker = CircuitBreaker("dogs_are_great", config) + breaker = MockCircuitBreaker("dogs_are_great", config) expected_log_function.assert_called_with( "Circuit breaker '%s' has BROKEN and RECOVERY state durations (%d and %d sec, respectively)"