Skip to content

Commit

Permalink
Add middleware to log application name and verification status
Browse files Browse the repository at this point in the history
Also add authorization, app name and verified status to nginx logs
  • Loading branch information
sarayourfriend committed Nov 19, 2023
1 parent 3986e71 commit 75ed5ef
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 34 deletions.
16 changes: 16 additions & 0 deletions api/api/middleware/client_application_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from api.utils.oauth2_helper import get_token_info


def client_application_middleware(get_response):
def middleware(request):
response = get_response(request)

if hasattr(request, "auth") and request.auth:
token_info = get_token_info(str(request.auth))
if token_info:
response["x-ov-client-application-name"] = token_info.application_name
response["x-ov-client-application-verified"] = token_info.verified

return response

return middleware
33 changes: 24 additions & 9 deletions api/api/utils/oauth2_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime as dt
import logging
from dataclasses import dataclass

from oauth2_provider.models import AccessToken

Expand All @@ -8,10 +9,22 @@

parent_logger = logging.getLogger(__name__)

_no_result = (None, None, None)

@dataclass
class TokenInfo:
"""Extracted ``models.ThrottledApplication`` metadata."""

def get_token_info(token: str):
client_id: str
rate_limit_model: str
verified: bool
application_name: str

@property
def valid(self):
return self.client_id and self.verified


def get_token_info(token: str) -> None | TokenInfo:
"""
Recover an OAuth2 application client ID and rate limit model from an access token.
Expand All @@ -24,7 +37,7 @@ def get_token_info(token: str):
try:
token = AccessToken.objects.get(token=token)
except AccessToken.DoesNotExist:
return _no_result
return None

try:
application = models.ThrottledApplication.objects.get(accesstoken=token)
Expand All @@ -33,7 +46,7 @@ def get_token_info(token: str):
# In practice should never occur so long as the preceding
# operation to retrieve the access token was successful.
logger.critical("Failed to find application associated with access token.")
return _no_result
return None

expired = token.expires < dt.datetime.now(token.expires.tzinfo)
if expired:
Expand All @@ -42,9 +55,11 @@ def get_token_info(token: str):
f"application.name={application.name} "
f"application.client_id={application.client_id} "
)
return _no_result
return None

client_id = str(application.client_id)
rate_limit_model = application.rate_limit_model
verified = application.verified
return client_id, rate_limit_model, verified
return TokenInfo(
client_id=str(application.client_id),
rate_limit_model=application.rate_limit_model,
verified=application.verified,
application_name=application.name,
)
17 changes: 8 additions & 9 deletions api/api/utils/throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def headers(self):
contains the limit and the number of requests left in the limit. Since multiple
rate limits can apply concurrently, the suffix identifies each pair uniquely.
"""

prefix = "X-RateLimit"
suffix = self.scope or self.__class__.__name__.lower()
if hasattr(self, "history"):
Expand All @@ -53,8 +52,8 @@ def get_cache_key(self, request, view):
logger = self.logger.getChild("get_cache_key")
# Do not apply anonymous throttle to request with valid tokens.
if request.auth:
client_id, _, verified = get_token_info(str(request.auth))
if client_id and verified:
token_info = get_token_info(str(request.auth))
if token_info and token_info.valid:
return None

ident = self.get_ident(request)
Expand Down Expand Up @@ -113,14 +112,14 @@ class AbstractOAuth2IdRateThrottle(SimpleRateThrottleHeader, metaclass=abc.ABCMe
def get_cache_key(self, request, view):
# Find the client ID associated with the access token.
auth = str(request.auth)
client_id, rate_limit_model, verified = get_token_info(auth)
if client_id and rate_limit_model == self.applies_to_rate_limit_model:
ident = client_id
else:
# Return None, fallback to the anonymous rate limiting
token_info = get_token_info(auth)
if not (token_info and token_info.valid):
return None

if token_info.rate_limit_model != self.applies_to_rate_limit_model:
return None

return self.cache_format % {"scope": self.scope, "ident": ident}
return self.cache_format % {"scope": self.scope, "ident": token_info.client_id}


class OAuth2IdThumbnailRateThrottle(AbstractOAuth2IdRateThrottle):
Expand Down
13 changes: 10 additions & 3 deletions api/api/views/oauth2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,19 @@ def get(self, request, format=None):
return Response(status=403, data="Forbidden")

access_token = str(request.auth)
client_id, rate_limit_model, verified = get_token_info(access_token)
token_info = get_token_info(access_token)

if not token_info:
# This shouldn't happen if `request.auth` was true above,
# but better safe than sorry
return Response(status=403, data="Forbidden")

client_id = token_info.client_id

if not client_id:
return Response(status=403, data="Forbidden")

throttle_type = rate_limit_model
throttle_type = token_info.rate_limit_model
throttle_key = "throttle_{scope}_{client_id}"
if throttle_type == "standard":
sustained_throttle_key = throttle_key.format(
Expand Down Expand Up @@ -223,7 +230,7 @@ def get(self, request, format=None):
"requests_this_minute": burst_requests,
"requests_today": sustained_requests,
"rate_limit_model": throttle_type,
"verified": verified,
"verified": token_info.verified,
}
)
return Response(status=200, data=response_data.data)
1 change: 1 addition & 0 deletions api/conf/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
"api.middleware.client_application_middleware.client_application_middleware",
]

