Skip to content

Commit

Permalink
chore(typing): stricter decorator type checking for leaky bucket (#74687
Browse files Browse the repository at this point in the history
)

Adds stricter decorator type checking for leaky bucket
  • Loading branch information
Bartek Ogryczak authored Jul 23, 2024
1 parent 62d415c commit 0066487
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
11 changes: 7 additions & 4 deletions src/sentry/ratelimits/leaky_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from collections.abc import Callable
from dataclasses import dataclass
from time import time
from typing import Any
from typing import Any, ParamSpec, TypeVar

from django.conf import settings

from sentry.exceptions import InvalidConfiguration
from sentry.utils import redis

P = ParamSpec("P")
R = TypeVar("R")

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -113,9 +116,9 @@ def get_bucket_state(self, key: str | None = None) -> LeakyBucketLimitInfo:
def decorator(
self,
key_override: str | None = None,
limited_handler: Callable[[LeakyBucketLimitInfo, dict[str, Any]], Any] | None = None,
limited_handler: Callable[[LeakyBucketLimitInfo, dict[str, Any]], R] | None = None,
raise_exception: bool = False,
) -> Callable[[Any], Any]:
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator to limit the rate of requests
Expand Down Expand Up @@ -181,7 +184,7 @@ def my_function():
"""

def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
Expand Down
24 changes: 12 additions & 12 deletions tests/sentry/ratelimits/test_leaky_bucket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Never
from unittest import mock

import pytest
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_drip_rate(self) -> None:

def test_decorator(self) -> None:
@self.limiter("foo")
def foo() -> None:
def foo() -> Never:
assert False, "This should not be executed when limited"

with freeze_time("2077-09-13"):
Expand All @@ -75,7 +75,7 @@ def foo() -> None:
assert foo() is None

@self.limiter("bar", raise_exception=True)
def bar() -> None:
def bar() -> Never:
assert False, "This should not be executed when limited"

with freeze_time("2077-09-13"):
Expand All @@ -88,23 +88,23 @@ def bar() -> None:

last_info: list[LeakyBucketLimitInfo] = []

def callback(info: LeakyBucketLimitInfo, context: dict[str, Any]) -> LeakyBucketLimitInfo:
def callback(info: LeakyBucketLimitInfo, context: dict[str, Any]) -> str:
last_info.append(info)
return info
return "rate limited"

@self.limiter("baz", limited_handler=callback)
def baz() -> bool:
return True
def baz() -> str:
return "normal value"

with freeze_time("2077-09-13"):
for i in range(5):
assert baz() is True
assert baz() == "normal value"
assert len(last_info) == 0

info = baz()
assert info
baz_rv = baz()
assert baz_rv == "rate limited"
assert len(last_info) == 1
assert last_info[0] == info
info = last_info[0]
assert info.wait_time > 0
assert info.current_level == 5

Expand All @@ -114,7 +114,7 @@ def test_decorator_default_key(self) -> None:
with mock.patch.object(limiter, "_redis_key", wraps=limiter._redis_key) as _redis_key_spy:

@limiter()
def foo() -> None:
def foo() -> Any:
pass

foo()
Expand Down

0 comments on commit 0066487

Please sign in to comment.