Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(utils): Add helpers for circuit breaker and circuit breaker tests #74559

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 124 additions & 2 deletions src/sentry/utils/circuit_breaker2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
"""

import logging
import time
from enum import Enum
from typing import NotRequired, TypedDict
from typing import Any, Literal, NotRequired, TypedDict, overload

from django.conf import settings

from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter
from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter, RequestedQuota

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -180,3 +181,124 @@ def __init__(self, key: str, config: CircuitBreakerConfig):
default_recovery_duration,
)
self.recovery_duration = default_recovery_duration

def _get_from_redis(self, keys: list[str]) -> Any:
for key in keys:
self.redis_pipeline.get(key)
return self.redis_pipeline.execute()

def _set_in_redis(self, keys_values_and_timeouts: list[tuple[str, Any, int]]) -> None:
for key, value, timeout in keys_values_and_timeouts:
self.redis_pipeline.set(key, value, timeout)
self.redis_pipeline.execute()

def _get_state_and_remaining_time(
self,
) -> tuple[CircuitBreakerState, int | None]:
"""
Return the current state of the breaker (OK, BROKEN, or in RECOVERY), along with the
number of seconds until that state expires (or `None` when in OK state, as it has no
expiry).
"""
now = int(time.time())

try:
broken_state_expiry, recovery_state_expiry = self._get_from_redis(
[self.broken_state_key, self.recovery_state_key]
)
except Exception:
logger.exception("Couldn't get state from redis for circuit breaker '%s'", self.key)

# Default to letting traffic through so the breaker doesn't become a single point of failure
return (CircuitBreakerState.OK, None)

# The BROKEN state key should always expire before the RECOVERY state one, so check it first
if broken_state_expiry is not None:
broken_state_seconds_left = int(broken_state_expiry) - now

# In theory there should always be time left (the key should have expired otherwise),
# but race conditions/caching/etc means we should check, just to be sure
if broken_state_seconds_left > 0:
return (CircuitBreakerState.BROKEN, broken_state_seconds_left)

if recovery_state_expiry is not None:
recovery_state_seconds_left = int(recovery_state_expiry) - now
if recovery_state_seconds_left > 0:
return (CircuitBreakerState.RECOVERY, recovery_state_seconds_left)

return (CircuitBreakerState.OK, None)

@overload
def _get_controlling_quota(
self, state: Literal[CircuitBreakerState.OK, CircuitBreakerState.RECOVERY]
) -> Quota:
...

@overload
def _get_controlling_quota(self, state: Literal[CircuitBreakerState.BROKEN]) -> None:
...

@overload
def _get_controlling_quota(self) -> Quota | None:
...

def _get_controlling_quota(self, state: CircuitBreakerState | None = None) -> Quota | None:
"""
Return the Quota corresponding to the given breaker state (or the current breaker state, if
no state is provided). If the state is question is the BROKEN state, return None.
"""
controlling_quota_by_state = {
CircuitBreakerState.OK: self.primary_quota,
CircuitBreakerState.BROKEN: None,
CircuitBreakerState.RECOVERY: self.recovery_quota,
}

_state = state or self._get_state_and_remaining_time()[0]

return controlling_quota_by_state[_state]

@overload
def _get_remaining_error_quota(self, quota: None, window_end: int | None) -> None:
...

@overload
def _get_remaining_error_quota(self, quota: Quota, window_end: int | None) -> int:
...

@overload
def _get_remaining_error_quota(self, quota: None) -> None:
...

@overload
def _get_remaining_error_quota(self, quota: Quota) -> int:
...

@overload
def _get_remaining_error_quota(self) -> int | None:
...

def _get_remaining_error_quota(
self, quota: Quota | None = None, window_end: int | None = None
) -> int | None:
"""
Get the number of allowable errors remaining in the given quota for the time window ending
at the given time.
If no quota is given, in OK and RECOVERY states, return the current controlling quota's
remaining errors. In BROKEN state, return None.
If no time window end is given, return the current amount of quota remaining.
"""
if not quota:
quota = self._get_controlling_quota()
if quota is None: # BROKEN state
return None

now = int(time.time())
window_end = window_end or now

_, result = self.limiter.check_within_quotas(
[RequestedQuota(self.key, quota.limit, [quota])], window_end
)

return result[0].granted
209 changes: 204 additions & 5 deletions tests/sentry/utils/test_circuit_breaker2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import time
from typing import Any
from unittest import TestCase
from unittest.mock import ANY, MagicMock, patch

from django.conf import settings
from redis.client import Pipeline

from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter
from sentry.ratelimits.sliding_windows import (
GrantedQuota,
Quota,
RedisSlidingWindowRateLimiter,
RequestedQuota,
)
from sentry.testutils.helpers.datetime import freeze_time
from sentry.utils.circuit_breaker2 import CircuitBreaker, CircuitBreakerConfig
from sentry.utils.circuit_breaker2 import CircuitBreaker, CircuitBreakerConfig, CircuitBreakerState

# Note: These need to be relatively big. If the limit is too low, the RECOVERY quota isn't big
# enough to be useful, and if the window is too short, redis (which doesn't seem to listen to the
Expand All @@ -18,11 +25,203 @@
}


class MockCircuitBreaker(CircuitBreaker):
"""
A circuit breaker with extra methods useful for mocking state.
To understand the methods below, it helps to understand the `RedisSlidingWindowRateLimiter`
which powers the circuit breaker. Details can be found in
https://github.com/getsentry/sentry-redis-tools/blob/d4f3dc883b1137d82b6b7a92f4b5b41991c1fc8a/sentry_redis_tools/sliding_windows_rate_limiter.py,
(which is the implementation behind the rate limiter) but TL;DR, quota usage during the time
window is tallied in buckets ("granules"), and as time passes the window slides forward one
granule at a time. To be able to mimic this, most of the methods here operate at the granule
level.
"""

def _set_breaker_state(
self, state: CircuitBreakerState, seconds_left: int | None = None
) -> None:
"""
Adjust redis keys to force the breaker into the given state. If no remaining seconds are
given, puts the breaker at the beginning of its time in the given state.
"""
now = int(time.time())

if state == CircuitBreakerState.OK:
self._delete_from_redis([self.broken_state_key, self.recovery_state_key])

elif state == CircuitBreakerState.BROKEN:
broken_state_timeout = seconds_left or self.broken_state_duration
broken_state_end = now + broken_state_timeout
recovery_timeout = broken_state_timeout + self.recovery_duration
recovery_end = now + recovery_timeout

self._set_in_redis(
[
(self.broken_state_key, broken_state_end, broken_state_timeout),
(self.recovery_state_key, recovery_end, recovery_timeout),
]
)

elif state == CircuitBreakerState.RECOVERY:
recovery_timeout = seconds_left or self.recovery_duration
recovery_end = now + recovery_timeout

self._delete_from_redis([self.broken_state_key])
self._set_in_redis([(self.recovery_state_key, recovery_end, recovery_timeout)])

assert self._get_state_and_remaining_time() == (
state,
(
None
if state == CircuitBreakerState.OK
else (
broken_state_timeout
if state == CircuitBreakerState.BROKEN
else recovery_timeout
)
),
)

def _add_quota_usage(
self,
quota: Quota,
amount_used: int,
granule_or_window_end: int | None = None,
) -> None:
"""
Add to the usage total of the given quota, in the granule or window ending at the given
time. If a window (rather than a granule) end time is given, usage will be added to the
final granule.
If no end time is given, the current time will be used.
"""
now = int(time.time())
window_end_time = granule_or_window_end or now

self.limiter.use_quotas(
[RequestedQuota(self.key, amount_used, [quota])],
[GrantedQuota(self.key, amount_used, [])],
window_end_time,
)

def _clear_quota(self, quota: Quota, window_end: int | None = None) -> list[int]:
"""
Clear usage of the given quota up until the end of the given time window. If no window end
is given, clear the quota up to the present.
Returns the list of granule values which were cleared.
"""
now = int(time.time())
window_end_time = window_end or now
granule_end_times = self._get_granule_end_times(quota, window_end_time)
num_granules = len(granule_end_times)
previous_granule_values = [0] * num_granules

current_total_quota_used = quota.limit - self._get_remaining_error_quota(
quota, window_end_time
)
if current_total_quota_used != 0:
# Empty the granules one by one, starting with the oldest.
#
# To empty each granule, we need to add negative quota usage, which means we need to
# know how much usage is currently in each granule. Unfortunately, the limiter will only
# report quota usage at the window level, not the granule level. To get around this, we
# start with a window ending with the oldest granule. Any granules before it will have
# expired, so the window usage will equal the granule usage.ending in that granule will
# have a total usage equal to that of the granule.
#
# Once we zero-out the granule, we can move the window one granule forward. It will now
# consist of expired granules, the granule we just set to 0, and the granule we care
# about. Thus the window usage will again match the granule usage, which we can use to
# empty the granule. We then just repeat the pattern until we've reached the end of the
# window we want to clear.
for i, granule_end_time in enumerate(granule_end_times):
granule_quota_used = quota.limit - self._get_remaining_error_quota(
quota, granule_end_time
)
previous_granule_values[i] = granule_quota_used
self._add_quota_usage(quota, -granule_quota_used, granule_end_time)

new_total_quota_used = quota.limit - self._get_remaining_error_quota(
quota, window_end_time
)
assert new_total_quota_used == 0

return previous_granule_values

def _get_granule_end_times(
self, quota: Quota, window_end: int, newest_first: bool = False
) -> list[int]:
"""
Given a quota and the end of the time window it's covering, return the timestamps
corresponding to the end of each granule.
"""
window_duration = quota.window_seconds
granule_duration = quota.granularity_seconds
num_granules = window_duration // granule_duration

# Walk backwards through the granules
end_times_newest_first = [
window_end - num_granules_ago * granule_duration
for num_granules_ago in range(num_granules)
]

return end_times_newest_first if newest_first else list(reversed(end_times_newest_first))

def _set_granule_values(
self,
quota: Quota,
values: list[int | None],
window_end: int | None = None,
) -> None:
"""
Set the usage in each granule of the given quota, for the time window ending at the given
time.
If no ending time is given, the current time is used.
The list of values should be ordered from oldest to newest and must contain the same number
of elements as the window has granules. To only change some of the values, pass `None` in
the spot of any value which should remain unchanged. (For example, in a two-granule window,
to only change the older granule, pass `[3, None]`.)
"""
window_duration = quota.window_seconds
granule_duration = quota.granularity_seconds
num_granules = window_duration // granule_duration

if len(values) != num_granules:
raise Exception(
f"Exactly {num_granules} granule values must be provided. "
+ "To leave an existing value as is, include `None` in its spot."
)

now = int(time.time())
window_end_time = window_end or now

previous_values = self._clear_quota(quota, window_end_time)

for i, granule_end_time, value in zip(
range(num_granules), self._get_granule_end_times(quota, window_end_time), values
):
# When we cleared the quota above, we set each granule's value to 0, so here "adding"
# usage is actually setting usage
if value is not None:
self._add_quota_usage(quota, value, granule_end_time)
else:
self._add_quota_usage(quota, previous_values[i], granule_end_time)

def _delete_from_redis(self, keys: list[str]) -> Any:
for key in keys:
self.redis_pipeline.delete(key)
return self.redis_pipeline.execute()


@freeze_time()
class CircuitBreakerTest(TestCase):
def setUp(self) -> None:
self.config = DEFAULT_CONFIG
self.breaker = CircuitBreaker("dogs_are_great", self.config)
self.breaker = MockCircuitBreaker("dogs_are_great", self.config)

# Clear all existing keys from redis
self.breaker.redis_pipeline.flushall()
Expand Down Expand Up @@ -78,7 +277,7 @@ def test_fixes_too_loose_recovery_limit(self, mock_logger: MagicMock):
(False, mock_logger.warning),
]:
settings.DEBUG = settings_debug_value
breaker = CircuitBreaker("dogs_are_great", config)
breaker = MockCircuitBreaker("dogs_are_great", config)

expected_log_function.assert_called_with(
"Circuit breaker '%s' has a recovery error limit (%d) greater than or equal"
Expand All @@ -104,7 +303,7 @@ def test_fixes_mismatched_state_durations(self, mock_logger: MagicMock):
(False, mock_logger.warning),
]:
settings.DEBUG = settings_debug_value
breaker = CircuitBreaker("dogs_are_great", config)
breaker = MockCircuitBreaker("dogs_are_great", config)

expected_log_function.assert_called_with(
"Circuit breaker '%s' has BROKEN and RECOVERY state durations (%d and %d sec, respectively)"
Expand Down
Loading