diff --git a/src/sentry/ratelimits/leaky_bucket.py b/src/sentry/ratelimits/leaky_bucket.py index 22d5f9d3f4e35a..d105e78d2623f7 100644 --- a/src/sentry/ratelimits/leaky_bucket.py +++ b/src/sentry/ratelimits/leaky_bucket.py @@ -5,13 +5,16 @@ from collections.abc import Callable from dataclasses import dataclass from time import time -from typing import Any +from typing import Any, ParamSpec, TypeVar from django.conf import settings from sentry.exceptions import InvalidConfiguration from sentry.utils import redis +P = ParamSpec("P") +R = TypeVar("R") + logger = logging.getLogger(__name__) @@ -113,9 +116,9 @@ def get_bucket_state(self, key: str | None = None) -> LeakyBucketLimitInfo: def decorator( self, key_override: str | None = None, - limited_handler: Callable[[LeakyBucketLimitInfo, dict[str, Any]], Any] | None = None, + limited_handler: Callable[[LeakyBucketLimitInfo, dict[str, Any]], R] | None = None, raise_exception: bool = False, - ) -> Callable[[Any], Any]: + ) -> Callable[[Callable[P, R]], Callable[P, R]]: """ Decorator to limit the rate of requests @@ -181,7 +184,7 @@ def my_function(): """ - def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]: + def decorator(func: Callable[P, R]) -> Callable[P, R]: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: try: diff --git a/tests/sentry/ratelimits/test_leaky_bucket.py b/tests/sentry/ratelimits/test_leaky_bucket.py index 54ec64b19fa1d4..c8c0d9c62375f7 100644 --- a/tests/sentry/ratelimits/test_leaky_bucket.py +++ b/tests/sentry/ratelimits/test_leaky_bucket.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Never from unittest import mock import pytest @@ -64,7 +64,7 @@ def test_drip_rate(self) -> None: def test_decorator(self) -> None: @self.limiter("foo") - def foo() -> None: + def foo() -> Never: assert False, "This should not be executed when limited" with freeze_time("2077-09-13"): @@ -75,7 +75,7 @@ def foo() -> None: assert foo() is None @self.limiter("bar", raise_exception=True) - def bar() -> None: + def bar() -> Never: assert False, "This should not be executed when limited" with freeze_time("2077-09-13"): @@ -88,23 +88,23 @@ def bar() -> None: last_info: list[LeakyBucketLimitInfo] = [] - def callback(info: LeakyBucketLimitInfo, context: dict[str, Any]) -> LeakyBucketLimitInfo: + def callback(info: LeakyBucketLimitInfo, context: dict[str, Any]) -> str: last_info.append(info) - return info + return "rate limited" @self.limiter("baz", limited_handler=callback) - def baz() -> bool: - return True + def baz() -> str: + return "normal value" with freeze_time("2077-09-13"): for i in range(5): - assert baz() is True + assert baz() == "normal value" assert len(last_info) == 0 - info = baz() - assert info + baz_rv = baz() + assert baz_rv == "rate limited" assert len(last_info) == 1 - assert last_info[0] == info + info = last_info[0] assert info.wait_time > 0 assert info.current_level == 5 @@ -114,7 +114,7 @@ def test_decorator_default_key(self) -> None: with mock.patch.object(limiter, "_redis_key", wraps=limiter._redis_key) as _redis_key_spy: @limiter() - def foo() -> None: + def foo() -> Any: pass foo()