Skip to content

Commit

Permalink
Add warning to search response when source parameter has mixed validi…
Browse files Browse the repository at this point in the history
…ty (#4031)

* Add warning to search response when source parameter has mixed validity

* Use fixture source list

* Move warnings to top and use clearer message
  • Loading branch information
sarayourfriend authored Apr 8, 2024
1 parent cc475df commit e0e0e27
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 40 deletions.
47 changes: 39 additions & 8 deletions api/api/serializers/media_serializers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import logging
from collections import namedtuple
from typing import TypedDict

from django.conf import settings
from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.validators import MaxValueValidator
from django.urls import reverse
from rest_framework import serializers
from rest_framework.exceptions import NotAuthenticated, ValidationError
from rest_framework.request import Request

from drf_spectacular.utils import extend_schema_serializer
from elasticsearch_dsl.response import Hit

from api.constants import sensitivity
from api.constants.licenses import LICENSE_GROUPS
from api.constants.media_types import MediaType
from api.constants.parameters import COLLECTION, TAG
from api.constants.sorting import DESCENDING, RELEVANCE, SORT_DIRECTIONS, SORT_FIELDS
from api.controllers import search_controller
Expand Down Expand Up @@ -294,8 +298,16 @@ class MediaSearchRequestSerializer(PaginatedRequestSerializer):
required=False,
)

class Context(TypedDict, total=True):
warnings: list[dict]
media_type: MediaType
request: Request

context: Context

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.context["warnings"] = []
self.media_type = self.context.get("media_type")
if not self.media_type:
raise ValueError(
Expand Down Expand Up @@ -398,15 +410,34 @@ def validate_source(self, value):
)
return value
else:
sources = value.lower().split(",")
valid_sources = set(
[source for source in sources if source in allowed_sources]
)
if len(sources) > len(valid_sources):
invalid_sources = set(sources).difference(valid_sources)
logger.warning(
f"Invalid sources in search query: {invalid_sources}; sources query: '{value}'"
sources = set(value.lower().split(","))
valid_sources = {source for source in sources if source in allowed_sources}
if not valid_sources:
# Raise only if there are _no_ valid sources selected
# If the requester passed only `mispelled_museum_name1,mispelled_musesum_name2`
# the request cannot move forward, as all the top responses will likely be from Flickr
# which provides radically different responses than most other providers.
# If even one source is valid, it won't be a problem, in which case we'll issue a warning
raise serializers.ValidationError(
f"Invalid source parameter '{value}'. No valid sources selected. "
f"Refer to the source list for valid options: {sources_list}."
)
elif invalid_sources := (sources - valid_sources):
available_sources_uri = self.context["request"].build_absolute_uri(
reverse(f"{self.media_type}-stats")
)
self.context["warnings"].append(
{
"code": "partially invalid source parameter",
"message": (
"The source parameter included non-existent sources. "
f"For a list of available sources, see {available_sources_uri}"
),
"invalid_sources": invalid_sources,
"valid_sources": valid_sources,
}
)

return ",".join(valid_sources)

def validate_excluded_source(self, input_sources):
Expand Down
77 changes: 61 additions & 16 deletions api/api/utils/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,36 @@ class StandardPagination(PageNumberPagination):
page_size_query_param = "page_size"
page_query_param = "page"

result_count: int | None
page_count: int | None
page: int
warnings: list[dict]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.result_count = None # populated later
self.page_count = None # populated later
self.page = 1 # default, gets updated when necessary
self.warnings = [] # populated later as needed

def get_paginated_response(self, data):
response = {
"result_count": self.result_count,
"page_count": min(settings.MAX_PAGINATION_DEPTH, self.page_count),
"page_size": self.page_size,
"page": self.page,
"results": data,
}
return Response(
{
"result_count": self.result_count,
"page_count": min(settings.MAX_PAGINATION_DEPTH, self.page_count),
"page_size": self.page_size,
"page": self.page,
"results": data,
}
(
{
# Put ``warnings`` first so it is as visible as possible.
"warnings": list(self.warnings),
}
if self.warnings
else {}
)
| response
)

def get_paginated_response_schema(self, schema):
Expand All @@ -39,15 +54,45 @@ def get_paginated_response_schema(self, schema):
"page_size": ("The number of items per page.", 20),
"page": ("The current page number returned in the response.", 1),
}

