Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Commit

Permalink
Add throttle exemptions (#568)
Browse files Browse the repository at this point in the history
* Add throttle exemptions

* Add more docstrings

* Cover edge case for unauthed requests

* Remove dev dependency from prod deps

* Make ThrottleExemption a true ABC

Props to @AetherUnbound

* Fix various linting errors

* Put back terminal history destroying logging
  • Loading branch information
sarayourfriend authored Mar 16, 2022
1 parent f233598 commit 188b6a0
Show file tree
Hide file tree
Showing 6 changed files with 585 additions and 164 deletions.
1 change: 1 addition & 0 deletions api/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ sphinx-autobuild = "*"
furo = "*"
myst-parser = "*"
factory-boy = "*"
fakeredis = "*"

[packages]
aws-requests-auth = "*"
Expand Down
424 changes: 269 additions & 155 deletions api/Pipfile.lock

Large diffs are not rendered by default.

93 changes: 84 additions & 9 deletions api/catalog/api/utils/throttle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import logging

from catalog.api.utils.oauth2_helper import get_token_info
Expand All @@ -8,23 +9,98 @@
log = logging.getLogger(__name__)


def _from_internal_network(ip):
redis = get_redis_connection("default")
return redis.sismember("ip-whitelist", ip)
class ThrottleExemption(abc.ABC):
"""
Abstract class describing a given throttle exemption.
To be included in an iterable of ``ThrottleExemption``s
in children of ``ExemptionAwareThrottle``.
"""

def __init__(self, throttle_class, request):
"""
:param throttle_class: The throttle class the exemption modifies.
:param request: The current request against which to evaluate the exemption.
"""
self.throttle_class = throttle_class
self.request = request

@abc.abstractmethod
def is_exempt(self) -> bool:
"""
Whether the current request is exempt from throttling.
:return: ``True`` if exempt, ``False`` otherwise.
"""
...


class AnonRateThrottle(SimpleRateThrottle):
class ExemptionAwareThrottle(SimpleRateThrottle):
"""
An throttle exemption aware base throttle.
Classes in ``exemption_classes`` are evaluated for each
request. If any of them detect an exempted request then
the request will not be throttled.
"""

exemption_classes = []

def allow_request(self, request, view):
"""
Short circuit ``allow_request`` if _any_ exemption
declares the request to be exempt from the throttle.
"""
for exemption_class in self.exemption_classes:
if exemption_class(self, request).is_exempt():
return True

return super().allow_request(request, view)


class InternalNetworkExemption(ThrottleExemption):
redis_set_name = "ip-whitelist"

def is_exempt(self):
"""
Exempts requests coming from within Openverse's own
network. In practical terms this prevents the Nuxt server
from being rate-limited when server-side-rendering.
"""
ip = self.throttle_class.get_ident(self.request)
redis = get_redis_connection("default", write=False)
return redis.sismember(self.redis_set_name, ip)


class ApiKeyExemption(ThrottleExemption):
redis_set_name = "client-id-allowlist"

def is_exempt(self):
"""
Exempt certain API keys from throttling. In practical
terms this is used to prevent large consumers of
Openverse's API like WordPress.com and Jetpack from
being rate-limited.
"""
client_id, _, _ = get_token_info(str(self.request.auth))
if not client_id:
return False

redis = get_redis_connection("default")
return redis.sismember(self.redis_set_name, client_id)


class AnonRateThrottle(ExemptionAwareThrottle):
"""
Limits the rate of API calls that may be made by a anonymous users.
The IP address of the request will be used as the unique cache key.
"""

scope = "anon"
exemption_classes = [InternalNetworkExemption, ApiKeyExemption]

def get_cache_key(self, request, view):
if _from_internal_network(self.get_ident(request)):
return None
# Do not throttle requests with a valid access token.
if request.auth:
client_id, _, verified = get_token_info(str(request.auth))
Expand Down Expand Up @@ -61,7 +137,7 @@ class OnePerSecond(AnonRateThrottle):
rate = "1/second"


class OAuth2IdThrottleRate(SimpleRateThrottle):
class OAuth2IdThrottleRate(ExemptionAwareThrottle):
"""
Limits the rate of API calls that may be made by a given user's Oauth2
client ID. Can be configured to apply to either standard or enhanced
Expand All @@ -70,10 +146,9 @@ class OAuth2IdThrottleRate(SimpleRateThrottle):

scope = "oauth2_client_credentials"
applies_to_rate_limit_model = "standard"
exemption_classes = [InternalNetworkExemption, ApiKeyExemption]

def get_cache_key(self, request, view):
if _from_internal_network(self.get_ident(request)):
return None
# Find the client ID associated with the access token.
auth = str(request.auth)
client_id, rate_limit_model, verified = get_token_info(auth)
Expand Down
6 changes: 6 additions & 0 deletions api/test/factory/faker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from faker.utils.distribution import choices_distribution


class ChoiceProvider(BaseProvider):
def random_choice_field(self, choices):
return self.random_element(elements=[choice[0] for choice in choices])


class WaveformProvider(BaseProvider):
_float_space = [x / 100.0 for x in range(101)] * 20

Expand All @@ -14,4 +19,5 @@ def waveform(self) -> list[float]:
return WaveformProvider.generate_waveform()


Faker.add_provider(ChoiceProvider)
Faker.add_provider(WaveformProvider)
27 changes: 27 additions & 0 deletions api/test/factory/models/oauth2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from test.factory.faker import Faker

import factory
from catalog.api.models.oauth import ThrottledApplication
from factory.django import DjangoModelFactory
from oauth2_provider.models import AccessToken


class ThrottledApplicationFactory(DjangoModelFactory):
class Meta:
model = ThrottledApplication

client_type = Faker(
"random_choice_field", choices=ThrottledApplication.CLIENT_TYPES
)
authorization_grant_type = Faker(
"random_choice_field", choices=ThrottledApplication.GRANT_TYPES
)


class AccessTokenFactory(DjangoModelFactory):
class Meta:
model = AccessToken

token = Faker("uuid4")
expires = Faker("date_time_between", start_date="+1y", end_date="+2y")
application = factory.SubFactory(ThrottledApplicationFactory)
198 changes: 198 additions & 0 deletions api/test/unit/utils/throttle_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from test.factory.models.oauth2 import AccessTokenFactory

import pytest
from catalog.api.utils.oauth2_helper import get_token_info
from catalog.api.utils.throttle import (
ApiKeyExemption,
ExemptionAwareThrottle,
InternalNetworkExemption,
ThrottleExemption,
)
from fakeredis import FakeRedis
from rest_framework.response import Response
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.views import APIView


class HardThrottle(ExemptionAwareThrottle):
"""
A test throttle that denies all requests.
This is helpful for testing whether the exemptions
are working.
"""

rate = "0/second"
scope = "test"

def get_cache_key(self, request, view):
return {
"scope": self.scope,
"ident": self.get_ident(request),
}


class MockInternalNetworkExemptThrottle(HardThrottle):
exemption_classes = (InternalNetworkExemption,)


class MockApiKeyExemptThrottle(HardThrottle):
exemption_classes = (ApiKeyExemption,)


class FooRouteExemption(ThrottleExemption):
def is_exempt(self):
return self.request.path.startswith("/foo")


class MockMultipleExemptionThrottle(HardThrottle):
exemption_classes = (InternalNetworkExemption, ApiKeyExemption, FooRouteExemption)


def get_throttled_view(throttle_class):
class MockView(APIView):
throttle_classes = (throttle_class,)

def get(self, request):
return Response("foo")

return MockView().as_view()


@pytest.fixture(autouse=True)
def redis(monkeypatch) -> FakeRedis:
fake_redis = FakeRedis()

def get_redis_connection(*args, **kwargs):
return fake_redis

monkeypatch.setattr(
"catalog.api.utils.throttle.get_redis_connection", get_redis_connection
)

yield fake_redis
fake_redis.client().close()


@pytest.fixture
def request_factory() -> APIRequestFactory():
request_factory = APIRequestFactory(defaults={"REMOTE_ADDR": "192.0.2.1"})

return request_factory


@pytest.fixture
def access_token():
return AccessTokenFactory.create()


@pytest.fixture
def authed_request(access_token, request_factory):
request = request_factory.get("/")

force_authenticate(request, token=access_token.token)

return request


def assert_view_consistent_status_code(view, request, expected_status, times=4):
for _ in range(times):
assert view(request).status_code == expected_status


def assert_throttles(view, request, times=4):
assert_view_consistent_status_code(view, request, expected_status=429, times=times)


def assert_does_not_throttle(view, request, times=4):
assert_view_consistent_status_code(view, request, expected_status=200, times=times)


def test_hard_throttle_denies_requests(request_factory):
view = get_throttled_view(HardThrottle)
request = request_factory.get("/")
assert_throttles(view, request)


def test_internal_network_exemption_passes_when_ip_in_allowlist(redis, request_factory):
view = get_throttled_view(MockInternalNetworkExemptThrottle)
request = request_factory.get("/")
redis.sadd(InternalNetworkExemption.redis_set_name, request.META["REMOTE_ADDR"])
assert_does_not_throttle(view, request)


def test_internal_network_exemption_throttles_when_ip_not_in_allowlist(
redis, request_factory
):
view = get_throttled_view(MockInternalNetworkExemptThrottle)
request = request_factory.get("/")
assert not redis.sismember(
InternalNetworkExemption.redis_set_name, request.META["REMOTE_ADDR"]
)
assert_throttles(view, request)


@pytest.mark.django_db
def test_api_key_exemption_passes_when_token_in_allowlist(
redis, access_token, authed_request
):
view = get_throttled_view(MockApiKeyExemptThrottle)
client_id, _, _ = get_token_info(access_token.token)
redis.sadd(ApiKeyExemption.redis_set_name, client_id)
assert_does_not_throttle(view, authed_request)


@pytest.mark.django_db
def test_api_key_exemption_throttles_when_token_not_in_allowlist(
redis, access_token, authed_request
):
view = get_throttled_view(MockApiKeyExemptThrottle)
client_id, _, _ = get_token_info(access_token.token)
assert not redis.sismember(ApiKeyExemption.redis_set_name, client_id)
assert_throttles(view, authed_request)


@pytest.mark.django_db
def test_api_key_exemption_throttles_with_unauthed_request(request_factory):
request = request_factory.get("/")
view = get_throttled_view(MockApiKeyExemptThrottle)
assert_throttles(view, request)


@pytest.mark.django_db
def test_multiple_exemptions_allows_if_one_passes_api_key(
redis, access_token, authed_request
):
view = get_throttled_view(MockMultipleExemptionThrottle)
client_id, _, _ = get_token_info(access_token.token)
redis.sadd(ApiKeyExemption.redis_set_name, client_id)
assert not redis.sismember(
InternalNetworkExemption.redis_set_name, authed_request.META["REMOTE_ADDR"]
)
assert_does_not_throttle(view, authed_request)


@pytest.mark.django_db
def test_multiple_exemptions_allows_if_one_passes_internal_network(
redis, access_token, authed_request
):
view = get_throttled_view(MockMultipleExemptionThrottle)
client_id, _, _ = get_token_info(access_token.token)
redis.sadd(
InternalNetworkExemption.redis_set_name, authed_request.META["REMOTE_ADDR"]
)
assert not redis.sismember(ApiKeyExemption.redis_set_name, client_id)
assert_does_not_throttle(view, authed_request)


@pytest.mark.django_db
def test_multiple_exemptions_throttles_if_none_pass(
redis, access_token, authed_request
):
view = get_throttled_view(MockMultipleExemptionThrottle)
client_id, _, _ = get_token_info(access_token.token)
assert not redis.sismember(
InternalNetworkExemption.redis_set_name, authed_request.META["REMOTE_ADDR"]
)
assert not redis.sismember(ApiKeyExemption.redis_set_name, client_id)
assert_throttles(view, authed_request)

0 comments on commit 188b6a0

Please sign in to comment.