diff --git a/api/api/serializers/image_serializers.py b/api/api/serializers/image_serializers.py index 4f0af390d26..5824198936c 100644 --- a/api/api/serializers/image_serializers.py +++ b/api/api/serializers/image_serializers.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID from rest_framework import serializers @@ -124,7 +125,15 @@ class OembedRequestSerializer(serializers.Serializer): @staticmethod def validate_url(value): - return add_protocol(value) + url = add_protocol(value) + if url.endswith("/"): + url = url[:-1] + identifier = url.rsplit("/", 1)[1] + try: + uuid = UUID(identifier) + except ValueError: + raise serializers.ValidationError("Could not parse identifier from URL.") + return uuid class OembedSerializer(BaseModelSerializer): diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 95844ffbf97..6f138908db8 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -83,10 +83,7 @@ def oembed(self, request, *_, **__): context = self.get_serializer_context() - url = params.validated_data["url"] - if url.endswith("/"): - url = url[:-1] - identifier = url.rsplit("/", 1)[1] + identifier = params.validated_data["url"] image = get_object_or_404(Image, identifier=identifier) if not (image.height and image.width): image_file = requests.get(image.url, headers=self.OEMBED_HEADERS) diff --git a/api/test/test_image_integration.py b/api/test/test_image_integration.py index 2ed9d9b0bc3..e8b7d479913 100644 --- a/api/test/test_image_integration.py +++ b/api/test/test_image_integration.py @@ -92,6 +92,16 @@ def test_oembed_endpoint_with_non_existent_image(): assert response.status_code == 404 +def test_oembed_endpoint_with_bad_identifier(): + params = { + "url": "https://any.domain/any/path/not-a-valid-uuid", + } + response = requests.get( + f"{API_URL}/v1/images/oembed?{urlencode(params)}", verify=False + ) + assert response.status_code == 400 + + @pytest.mark.parametrize( "url", [