# Storage
Expand Down
8 changes: 6 additions & 2 deletions api/nginx.conf.template
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ log_format json_combined escape=json
'"host_header": "$host",'
'"body_bytes_sent":$body_bytes_sent,'
'"request_time":"$request_time",'
'"upstream_response_time":$upstream_response_time,'
'"http_referrer":"$http_referer",'
'"http_user_agent":"$http_user_agent",'
'"upstream_response_time":$upstream_response_time,'
'"http_x_forwarded_for":"$http_x_forwarded_for"'
'"http_x_forwarded_for":"$http_x_forwarded_for",'
'"http_authorization":"$http_authorization",'
'"request_id":"$sent_http_x_request_id",'
'"client_application_name":"$sent_http_x_ov_client_application_name",'
'"client_application_verified":"$sent_http_x_ov_client_application_verified"'
'}';

access_log /var/log/nginx/access.log json_combined;
Expand Down
60 changes: 49 additions & 11 deletions api/test/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,32 @@ def test_auth_token_exchange_unsupported_method(client):
assert res.json()["detail"] == 'Method "GET" not allowed.'


def _integration_verify_most_recent_token(client):
verify = OAuth2Verification.objects.last()
code = verify.code
path = reverse("verify-email", args=[code])
return client.get(path)


@pytest.mark.django_db
@pytest.mark.parametrize(
"rate_limit_model",
[x[0] for x in ThrottledApplication.RATE_LIMIT_MODELS],
)
@pytest.mark.skipif(
API_URL != "http://localhost:8000",
reason=(
"This test needs to cheat by looking in the database,"
" so it needs to skip in non-local environments where"
" that isn't possible."
),
)
def test_auth_email_verification(client, rate_limit_model, test_auth_token_exchange):
# This test needs to cheat by looking in the database, so it will be
# skipped in non-local environments.
if API_URL == "http://localhost:8000":
verify = OAuth2Verification.objects.last()
code = verify.code
path = reverse("verify-email", args=[code])
res = client.get(path)
assert res.status_code == 200
test_auth_rate_limit_reporting(
client, rate_limit_model, test_auth_token_exchange, verified=True
)
res = _integration_verify_most_recent_token(client)
assert res.status_code == 200
test_auth_rate_limit_reporting(
client, rate_limit_model, test_auth_token_exchange, verified=True
)


@pytest.mark.django_db
Expand Down Expand Up @@ -106,6 +115,35 @@ def test_auth_rate_limit_reporting(
assert res_data["verified"] is False


@pytest.mark.django_db
@pytest.mark.parametrize(
"verified",
(True, False),
)
def test_auth_response_headers(
client, verified, test_auth_tokens_registration, test_auth_token_exchange
):
if verified:
_integration_verify_most_recent_token(client)

token = test_auth_token_exchange["access_token"]

res = client.get("/v1/images/", HTTP_AUTHORIZATION=f"Bearer {token}")

assert (
res.headers["x-ov-client-application-name"]
== test_auth_tokens_registration["name"]
)
assert res.headers["x-ov-client-application-verified"] == str(verified)


def test_unauthed_response_headers(client):
res = client.get("/v1/images")

assert "x-ov-client-application-name" not in res.headers
assert "x-ov-client-application-verified" not in res.headers


@pytest.mark.django_db
@pytest.mark.parametrize(
"sort_dir, exp_indexed_on",
Expand Down

0 comments on commit 75ed5ef

Please sign in to comment.