Skip to content

Commit

Permalink
Merge branch 'main' into add/accesstoken-throttledapplication-adminpanel
Browse files Browse the repository at this point in the history
  • Loading branch information
samshahriari authored Feb 27, 2024
2 parents 852dd0a + 2cffcb9 commit eb27147
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 101 deletions.
17 changes: 10 additions & 7 deletions api/api/middleware/response_headers_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from api.utils.oauth2_helper import get_token_info
from rest_framework.request import Request

from api.models.oauth import ThrottledApplication


def response_headers_middleware(get_response):
Expand All @@ -11,14 +13,15 @@ def response_headers_middleware(get_response):
to identify malicious requesters or request patterns.
"""

def middleware(request):
def middleware(request: Request):
response = get_response(request)

if hasattr(request, "auth") and request.auth:
token_info = get_token_info(str(request.auth))
if token_info:
response["x-ov-client-application-name"] = token_info.application_name
response["x-ov-client-application-verified"] = token_info.verified
if not (hasattr(request, "auth") and hasattr(request.auth, "application")):
return response

application: ThrottledApplication = request.auth.application
response["x-ov-client-application-name"] = application.name
response["x-ov-client-application-verified"] = application.verified

return response

Expand Down
4 changes: 2 additions & 2 deletions api/api/serializers/media_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class PaginatedRequestSerializer(serializers.Serializer):

def validate_page_size(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
is_anonymous = getattr(request, "auth", None) is None
max_value = (
settings.MAX_ANONYMOUS_PAGE_SIZE
if is_anonymous
Expand Down Expand Up @@ -247,7 +247,7 @@ class MediaSearchRequestSerializer(PaginatedRequestSerializer):

def is_request_anonymous(self):
request = self.context.get("request")
return bool(request and request.user and request.user.is_anonymous)
return getattr(request, "auth", None) is None

@staticmethod
def _truncate(value):
Expand Down
65 changes: 0 additions & 65 deletions api/api/utils/oauth2_helper.py

This file was deleted.

20 changes: 11 additions & 9 deletions api/api/utils/throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from redis.exceptions import ConnectionError

from api.utils.oauth2_helper import get_token_info


parent_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,8 +45,11 @@ def has_valid_token(self, request):
if not request.auth:
return False

token_info = get_token_info(str(request.auth))
return token_info and token_info.valid
application = getattr(request.auth, "application", None)
if application is None:
return False

return application.client_id and application.verified

def get_cache_key(self, request, view):
return self.cache_format % {
Expand Down Expand Up @@ -146,15 +147,16 @@ class AbstractOAuth2IdRateThrottle(SimpleRateThrottle, metaclass=abc.ABCMeta):

def get_cache_key(self, request, view):
# Find the client ID associated with the access token.
auth = str(request.auth)
token_info = get_token_info(auth)
if not (token_info and token_info.valid):
if not self.has_valid_token(request):
return None

if token_info.rate_limit_model not in self.applies_to_rate_limit_model:
# `self.has_valid_token` call earlier ensures accessing `application` will not fail
application = request.auth.application

if application.rate_limit_model not in self.applies_to_rate_limit_model:
return None

return self.cache_format % {"scope": self.scope, "ident": token_info.client_id}
return self.cache_format % {"scope": self.scope, "ident": application.client_id}


class OAuth2IdThumbnailRateThrottle(AbstractOAuth2IdRateThrottle):
Expand Down
20 changes: 7 additions & 13 deletions api/api/views/oauth2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from django.conf import settings
from django.core.cache import cache
from django.core.mail import send_mail
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.views import APIView
Expand All @@ -22,7 +23,6 @@
OAuth2KeyInfoSerializer,
OAuth2RegistrationSerializer,
)
from api.utils.oauth2_helper import get_token_info
from api.utils.throttle import OnePerSecond, TenPerDay


Expand Down Expand Up @@ -169,7 +169,7 @@ class CheckRates(APIView):
throttle_classes = (OnePerSecond,)

@key_info
def get(self, request, format=None):
def get(self, request: Request, format=None):
"""
Get information about your API key.
Expand All @@ -181,23 +181,17 @@ def get(self, request, format=None):
"""

# TODO: Replace 403 responses with DRF `authentication_classes`.
if not request.auth:
if not request.auth or not hasattr(request.auth, "application"):
return Response(status=403, data="Forbidden")

access_token = str(request.auth)
token_info = get_token_info(access_token)
application: ThrottledApplication = request.auth.application

if not token_info:
# This shouldn't happen if `request.auth` was true above,
# but better safe than sorry
return Response(status=403, data="Forbidden")

client_id = token_info.client_id
client_id = application.client_id

if not client_id:
return Response(status=403, data="Forbidden")

throttle_type = token_info.rate_limit_model
throttle_type = application.rate_limit_model
throttle_key = "throttle_{scope}_{client_id}"
if throttle_type == "standard":
sustained_throttle_key = throttle_key.format(
Expand Down Expand Up @@ -242,7 +236,7 @@ def get(self, request, format=None):
"requests_this_minute": burst_requests,
"requests_today": sustained_requests,
"rate_limit_model": throttle_type,
"verified": token_info.verified,
"verified": application.verified,
}
)
return Response(status=status, data=response_data.data)
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def test_create_search_query_q_search_with_filters(image_media_type_config):
}
},
{"rank_feature": {"boost": 10000, "field": "standardized_popularity"}},
{"rank_feature": {"boost": 25000, "field": "authority_boost"}},
],
}

Expand Down
1 change: 1 addition & 0 deletions api/test/unit/utils/test_throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def enable_throttles(settings):
def access_token():
token = AccessTokenFactory.create()
token.application.verified = True
token.application.client_id = 123
token.application.save()
return token

Expand Down
8 changes: 7 additions & 1 deletion catalog/dags/common/loader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@ def copy_file_to_s3(
s3_prefix,
aws_conn_id,
ti,
extra_args=None,
):
"""
Copy a TSV file to S3 with the given prefix.
The TSV's version is pushed to the `tsv_version` XCom, and the constructed
S3 key is pushed to the `s3_key` XCom.
The TSV is removed after the upload is complete.
``extra_args`` refers to the S3Hook argument.
"""
if tsv_file_path is None:
raise FileNotFoundError("No TSV file path was provided")
Expand All @@ -57,7 +60,10 @@ def copy_file_to_s3(
tsv_version = paths.get_tsv_version(tsv_file_path)
s3_key = f"{s3_prefix}/{tsv_file.name}"
logger.info(f"Uploading {tsv_file_path} to {s3_bucket}:{s3_key}")
s3 = S3Hook(aws_conn_id=aws_conn_id)
s3 = S3Hook(
aws_conn_id=aws_conn_id,
extra_args=extra_args or {},
)
s3.load_file(tsv_file_path, s3_key, bucket_name=s3_bucket)
ti.xcom_push(key="tsv_version", value=tsv_version)
ti.xcom_push(key="s3_key", value=s3_key)
Expand Down
3 changes: 3 additions & 0 deletions catalog/dags/providers/provider_dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def append_day_shift(id_str):
else None,
),
"aws_conn_id": AWS_CONN_ID,
"extra_args": {
"StorageClass": conf.s3_tsv_storage_class,
},
},
trigger_rule=TriggerRule.NONE_SKIPPED,
)
Expand Down
43 changes: 40 additions & 3 deletions catalog/dags/providers/provider_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ class ProviderWorkflow:
tags: list[str] = field(default_factory=list)
overrides: list[TaskOverride] = field(default_factory=list)

# Set when the object is uploaded, even though we access the object later in
# the DAG. IA incurs additional retrieval fees per request, unlike plain
# standard storage. However, as of writing, that costs 0.001 USD (1/10th of
# a US cent) per 1k requests. In other words, a minuscule amount, considering
# we will access the object once later in the DAG, to upsert it to the DB,
# and then in all likelihood never access it again.
# Even if we did, and had to pay the retrieval fee, we would still come out
# ahead on storage costs, because IA is so much less expensive than regular
# storage. We could set the storage class in a later task in the DAG, to
# avoid the one time retrieval fee. However, that adds complexity to the DAG
# that we can avoid by eagerly setting the storage class early, and the actual
# savings would probably be nil, factoring in the time spent in standard storage
# incurring standard storage costs. If it absolutely needs to be rationalised,
# consider the amount of energy spent on the extra request to S3 to update the
# storage cost to try to get around a retrieval fee (which, again, will not
# actually cost more, all things considered). Saving that energy could melt
# the glaciers all that much more slowly.
s3_tsv_storage_class: str = "STANDARD_IA"

def _get_module_info(self):
# Get the module the ProviderDataIngester was defined in
provider_script = inspect.getmodule(self.ingester_class)
Expand All @@ -186,12 +205,30 @@ def __post_init__(self):
if not self.doc_md:
self.doc_md = provider_script.__doc__

# Check for custom configuration overrides, which will be applied when
# the DAG is generated.
self._process_configuration_overrides()

def _process_configuration_overrides(self):
"""
Check for and apply custom configuration overrides.
These are only applied when the DAG is generated.
"""

# Provider-specific configuration overrides
self.overrides = Variable.get(
"CONFIGURATION_OVERRIDES", default_var={}, deserialize_json=True
).get(self.dag_id, [])

# Allow forcing the default to something other than `STANDARD_IA`
# Primarily meant for use in local development where minio is used
# which does not support all AWS storage classes
# https://github.com/minio/minio/issues/5469
# This intentionally applies to all providers, rather than the provider-specific
# overrides above
self.s3_tsv_storage_class = Variable.get(
"DEFAULT_S3_TSV_STORAGE_CLASS", default_var=self.s3_tsv_storage_class
)


PROVIDER_WORKFLOWS = [
ProviderWorkflow(
Expand All @@ -218,7 +255,7 @@ def __post_init__(self):
start_date=datetime(2022, 10, 27),
schedule_string="@daily",
dated=True,
pull_timeout=timedelta(weeks=1),
pull_timeout=timedelta(days=12),
),
ProviderWorkflow(
ingester_class=FinnishMuseumsDataIngester,
Expand Down
6 changes: 6 additions & 0 deletions catalog/env.template
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ AIRFLOW_CONN_SLACK_NOTIFICATIONS=https://slack
AIRFLOW_CONN_SLACK_ALERTS=https://slack

S3_LOCAL_ENDPOINT=http://s3:5000
# Set to a non-default value supported by minio in local development to workaround
# Minio's lack of support for all AWS storage classes, while still using a non-default
# value so that the expected behaviour can be verified (specifically, that the storage
# class is not the default "STANDARD")
# https://github.com/minio/minio/issues/5469
AIRFLOW_VAR_DEFAULT_S3_TSV_STORAGE_CLASS=REDUCED_REDUNDANCY

# Connection to the Ingestion Server, used for managing data refreshes. Default is used to
# connect to your locally running ingestion server.
Expand Down
1 change: 1 addition & 0 deletions catalog/tests/dags/providers/test_provider_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_overrides(configuration_overrides, expected_overrides):
with mock.patch("providers.provider_workflows.Variable") as MockVariable:
MockVariable.get.side_effect = [
configuration_overrides,
MockVariable.get_original()[0],
]
test_workflow = ProviderWorkflow(
dag_id="my_dag_id",
Expand Down
6 changes: 6 additions & 0 deletions documentation/changelogs/frontend/2024.02.26.18.58.35.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# 2024.02.26.18.58.35

## Bug Fixes

- Fix sketchfab loading state error
([#3794](https://github.com/WordPress/openverse/pull/3794)) by @zackkrida
4 changes: 4 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ init:
down *flags:
just dc down {{ flags }}
# Take all services down then call the specified app's up recipe. ex.: `just dup catalog` is useful for restarting the catalog with new environment variables
dup app:
just down && just {{ app }}/up
# Recreate all volumes and containers from scratch
recreate:
just down -v
Expand Down

0 comments on commit eb27147

Please sign in to comment.