diff --git a/api/Pipfile b/api/Pipfile index a7b60e79f7b..5b2d2cc9c04 100644 --- a/api/Pipfile +++ b/api/Pipfile @@ -44,6 +44,7 @@ requests-oauthlib = "~=1.3" sentry-sdk = "~=1.30" django-split-settings = "*" uvicorn = {extras = ["standard"], version = "*"} +adrf = "~=0.1.2" [requires] python_version = "3.11" diff --git a/api/Pipfile.lock b/api/Pipfile.lock index 5425bc9615a..ae41e51256e 100644 --- a/api/Pipfile.lock +++ b/api/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "cf0e53b4aa0e1d9d0a2b16d8cebff02cad6cbd6a16b8a059bded5eece3db701b" + "sha256": "fad35000b1e0ea01032877a6727a1da3937b36b516f4bbe36f3d38f75cb307dc" }, "pipfile-spec": 6, "requires": { @@ -16,6 +16,14 @@ ] }, "default": { + "adrf": { + "hashes": [ + "sha256:a33f8f51f0f80072ffb2af061df1fb119bc00adaa720a2972049d4aa33155337", + "sha256:ce7160878ba27999d333752941cde0687c1a205fc26fa0eda1bad3924958dc69" + ], + "index": "pypi", + "version": "==0.1.2" + }, "aiohttp": { "hashes": [ "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67", @@ -133,6 +141,13 @@ "markers": "python_version >= '3.7'", "version": "==3.7.2" }, + "async-property": { + "hashes": [ + "sha256:17d9bd6ca67e27915a75d92549df64b5c7174e9dc806b30a3934dc4ff0506380", + "sha256:8924d792b5843994537f8ed411165700b27b2bd966cefc4daeefc1253442a9d7" + ], + "version": "==0.2.2" + }, "async-timeout": { "hashes": [ "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f", @@ -159,19 +174,19 @@ }, "boto3": { "hashes": [ - "sha256:c53c92dfe22489ba31e918c2e7b59ff43e2e778bd3d3559e62351a739382bb5c", - "sha256:eea3b07e0f28c9f92bccab972af24a3b0dd951c69d93da75227b8ecd3e18f6c4" + "sha256:04445d70127c25fad69e2cab7e3f5cb219c8d6e60463af3657f20e29ac517957", + "sha256:2ca2852f7b7c1bc2e56f10f968d4c8483c8228b935ecd89a444ae8292ad0dc24" ], "index": "pypi", - "version": "==1.28.44" + "version": "==1.28.46" }, "botocore": { "hashes": [ - "sha256:83d61c1ca781e6ede19fcc4d5dd73004eee3825a2b220f0d7727e32069209d98", - "sha256:84f90919fecb4a4f417fd10145c8a87ff2c4b14d6381cd34d9babf02110b3315" + "sha256:6c30be3371624a80d6a881d9c7771a80e0eb82697ee374aaf522cd59b76e14dd", + "sha256:ac0c1258b1782cde42950bd00138fdce6bd7d04855296af8c326d5844a426473" ], "markers": "python_version >= '3.7'", - "version": "==1.31.44" + "version": "==1.31.46" }, "certifi": { "hashes": [ @@ -370,11 +385,11 @@ }, "deepdiff": { "hashes": [ - "sha256:065cdbbe62f66447cf507b32351579ffcc4a80bb28f567ac27e92a21ddca99f9", - "sha256:744c4e54ff83eaa77a995b3311dccdce6ee67773335a34a5ef269fa048005457" + "sha256:080b1359d6128f3f5f1738c6be3064f0ad9b0cc41994aa90a028065f6ad11f25", + "sha256:acdc1651a3e802415e0337b7e1192df5cd7c17b72fbab480466fdd799b9a72e7" ], "index": "pypi", - "version": "==6.4.1" + "version": "==6.5.0" }, "deprecated": { "hashes": [ @@ -1545,11 +1560,11 @@ }, "faker": { "hashes": [ - "sha256:7cf705758f6cc5dd31f628e323f306a6d881e9a8a103f1e32e5f30a4cad0974c", - "sha256:d79d5ea59f31e00fbb882546840a4adb2fd0bae99b103db1ba5869f176bc530b" + "sha256:5d6b7880b3bea708075ddf91938424453f07053a59f8fa0453c1870df6ff3292", + "sha256:64c8513c53c3a809075ee527b323a0ba61517814123f3137e4912f5d43350139" ], "markers": "python_version >= '3.8'", - "version": "==19.6.0" + "version": "==19.6.1" }, "fakeredis": { "hashes": [ diff --git a/api/api/utils/image_proxy/__init__.py b/api/api/utils/image_proxy/__init__.py index de1c9fd28cf..21c0eb65362 100644 --- a/api/api/utils/image_proxy/__init__.py +++ b/api/api/utils/image_proxy/__init__.py @@ -1,4 +1,5 @@ import logging +from dataclasses import asdict, dataclass from typing import Literal from urllib.parse import urlparse @@ -33,13 +34,25 @@ THUMBNAIL_STRATEGY = Literal["photon_proxy", "original"] +@dataclass +class ImageProxyMediaInfo: + media_identifier: str + image_url: str + + +@dataclass +class ImageProxyConfig: + accept_header: str = "image/*" + is_full_size: bool = False + is_compressed: bool = True + + def get_request_params_for_extension( ext: str, headers: dict[str, str], image_url: str, parsed_image_url: urlparse, - is_full_size: bool, - is_compressed: bool, + proxy_config: ImageProxyConfig, ) -> tuple[str, dict[str, str], dict[str, str]]: """ Get the request params (url, params, headers) for the thumbnail proxy. @@ -49,7 +62,10 @@ def get_request_params_for_extension( """ if ext in PHOTON_TYPES: return get_photon_request_params( - parsed_image_url, is_full_size, is_compressed, headers + parsed_image_url, + proxy_config.is_full_size, + proxy_config.is_compressed, + headers, ) elif ext in ORIGINAL_TYPES: return image_url, {}, headers @@ -59,23 +75,23 @@ def get_request_params_for_extension( def get( - image_url: str, - media_identifier: str, - accept_header: str = "image/*", - is_full_size: bool = False, - is_compressed: bool = True, + media_info: ImageProxyMediaInfo, + proxy_config: ImageProxyConfig = ImageProxyConfig(), ) -> HttpResponse: """ Proxy an image through Photon if its file type is supported, else return the original image if the file type is SVG. Otherwise, raise an exception. """ + image_url = media_info.image_url + media_identifier = media_info.media_identifier + logger = parent_logger.getChild("get") tallies = django_redis.get_redis_connection("tallies") month = get_monthly_timestamp() image_extension = get_image_extension(image_url, media_identifier) - headers = {"Accept": accept_header} | HEADERS + headers = {"Accept": proxy_config.accept_header} | HEADERS parsed_image_url = urlparse(image_url) domain = parsed_image_url.netloc @@ -85,8 +101,7 @@ def get( headers, image_url, parsed_image_url, - is_full_size, - is_compressed, + proxy_config, ) try: diff --git a/api/api/views/audio_views.py b/api/api/views/audio_views.py index c26fe8055b8..1535f9294c0 100644 --- a/api/api/views/audio_views.py +++ b/api/api/views/audio_views.py @@ -6,15 +6,9 @@ from drf_spectacular.utils import extend_schema, extend_schema_view from api.constants.media_types import AUDIO_TYPE -from api.docs.audio_docs import ( - detail, - related, - report, - search, - stats, - thumbnail, - waveform, -) +from api.docs.audio_docs import detail, related, report, search, stats +from api.docs.audio_docs import thumbnail as thumbnail_docs +from api.docs.audio_docs import waveform from api.models import Audio from api.serializers.audio_serializers import ( AudioReportRequestSerializer, @@ -22,7 +16,7 @@ AudioSerializer, AudioWaveformSerializer, ) -from api.serializers.media_serializers import MediaThumbnailRequestSerializer +from api.utils.image_proxy import ImageProxyMediaInfo from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle from api.views.media_views import MediaViewSet @@ -48,21 +42,8 @@ def get_queryset(self): # Extra actions - @thumbnail - @action( - detail=True, - url_path="thumb", - url_name="thumb", - serializer_class=MediaThumbnailRequestSerializer, - throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], - ) - def thumbnail(self, request, *_, **__): - """ - Retrieve the scaled down and compressed thumbnail of the artwork of an - audio track or its audio set. - """ - - audio = self.get_object() + async def get_image_proxy_media_info(self) -> ImageProxyMediaInfo: + audio = await self.aget_object() image_url = None if audio_thumbnail := audio.thumbnail: @@ -72,7 +53,12 @@ def thumbnail(self, request, *_, **__): if not image_url: raise NotFound("Could not find artwork.") - return super().thumbnail(request, audio, image_url) + return ImageProxyMediaInfo( + media_identifier=audio.identifier, + image_url=image_url, + ) + + thumbnail = thumbnail_docs(MediaViewSet.thumbnail) @waveform @action( diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 95844ffbf97..5007bb14765 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -12,15 +12,8 @@ from PIL import Image as PILImage from api.constants.media_types import IMAGE_TYPE -from api.docs.image_docs import ( - detail, - oembed, - related, - report, - search, - stats, - thumbnail, -) +from api.docs.image_docs import detail, oembed, related, report, search, stats +from api.docs.image_docs import thumbnail as thumbnail_docs from api.docs.image_docs import watermark as watermark_doc from api.models import Image from api.serializers.image_serializers import ( @@ -31,8 +24,7 @@ OembedSerializer, WatermarkRequestSerializer, ) -from api.serializers.media_serializers import MediaThumbnailRequestSerializer -from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle +from api.utils.image_proxy import ImageProxyMediaInfo from api.utils.watermark import watermark from api.views.media_views import MediaViewSet @@ -99,25 +91,19 @@ def oembed(self, request, *_, **__): serializer = self.get_serializer(image, context=context) return Response(data=serializer.data) - @thumbnail - @action( - detail=True, - url_path="thumb", - url_name="thumb", - serializer_class=MediaThumbnailRequestSerializer, - throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], - ) - def thumbnail(self, request, *_, **__): - """Retrieve the scaled down and compressed thumbnail of the image.""" - - image = self.get_object() + async def get_image_proxy_media_info(self) -> ImageProxyMediaInfo: + image = await self.aget_object() image_url = image.url # Hotfix to use thumbnails for SMK images # TODO: Remove when small thumbnail issues are resolved if "iip.smk.dk" in image_url and image.thumbnail: image_url = image.thumbnail - return super().thumbnail(request, image, image_url) + return ImageProxyMediaInfo( + media_identifier=image.identifier, image_url=image_url + ) + + thumbnail = thumbnail_docs(MediaViewSet.thumbnail) @watermark_doc @action(detail=True, url_path="watermark", url_name="watermark") diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index 0ad9cd1e4f7..811fb497976 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -6,18 +6,29 @@ from rest_framework.response import Response from rest_framework.viewsets import ReadOnlyModelViewSet +from adrf.views import APIView as AsyncAPIView +from adrf.viewsets import ViewSetMixin as AsyncViewSetMixin +from asgiref.sync import sync_to_async + from api.controllers import search_controller from api.models import ContentProvider from api.models.media import AbstractMedia +from api.serializers.media_serializers import MediaThumbnailRequestSerializer from api.serializers.provider_serializers import ProviderSerializer from api.utils import image_proxy from api.utils.pagination import StandardPagination +from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle logger = logging.getLogger(__name__) -class MediaViewSet(ReadOnlyModelViewSet): +image_proxy_aget = sync_to_async(image_proxy.get, thread_sensitive=True) + + +class MediaViewSet(AsyncViewSetMixin, AsyncAPIView, ReadOnlyModelViewSet): + view_is_async = True + lookup_field = "identifier" # TODO: https://github.com/encode/django-rest-framework/pull/6789 lookup_value_regex = ( @@ -58,6 +69,8 @@ def get_queryset(self): ).values_list("provider_identifier") ) + aget_object = sync_to_async(ReadOnlyModelViewSet.get_object) + def get_serializer_context(self): context = super().get_serializer_context() req_serializer = self._get_request_serializer(self.request) @@ -174,15 +187,30 @@ def report(self, request, identifier): return Response(data=serializer.data, status=status.HTTP_201_CREATED) - def thumbnail(self, request, media_obj, image_url): + async def get_image_proxy_media_info(self) -> image_proxy.ImageProxyMediaInfo: + raise NotImplementedError( + "Subclasses must implement `get_image_proxy_media_info`" + ) + + @action( + detail=True, + url_path="thumb", + url_name="thumb", + serializer_class=MediaThumbnailRequestSerializer, + throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], + ) + async def thumbnail(self, request, *_, **__): serializer = self.get_serializer(data=request.query_params) serializer.is_valid(raise_exception=True) - return image_proxy.get( - image_url, - media_obj.identifier, - accept_header=request.headers.get("Accept", "image/*"), - **serializer.validated_data, + media_info = await self.get_image_proxy_media_info() + + return await image_proxy_aget( + media_info, + proxy_config=image_proxy.ImageProxyConfig( + accept_header=request.headers.get("Accept", "image/*"), + **serializer.validated_data, + ), ) # Helper functions diff --git a/api/test/unit/utils/test_image_proxy.py b/api/test/unit/utils/test_image_proxy.py index d68624ffc03..c1a57b9916c 100644 --- a/api/test/unit/utils/test_image_proxy.py +++ b/api/test/unit/utils/test_image_proxy.py @@ -1,3 +1,4 @@ +from dataclasses import replace from test.factory.models.image import ImageFactory from unittest.mock import MagicMock from urllib.parse import urlencode @@ -9,7 +10,13 @@ import pytest import requests -from api.utils.image_proxy import HEADERS, UpstreamThumbnailException, extension +from api.utils.image_proxy import ( + HEADERS, + ImageProxyConfig, + ImageProxyMediaInfo, + UpstreamThumbnailException, + extension, +) from api.utils.image_proxy import get as photon_get from api.utils.tallies import get_monthly_timestamp @@ -18,6 +25,10 @@ TEST_IMAGE_URL = PHOTON_URL_FOR_TEST_IMAGE.replace(settings.PHOTON_ENDPOINT, "http://") TEST_MEDIA_IDENTIFIER = "123" +TEST_MEDIA_INFO = ImageProxyMediaInfo( + media_identifier=TEST_MEDIA_IDENTIFIER, image_url=TEST_IMAGE_URL +) + UA_HEADER = HEADERS["User-Agent"] # cannot use actual image response because I kept running into some issue with @@ -59,7 +70,7 @@ def test_get_successful_no_auth_key_default_args(mock_image_data): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER) + res = photon_get(TEST_MEDIA_INFO) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -77,7 +88,11 @@ def test_get_successful_original_svg_no_auth_key_default_args(mock_image_data): .mock ) - res = photon_get(TEST_IMAGE_URL.replace(".jpg", ".svg"), TEST_MEDIA_IDENTIFIER) + media_info = replace( + TEST_MEDIA_INFO, image_url=TEST_MEDIA_INFO.image_url.replace(".jpg", ".svg") + ) + + res = photon_get(media_info) assert res.content == SVG_BODY.encode() assert res.status_code == 200 @@ -102,7 +117,7 @@ def test_get_successful_with_auth_key_default_args(mock_image_data, auth_key): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER) + res = photon_get(TEST_MEDIA_INFO) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -125,7 +140,7 @@ def test_get_successful_no_auth_key_not_compressed(mock_image_data): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, is_compressed=False) + res = photon_get(TEST_MEDIA_INFO, ImageProxyConfig(is_compressed=False)) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -148,7 +163,7 @@ def test_get_successful_no_auth_key_full_size(mock_image_data): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, is_full_size=True) + res = photon_get(TEST_MEDIA_INFO, ImageProxyConfig(is_full_size=True)) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -167,7 +182,8 @@ def test_get_successful_no_auth_key_full_size_not_compressed(mock_image_data): ) res = photon_get( - TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, is_full_size=True, is_compressed=False + TEST_MEDIA_INFO, + ImageProxyConfig(is_full_size=True, is_compressed=False), ) assert res.content == MOCK_BODY.encode() @@ -192,7 +208,7 @@ def test_get_successful_no_auth_key_png_only(mock_image_data): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, accept_header="image/png") + res = photon_get(TEST_MEDIA_INFO, ImageProxyConfig(accept_header="image/png")) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -218,9 +234,11 @@ def test_get_successful_forward_query_params(mock_image_data): .mock ) - url_with_params = f"{TEST_IMAGE_URL}?{params}" + media_info_with_url_params = replace( + TEST_MEDIA_INFO, image_url=f"{TEST_IMAGE_URL}?{params}" + ) - res = photon_get(url_with_params, TEST_MEDIA_IDENTIFIER) + res = photon_get(media_info_with_url_params) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -255,7 +273,7 @@ def test_get_successful_records_response_code(mock_image_data, redis): .mock ) - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER) + photon_get(TEST_MEDIA_INFO) month = get_monthly_timestamp() assert redis.get(f"thumbnail_response_code:{month}:200") == b"1" assert ( @@ -310,7 +328,7 @@ def test_get_exception_handles_error( redis.set(key, count_start) with pytest.raises(UpstreamThumbnailException): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER) + photon_get(TEST_MEDIA_INFO) assert_func = ( capture_exception.assert_called_once @@ -351,7 +369,7 @@ def test_get_http_exception_handles_error( redis.set(key, count_start) with pytest.raises(UpstreamThumbnailException): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER) + photon_get(TEST_MEDIA_INFO) assert_func = ( capture_exception.assert_called_once @@ -389,7 +407,9 @@ def test_get_successful_https_image_url_sends_ssl_parameter(mock_image_data): .mock ) - res = photon_get(https_url, TEST_MEDIA_IDENTIFIER) + https_media_info = replace(TEST_MEDIA_INFO, image_url=https_url) + + res = photon_get(https_media_info) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -403,7 +423,7 @@ def test_get_unsuccessful_request_raises_custom_exception(): with pytest.raises( UpstreamThumbnailException, match=r"Failed to render thumbnail." ): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER) + photon_get(TEST_MEDIA_INFO) assert mock_get.matched @@ -432,9 +452,10 @@ def test__get_extension_from_url(image_url, expected_ext): def test_photon_get_raises_by_not_allowed_types(image_type): image_url = TEST_IMAGE_URL.replace(".jpg", f".{image_type}") image = ImageFactory.create(url=image_url) + media_info = ImageProxyMediaInfo(image.identifier, image_url) with pytest.raises(UnsupportedMediaType): - photon_get(image_url, image.identifier) + photon_get(media_info) @pytest.mark.django_db @@ -448,10 +469,12 @@ def test_photon_get_raises_by_not_allowed_types(image_type): def test_photon_get_saves_image_type_to_cache(redis, headers, expected_cache_val): image_url = TEST_IMAGE_URL.replace(".jpg", "") image = ImageFactory.create(url=image_url) + media_info = ImageProxyMediaInfo(image.identifier, image_url) + with pook.use(): pook.head(image_url, reply=200, response_headers=headers) with pytest.raises(UnsupportedMediaType): - photon_get(image_url, image.identifier) + photon_get(media_info) key = f"media:{image.identifier}:thumb_type" assert redis.get(key) == expected_cache_val diff --git a/api/test/unit/views/test_image_views.py b/api/test/unit/views/test_image_views.py index 7dbaa27453d..042870a4549 100644 --- a/api/test/unit/views/test_image_views.py +++ b/api/test/unit/views/test_image_views.py @@ -3,10 +3,8 @@ from dataclasses import dataclass from pathlib import Path from test.factory.models.image import ImageFactory -from unittest.mock import ANY, patch - -from django.http import HttpResponse +import pook import pytest from requests import Request, Response @@ -34,7 +32,7 @@ def _default_response_factory(req: Request) -> Response: return res -@pytest.fixture(autouse=True) +@pytest.fixture def requests(monkeypatch) -> RequestsFixture: fixture = RequestsFixture([]) @@ -67,15 +65,24 @@ def test_oembed_sends_ua_header(api_client, requests): [(True, "http://iip.smk.dk/thumb.jpg"), (False, "http://iip.smk.dk/image.jpg")], ) def test_thumbnail_uses_upstream_thumb_for_smk( - api_client, smk_has_thumb, expected_thumb_url + api_client, smk_has_thumb, expected_thumb_url, settings ): thumb_url = "http://iip.smk.dk/thumb.jpg" if smk_has_thumb else None image = ImageFactory.create( url="http://iip.smk.dk/image.jpg", thumbnail=thumb_url, ) - with patch("api.views.media_views.MediaViewSet.thumbnail") as thumb_call: - mock_response = HttpResponse("mock_response") - thumb_call.return_value = mock_response - api_client.get(f"/v1/images/{image.identifier}/thumb/") - thumb_call.assert_called_once_with(ANY, image, expected_thumb_url) + + with pook.use(): + mock_get = ( + # Pook interprets a trailing slash on the URL as the path, + # so strip that so the `path` matcher works + pook.get(settings.PHOTON_ENDPOINT[:-1]) + .path(expected_thumb_url.replace("http://", "/")) + .response(200) + ).mock + + response = api_client.get(f"/v1/images/{image.identifier}/thumb/") + + assert response.status_code == 200 + assert mock_get.matched is True