Skip to content

Commit

Permalink
Convert oembed view to async
Browse files Browse the repository at this point in the history
  • Loading branch information
sarayourfriend committed Dec 5, 2023
1 parent bf80cf9 commit 7289ef4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 46 deletions.
14 changes: 8 additions & 6 deletions api/api/views/image_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

from django.conf import settings
from django.http.response import FileResponse, HttpResponse
from django.shortcuts import get_object_or_404
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
from rest_framework.response import Response

import requests
from drf_spectacular.utils import extend_schema, extend_schema_view
from PIL import Image as PILImage

Expand Down Expand Up @@ -36,6 +34,8 @@
)
from api.serializers.media_serializers import PaginatedRequestSerializer
from api.utils import image_proxy
from api.utils.aiohttp import get_aiohttp_session
from api.utils.asyncio import aget_object_or_404
from api.utils.watermark import watermark
from api.views.media_views import MediaViewSet

Expand Down Expand Up @@ -100,7 +100,7 @@ def tag_collection(self, request, tag, *_, **__):
url_name="oembed",
serializer_class=OembedSerializer,
)
def oembed(self, request, *_, **__):
async def oembed(self, request, *_, **__):
"""
Retrieve the structured data for a specified image URL as per the
[oEmbed spec](https://oembed.com/).
Expand All @@ -115,10 +115,12 @@ def oembed(self, request, *_, **__):
context = self.get_serializer_context()

identifier = params.validated_data["url"]
image = get_object_or_404(Image, identifier=identifier)
image = await aget_object_or_404(Image, identifier=identifier)
if not (image.height and image.width):
image_file = requests.get(image.url, headers=self.OEMBED_HEADERS)
width, height = PILImage.open(io.BytesIO(image_file.content)).size
session = await get_aiohttp_session()
image_file = await session.get(image.url, headers=self.OEMBED_HEADERS)
image_content = await image_file.content.read()
width, height = PILImage.open(io.BytesIO(image_content)).size
context |= {
"width": width,
"height": height,
Expand Down
51 changes: 11 additions & 40 deletions api/test/unit/views/test_image_views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import json
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from test.factory.models.image import ImageFactory
from unittest.mock import patch

import pook
import pytest
from PIL import UnidentifiedImageError
from requests import Request, Response
from requests import Response

from api.views.image_views import ImageViewSet

Expand All @@ -18,47 +16,20 @@
_MOCK_IMAGE_INFO = json.loads((_MOCK_IMAGE_PATH / "sample-image-info.json").read_text())


@dataclass
class RequestsFixture:
requests: list[Request]
response_factory: Callable[ # noqa: E731
[Request], Response
] = lambda x: RequestsFixture._default_response_factory(x)

@staticmethod
def _default_response_factory(req: Request) -> Response:
res = Response()
res.url = req.url
res.status_code = 200
res._content = _MOCK_IMAGE_BYTES
return res


@pytest.fixture
def requests(monkeypatch) -> RequestsFixture:
fixture = RequestsFixture([])

def requests_get(url, **kwargs):
req = Request(method="GET", url=url, **kwargs)
fixture.requests.append(req)
response = fixture.response_factory(req)
return response

monkeypatch.setattr("requests.get", requests_get)

return fixture


@pytest.mark.django_db
def test_oembed_sends_ua_header(api_client, requests):
def test_oembed_sends_ua_header(api_client):
image = ImageFactory.create()
res = api_client.get("/v1/images/oembed/", data={"url": f"/{image.identifier}"})

assert res.status_code == 200
with pook.use():
(
pook.get(image.url)
.header("User-Agent", ImageViewSet.OEMBED_HEADERS["User-Agent"])
.reply(200)
.body(_MOCK_IMAGE_BYTES, binary=True)
)
res = api_client.get("/v1/images/oembed/", data={"url": f"/{image.identifier}"})

assert len(requests.requests) > 0
for r in requests.requests:
assert r.headers == ImageViewSet.OEMBED_HEADERS
assert res.status_code == 200


@pytest.mark.django_db
Expand Down

0 comments on commit 7289ef4

Please sign in to comment.