properties = {
field: {
"type": "integer",
"description": description,
"example": example,
}
for field, (description, example) in field_descriptions.items()
} | {
"results": schema,
"warnings": {
"type": "array",
"items": {
"type": "object",
},
"description": (
"Warnings pertinent to the request. "
"If there are no warnings, this property will not be present on the response. "
"Warnings are non-critical problems with the request. "
"Responses with warnings should be treated as unstable. "
"Warning descriptions must not be treated as machine readable "
"and their schema can change at any time."
),
"example": [
{
"code": "partially invalid request parameter",
"message": (
"Some of the request parameters were bad, "
"but we processed the request anywhere. "
"Here's some information that might help you "
"fix the problem for future requests."
),
}
],
},
}

return {
"type": "object",
"properties": {
field: {
"type": "integer",
"description": description,
"example": example,
}
for field, (description, example) in field_descriptions.items()
}
| {"results": schema},
"properties": properties,
"required": list(set(properties.keys()) - {"warnings"}),
}
1 change: 1 addition & 0 deletions api/api/views/media_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def get_media_results(
):
page_size = self.paginator.page_size = params.data["page_size"]
page = self.paginator.page = params.data["page"]
self.paginator.warnings = params.context["warnings"]

hashed_ip = hash(self._get_user_ip(request))
filter_dead = params.validated_data.get("filter_dead", True)
Expand Down
2 changes: 1 addition & 1 deletion api/test/fixtures/rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@pytest.fixture
def api_client():
def api_client() -> APIClient:
return APIClient()


Expand Down
44 changes: 44 additions & 0 deletions api/test/integration/test_media_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,50 @@ def test_detail_view_for_invalid_uuids_returns_not_found(
assert res.status_code == 404


def test_search_with_only_valid_sources_produces_no_warning(media_type, api_client):
search = api_client.get(
f"/v1/{media_type.path}/",
{"source": ",".join(media_type.providers)},
)
assert search.status_code == 200
assert "warnings" not in search.json()


def test_search_with_partially_invalid_sources_produces_warning_but_still_succeeds(
media_type: MediaType, api_client
):
invalid_sources = [
"surely_neither_this_one",
"this_is_sure_not_to_ever_be_a_real_source_name",
]

search = api_client.get(
f"/v1/{media_type.path}/",
{"source": ",".join([media_type.providers[0]] + invalid_sources)},
)
assert search.status_code == 200
result = search.json()

assert {w["code"] for w in result["warnings"]} == {
"partially invalid source parameter"
}
warning = result["warnings"][0]
assert set(warning["invalid_sources"]) == set(invalid_sources)
assert warning["valid_sources"] == [media_type.providers[0]]
assert f"v1/{media_type.path}/stats/" in warning["message"]


def test_search_with_all_invalid_sources_fails(media_type, api_client):
invalid_sources = [
"this_is_sure_not_to_ever_be_a_real_source_name",
"surely_neither_this_one",
]
search = api_client.get(
f"/v1/{media_type.path}/", {"source": ",".join(invalid_sources)}
)
assert search.status_code == 400


def test_detail_view_returns_ok(single_result, api_client):
media_type, item = single_result
res = api_client.get(f"/v1/{media_type.path}/{item['id']}/")
Expand Down
28 changes: 13 additions & 15 deletions api/test/unit/serializers/test_media_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,23 @@ def test_media_serializer_adds_license_url_if_missing(
assert repr["license_url"] == "https://creativecommons.org/publicdomain/zero/1.0/"


def test_media_serializer_logs_when_invalid_or_duplicate_source(media_type_config):
def test_media_serializer_recovers_invalid_or_duplicate_source(
media_type_config, request_factory
):
sources = {
"image": ("flickr,flickr,invalid", "flickr"),
"audio": ("freesound,freesound,invalid", "freesound"),
}
with patch("api.serializers.media_serializers.logger.warning") as mock_logger:
serializer_class = media_type_config.search_request_serializer(
context={"media_type": media_type_config.media_type},
data={"source": sources[media_type_config.media_type][0]},
)
assert serializer_class.is_valid()
assert (
serializer_class.validated_data["source"]
== sources[media_type_config.media_type][1]
)
mock_logger.assert_called_with(
f"Invalid sources in search query: {{'invalid'}}; "
f"sources query: '{sources[media_type_config.media_type][0]}'"
)
request = request_factory.get("/v1/images/")
serializer_class = media_type_config.search_request_serializer(
context={"media_type": media_type_config.media_type, "request": request},
data={"source": sources[media_type_config.media_type][0]},
)
assert serializer_class.is_valid()
assert (
serializer_class.validated_data["source"]
== sources[media_type_config.media_type][1]
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit e0e0e27

Please sign in to comment.