diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 955466bcf57..3783169db94 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -2,12 +2,10 @@ from django.conf import settings from django.http.response import FileResponse, HttpResponse -from django.shortcuts import get_object_or_404 from rest_framework.decorators import action from rest_framework.exceptions import NotFound from rest_framework.response import Response -import requests from drf_spectacular.utils import extend_schema, extend_schema_view from PIL import Image as PILImage @@ -36,6 +34,8 @@ ) from api.serializers.media_serializers import PaginatedRequestSerializer from api.utils import image_proxy +from api.utils.aiohttp import get_aiohttp_session +from api.utils.asyncio import aget_object_or_404 from api.utils.watermark import watermark from api.views.media_views import MediaViewSet @@ -100,7 +100,7 @@ def tag_collection(self, request, tag, *_, **__): url_name="oembed", serializer_class=OembedSerializer, ) - def oembed(self, request, *_, **__): + async def oembed(self, request, *_, **__): """ Retrieve the structured data for a specified image URL as per the [oEmbed spec](https://oembed.com/). @@ -115,10 +115,12 @@ def oembed(self, request, *_, **__): context = self.get_serializer_context() identifier = params.validated_data["url"] - image = get_object_or_404(Image, identifier=identifier) + image = await aget_object_or_404(Image, identifier=identifier) if not (image.height and image.width): - image_file = requests.get(image.url, headers=self.OEMBED_HEADERS) - width, height = PILImage.open(io.BytesIO(image_file.content)).size + session = await get_aiohttp_session() + image_file = await session.get(image.url, headers=self.OEMBED_HEADERS) + image_content = await image_file.content.read() + width, height = PILImage.open(io.BytesIO(image_content)).size context |= { "width": width, "height": height, diff --git a/api/test/unit/views/test_image_views.py b/api/test/unit/views/test_image_views.py index dc258dfc246..247dabf8866 100644 --- a/api/test/unit/views/test_image_views.py +++ b/api/test/unit/views/test_image_views.py @@ -1,6 +1,4 @@ import json -from collections.abc import Callable -from dataclasses import dataclass from pathlib import Path from test.factory.models.image import ImageFactory from unittest.mock import patch @@ -8,7 +6,7 @@ import pook import pytest from PIL import UnidentifiedImageError -from requests import Request, Response +from requests import Response from api.views.image_views import ImageViewSet @@ -18,47 +16,20 @@ _MOCK_IMAGE_INFO = json.loads((_MOCK_IMAGE_PATH / "sample-image-info.json").read_text()) -@dataclass -class RequestsFixture: - requests: list[Request] - response_factory: Callable[ # noqa: E731 - [Request], Response - ] = lambda x: RequestsFixture._default_response_factory(x) - - @staticmethod - def _default_response_factory(req: Request) -> Response: - res = Response() - res.url = req.url - res.status_code = 200 - res._content = _MOCK_IMAGE_BYTES - return res - - -@pytest.fixture -def requests(monkeypatch) -> RequestsFixture: - fixture = RequestsFixture([]) - - def requests_get(url, **kwargs): - req = Request(method="GET", url=url, **kwargs) - fixture.requests.append(req) - response = fixture.response_factory(req) - return response - - monkeypatch.setattr("requests.get", requests_get) - - return fixture - - @pytest.mark.django_db -def test_oembed_sends_ua_header(api_client, requests): +def test_oembed_sends_ua_header(api_client): image = ImageFactory.create() - res = api_client.get("/v1/images/oembed/", data={"url": f"/{image.identifier}"}) - assert res.status_code == 200 + with pook.use(): + ( + pook.get(image.url) + .header("User-Agent", ImageViewSet.OEMBED_HEADERS["User-Agent"]) + .reply(200) + .body(_MOCK_IMAGE_BYTES, binary=True) + ) + res = api_client.get("/v1/images/oembed/", data={"url": f"/{image.identifier}"}) - assert len(requests.requests) > 0 - for r in requests.requests: - assert r.headers == ImageViewSet.OEMBED_HEADERS + assert res.status_code == 200 @pytest.mark.django_db