Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert oembed to async #3458

Merged
merged 2 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions api/api/serializers/image_serializers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Literal
from uuid import UUID

from django.core.exceptions import ValidationError
from rest_framework import serializers

from api.constants.field_order import field_position_map
Expand Down Expand Up @@ -139,14 +138,7 @@ def to_internal_value(self, data):
{"Could not parse identifier from URL.": data["url"]}
)

try:
image = Image.objects.get(identifier=uuid)
except (Image.DoesNotExist, ValidationError):
raise serializers.ValidationError(
{"Could not find image from the provided URL": data["url"]}
)

data["image"] = image
data["identifier"] = uuid
return data


Expand Down
7 changes: 7 additions & 0 deletions api/api/utils/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import logging
from collections.abc import Awaitable

from rest_framework.generics import get_object_or_404

from asgiref.sync import sync_to_async


parent_logger = logging.getLogger(__name__)

Expand All @@ -27,3 +31,6 @@ def do_not_wait_for(awaitable: Awaitable) -> None:
raise exc

loop.create_task(awaitable)


aget_object_or_404 = sync_to_async(get_object_or_404)
15 changes: 10 additions & 5 deletions api/api/views/image_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from rest_framework.exceptions import NotFound
from rest_framework.response import Response

import requests
from drf_spectacular.utils import extend_schema, extend_schema_view
from PIL import Image as PILImage

Expand Down Expand Up @@ -35,6 +34,8 @@
)
from api.serializers.media_serializers import PaginatedRequestSerializer
from api.utils import image_proxy
from api.utils.aiohttp import get_aiohttp_session
from api.utils.asyncio import aget_object_or_404
from api.utils.watermark import watermark
from api.views.media_views import MediaViewSet

Expand Down Expand Up @@ -99,7 +100,7 @@ def tag_collection(self, request, tag, *_, **__):
url_name="oembed",
serializer_class=OembedSerializer,
)
def oembed(self, request, *_, **__):
async def oembed(self, request, *_, **__):
"""
Retrieve the structured data for a specified image URL as per the
[oEmbed spec](https://oembed.com/).
Expand All @@ -110,12 +111,16 @@ def oembed(self, request, *_, **__):

params = OembedRequestSerializer(data=request.query_params)
params.is_valid(raise_exception=True)
image = params.validated_data["image"]
identifier = params.validated_data["identifier"]
context = self.get_serializer_context()

image = await aget_object_or_404(Image, identifier=identifier)

if not (image.height and image.width):
image_file = requests.get(image.url, headers=self.OEMBED_HEADERS)
width, height = PILImage.open(io.BytesIO(image_file.content)).size
session = await get_aiohttp_session()
image_file = await session.get(image.url, headers=self.OEMBED_HEADERS)
image_content = await image_file.content.read()
width, height = PILImage.open(io.BytesIO(image_content)).size
context |= {
"width": width,
"height": height,
Expand Down
24 changes: 20 additions & 4 deletions api/test/test_image_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,26 @@ def test_audio_report(image_fixture):
@pytest.mark.parametrize(
"url, expected_status_code",
[
(f"https://any.domain/any/path/{identifier}", 200), # no trailing slash
(f"https://any.domain/any/path/{identifier}/", 200), # trailing slash
("https://any.domain/any/path/00000000-0000-0000-0000-000000000000", 400),
("https://any.domain/any/path/not-a-valid-uuid", 400),
pytest.param(
f"https://any.domain/any/path/{identifier}",
200,
id="OK; no trailing slash",
),
pytest.param(
f"https://any.domain/any/path/{identifier}/",
200,
id="OK; with trailing slash",
), # trailing slash
pytest.param(
"https://any.domain/any/path/00000000-0000-0000-0000-000000000000",
404,
id="Valid UUID but no matching identifier",
),
pytest.param(
"https://any.domain/any/path/not-a-valid-uuid",
400,
id="not a valid UUID",
),
],
)
def test_oembed_endpoint(url, expected_status_code):
Expand Down
53 changes: 12 additions & 41 deletions api/test/unit/views/test_image_views.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import json
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from test.constants import API_URL
from test.factory.models.image import ImageFactory
from unittest.mock import patch

import pook
import pytest
from PIL import UnidentifiedImageError
from requests import Request, Response
from requests import Response

from api.views.image_views import ImageViewSet

Expand All @@ -19,48 +16,22 @@
_MOCK_IMAGE_INFO = json.loads((_MOCK_IMAGE_PATH / "sample-image-info.json").read_text())


@dataclass
class RequestsFixture:
requests: list[Request]
response_factory: Callable[ # noqa: E731
[Request], Response
] = lambda x: RequestsFixture._default_response_factory(x)

@staticmethod
def _default_response_factory(req: Request) -> Response:
res = Response()
res.url = req.url
res.status_code = 200
res._content = _MOCK_IMAGE_BYTES
return res


@pytest.fixture
def requests(monkeypatch) -> RequestsFixture:
fixture = RequestsFixture([])

def requests_get(url, **kwargs):
req = Request(method="GET", url=url, **kwargs)
fixture.requests.append(req)
response = fixture.response_factory(req)
return response

monkeypatch.setattr("requests.get", requests_get)

return fixture


@pytest.mark.django_db
def test_oembed_sends_ua_header(api_client, requests):
def test_oembed_sends_ua_header(api_client):
image = ImageFactory.create()
image.url = f"https://any.domain/any/path/{image.identifier}"
res = api_client.get(f"{API_URL}/v1/images/oembed/", data={"url": image.url})
image.save()

assert res.status_code == 200
with pook.use():
(
pook.get(image.url)
.header("User-Agent", ImageViewSet.OEMBED_HEADERS["User-Agent"])
.reply(200)
.body(_MOCK_IMAGE_BYTES, binary=True)
)
res = api_client.get("/v1/images/oembed/", data={"url": image.url})

assert len(requests.requests) > 0
for r in requests.requests:
assert r.headers == ImageViewSet.OEMBED_HEADERS
assert res.status_code == 200


@pytest.mark.django_db
Expand Down