Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Oembed endpoint validation onto the serializer #3069

Merged
merged 5 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions api/api/serializers/image_serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal
from uuid import UUID

from django.core.exceptions import ValidationError
from rest_framework import serializers

from api.constants.field_order import field_position_map
Expand All @@ -16,7 +17,6 @@
get_hyperlinks_serializer,
get_search_request_source_serializer,
)
from api.utils.url import add_protocol


#######################
Expand Down Expand Up @@ -119,21 +119,35 @@ class Meta:
class OembedRequestSerializer(serializers.Serializer):
"""Parse and validate oEmbed parameters."""

url = serializers.CharField(
url = serializers.URLField(
allow_blank=False,
help_text="The link to an image present in Openverse.",
)

@staticmethod
def validate_url(value):
url = add_protocol(value)
def to_internal_value(self, data):
data = super().to_internal_value(data)

url = data["url"]
if url.endswith("/"):
url = url[:-1]
identifier = url.rsplit("/", 1)[1]

try:
uuid = UUID(identifier)
except ValueError:
raise serializers.ValidationError("Could not parse identifier from URL.")
return uuid
raise serializers.ValidationError(
{"Could not parse identifier from URL.": data["url"]}
)

try:
image = Image.objects.get(identifier=uuid)
except (Image.DoesNotExist, ValidationError):
raise serializers.ValidationError(
{"Could not find image from the provided URL": data["url"]}
)

data["image"] = image
return data


class OembedSerializer(BaseModelSerializer):
Expand Down
5 changes: 1 addition & 4 deletions api/api/views/image_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from django.conf import settings
from django.http.response import FileResponse, HttpResponse
from django.shortcuts import get_object_or_404
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
Expand Down Expand Up @@ -111,11 +110,9 @@ def oembed(self, request, *_, **__):

params = OembedRequestSerializer(data=request.query_params)
params.is_valid(raise_exception=True)

image = params.validated_data["image"]
context = self.get_serializer_context()

identifier = params.validated_data["url"]
image = get_object_or_404(Image, identifier=identifier)
if not (image.height and image.width):
image_file = requests.get(image.url, headers=self.OEMBED_HEADERS)
width, height = PILImage.open(io.BytesIO(image_file.content)).size
Expand Down
33 changes: 7 additions & 26 deletions api/test/test_image_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,40 +85,21 @@ def test_audio_report(image_fixture):
report("images", image_fixture)


def test_oembed_endpoint_with_non_existent_image():
params = {
"url": "https://any.domain/any/path/00000000-0000-0000-0000-000000000000",
}
response = requests.get(
f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False
)
assert response.status_code == 404


def test_oembed_endpoint_with_bad_identifier():
params = {
"url": "https://any.domain/any/path/not-a-valid-uuid",
}
response = requests.get(
f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False
)
assert response.status_code == 400


@pytest.mark.parametrize(
"url",
"url, expected_status_code",
[
f"https://any.domain/any/path/{identifier}", # no trailing slash
f"https://any.domain/any/path/{identifier}/", # trailing slash
identifier, # just identifier instead of URL
Copy link
Member

@dhruvkb dhruvkb Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific test case (just the identifier) is worth having as something that was earlier responding with 200 but now 400?

(f"https://any.domain/any/path/{identifier}", 200), # no trailing slash
(f"https://any.domain/any/path/{identifier}/", 200), # trailing slash
("https://any.domain/any/path/00000000-0000-0000-0000-000000000000", 400),
("https://any.domain/any/path/not-a-valid-uuid", 400),
],
)
def test_oembed_endpoint_with_fuzzy_input(url):
def test_oembed_endpoint(url, expected_status_code):
params = {"url": url}
response = requests.get(
f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False
)
assert response.status_code == 200
assert response.status_code == expected_status_code


def test_oembed_endpoint_for_json():
Expand Down
4 changes: 3 additions & 1 deletion api/test/unit/views/test_image_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from test.constants import API_URL
from test.factory.models.image import ImageFactory
from unittest.mock import patch

Expand Down Expand Up @@ -52,7 +53,8 @@ def requests_get(url, **kwargs):
@pytest.mark.django_db
def test_oembed_sends_ua_header(api_client, requests):
image = ImageFactory.create()
res = api_client.get("/v1/images/oembed/", data={"url": f"/{image.identifier}"})
image.url = f"https://any.domain/any/path/{image.identifier}"
res = api_client.get(f"{API_URL}/v1/images/oembed/", data={"url": image.url})

assert res.status_code == 200

Expand Down