diff --git a/api/Pipfile b/api/Pipfile index d824cdb4cd5..e2d8e9a0430 100644 --- a/api/Pipfile +++ b/api/Pipfile @@ -18,6 +18,7 @@ pytest-sugar = "~=0.9" pook = {ref = "master", git = "git+https://github.com/h2non/pook.git"} [packages] +adrf = "~=0.1.2" aiohttp = "~=3.8" aws-requests-auth = "~=0.4" deepdiff = "~=6.4" @@ -26,6 +27,7 @@ django-cors-headers = "~=4.2" django-log-request-id = "~=2.0" django-oauth-toolkit = "~=2.3" django-redis = "~=5.4" +django-split-settings = "*" django-tqdm = "~=1.3" django-uuslug = "~=2.0" djangorestframework = "~=3.14" @@ -35,11 +37,10 @@ elasticsearch-dsl = "~=8.9" future = "~=0.18" limit = "~=0.2" Pillow = "~=10.1.0" +psycopg = "~=3.1" python-decouple = "~=3.8" python-xmp-toolkit = "~=2.0" sentry-sdk = "~=1.30" -django-split-settings = "*" -psycopg = "~=3.1" uvicorn = {extras = ["standard"], version = "~=0.23"} [requires] diff --git a/api/Pipfile.lock b/api/Pipfile.lock index 0c13736001c..7a2ade8cfef 100644 --- a/api/Pipfile.lock +++ b/api/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "7e9adb0878e2d7523c7457b3712adde0aee93b6586d54eef2663d09e145a1b3e" + "sha256": "54293a6311c5ebb7d16bf6bb13d9a63f420b9275855cba28076d08f125ae95c2" }, "pipfile-spec": 6, "requires": { @@ -16,6 +16,14 @@ ] }, "default": { + "adrf": { + "hashes": [ + "sha256:a33f8f51f0f80072ffb2af061df1fb119bc00adaa720a2972049d4aa33155337", + "sha256:ce7160878ba27999d333752941cde0687c1a205fc26fa0eda1bad3924958dc69" + ], + "index": "pypi", + "version": "==0.1.2" + }, "aiohttp": { "hashes": [ "sha256:002f23e6ea8d3dd8d149e569fd580c999232b5fbc601c48d55398fbc2e582e8c", @@ -133,6 +141,13 @@ "markers": "python_version >= '3.7'", "version": "==3.7.2" }, + "async-property": { + "hashes": [ + "sha256:17d9bd6ca67e27915a75d92549df64b5c7174e9dc806b30a3934dc4ff0506380", + "sha256:8924d792b5843994537f8ed411165700b27b2bd966cefc4daeefc1253442a9d7" + ], + "version": "==0.2.2" + }, "async-timeout": { "hashes": [ "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f", diff --git a/api/api/utils/image_proxy/__init__.py b/api/api/utils/image_proxy/__init__.py index 6a591a0c61f..bcfb44db6aa 100644 --- a/api/api/utils/image_proxy/__init__.py +++ b/api/api/utils/image_proxy/__init__.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from typing import Literal from urllib.parse import urlparse @@ -33,13 +34,26 @@ THUMBNAIL_STRATEGY = Literal["photon_proxy", "original"] +@dataclass +class MediaInfo: + media_provider: str + media_identifier: str + image_url: str + + +@dataclass +class RequestConfig: + accept_header: str = "image/*" + is_full_size: bool = False + is_compressed: bool = True + + def get_request_params_for_extension( ext: str, headers: dict[str, str], image_url: str, parsed_image_url: urlparse, - is_full_size: bool, - is_compressed: bool, + request_config: RequestConfig, ) -> tuple[str, dict[str, str], dict[str, str]]: """ Get the request params (url, params, headers) for the thumbnail proxy. @@ -49,7 +63,10 @@ def get_request_params_for_extension( """ if ext in PHOTON_TYPES: return get_photon_request_params( - parsed_image_url, is_full_size, is_compressed, headers + parsed_image_url, + request_config.is_full_size, + request_config.is_compressed, + headers, ) elif ext in ORIGINAL_TYPES: return image_url, {}, headers @@ -59,24 +76,23 @@ def get_request_params_for_extension( def get( - image_url: str, - media_identifier: str, - media_provider: str, - accept_header: str = "image/*", - is_full_size: bool = False, - is_compressed: bool = True, + media_info: MediaInfo, + request_config: RequestConfig = RequestConfig(), ) -> HttpResponse: """ Proxy an image through Photon if its file type is supported, else return the original image if the file type is SVG. Otherwise, raise an exception. """ + image_url = media_info.image_url + media_identifier = media_info.media_identifier + logger = parent_logger.getChild("get") tallies = django_redis.get_redis_connection("tallies") month = get_monthly_timestamp() image_extension = get_image_extension(image_url, media_identifier) - headers = {"Accept": accept_header} | HEADERS + headers = {"Accept": request_config.accept_header} | HEADERS parsed_image_url = urlparse(image_url) domain = parsed_image_url.netloc @@ -86,8 +102,7 @@ def get( headers, image_url, parsed_image_url, - is_full_size, - is_compressed, + request_config, ) try: @@ -103,7 +118,7 @@ def get( f"{month}:{upstream_response.status_code}" ) tallies.incr( - f"thumbnail_response_code_by_provider:{media_provider}:" + f"thumbnail_response_code_by_provider:{media_info.media_provider}:" f"{month}:{upstream_response.status_code}" ) upstream_response.raise_for_status() @@ -133,7 +148,9 @@ def get( f"thumbnail_http_error:{domain}:{month}:{code}:{exc.response.text}" ) logger.warning( - f"Failed to render thumbnail {upstream_url=} {code=} {media_provider=}" + f"Failed to render thumbnail " + f"{upstream_url=} {code=} " + f"{media_info.media_provider=}" ) raise UpstreamThumbnailException(f"Failed to render thumbnail. {exc}") diff --git a/api/api/views/audio_views.py b/api/api/views/audio_views.py index 6457b233a09..dcde7266b92 100644 --- a/api/api/views/audio_views.py +++ b/api/api/views/audio_views.py @@ -15,9 +15,9 @@ source_collection, stats, tag_collection, - thumbnail, - waveform, ) +from api.docs.audio_docs import thumbnail as thumbnail_docs +from api.docs.audio_docs import waveform from api.models import Audio from api.serializers.audio_serializers import ( AudioCollectionRequestSerializer, @@ -26,7 +26,7 @@ AudioSerializer, AudioWaveformSerializer, ) -from api.serializers.media_serializers import MediaThumbnailRequestSerializer +from api.utils import image_proxy from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle from api.views.media_views import MediaViewSet @@ -80,21 +80,8 @@ def source_collection(self, request, source): def tag_collection(self, request, tag, *_, **__): return super().tag_collection(request, tag, *_, **__) - @thumbnail - @action( - detail=True, - url_path="thumb", - url_name="thumb", - serializer_class=MediaThumbnailRequestSerializer, - throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], - ) - def thumbnail(self, request, *_, **__): - """ - Retrieve the scaled down and compressed thumbnail of the artwork of an - audio track or its audio set. - """ - - audio = self.get_object() + async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo: + audio = await self.aget_object() image_url = None if audio_thumbnail := audio.thumbnail: @@ -104,7 +91,20 @@ def thumbnail(self, request, *_, **__): if not image_url: raise NotFound("Could not find artwork.") - return super().thumbnail(request, audio, image_url) + return image_proxy.MediaInfo( + media_identifier=audio.identifier, + media_provider=audio.provider, + image_url=image_url, + ) + + @thumbnail_docs + @MediaViewSet.thumbnail_action + async def thumbnail(self, *args, **kwargs): + """ + Retrieve the scaled down and compressed thumbnail of the artwork of an + audio track or its audio set. + """ + return await super().thumbnail(*args, **kwargs) @waveform @action( diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 641a230f384..955466bcf57 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -22,8 +22,8 @@ source_collection, stats, tag_collection, - thumbnail, ) +from api.docs.image_docs import thumbnail as thumbnail_docs from api.docs.image_docs import watermark as watermark_doc from api.models import Image from api.serializers.image_serializers import ( @@ -34,11 +34,8 @@ OembedSerializer, WatermarkRequestSerializer, ) -from api.serializers.media_serializers import ( - MediaThumbnailRequestSerializer, - PaginatedRequestSerializer, -) -from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle +from api.serializers.media_serializers import PaginatedRequestSerializer +from api.utils import image_proxy from api.utils.watermark import watermark from api.views.media_views import MediaViewSet @@ -130,25 +127,25 @@ def oembed(self, request, *_, **__): serializer = self.get_serializer(image, context=context) return Response(data=serializer.data) - @thumbnail - @action( - detail=True, - url_path="thumb", - url_name="thumb", - serializer_class=MediaThumbnailRequestSerializer, - throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], - ) - def thumbnail(self, request, *_, **__): - """Retrieve the scaled down and compressed thumbnail of the image.""" - - image = self.get_object() + async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo: + image = await self.aget_object() image_url = image.url # Hotfix to use thumbnails for SMK images # TODO: Remove when small thumbnail issues are resolved if "iip.smk.dk" in image_url and image.thumbnail: image_url = image.thumbnail - return super().thumbnail(request, image, image_url) + return image_proxy.MediaInfo( + media_identifier=image.identifier, + media_provider=image.provider, + image_url=image_url, + ) + + @thumbnail_docs + @MediaViewSet.thumbnail_action + async def thumbnail(self, *args, **kwargs): + """Retrieve the scaled down and compressed thumbnail of the image.""" + return await super().thumbnail(*args, **kwargs) @watermark_doc @action(detail=True, url_path="watermark", url_name="watermark") diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index ecf6005f410..d003feec9fd 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -7,6 +7,10 @@ from rest_framework.response import Response from rest_framework.viewsets import ReadOnlyModelViewSet +from adrf.views import APIView as AsyncAPIView +from adrf.viewsets import ViewSetMixin as AsyncViewSetMixin +from asgiref.sync import sync_to_async + from api.constants.media_types import MediaType from api.constants.search import SearchStrategy from api.controllers import search_controller @@ -18,6 +22,7 @@ from api.utils import image_proxy from api.utils.pagination import StandardPagination from api.utils.search_context import SearchContext +from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle logger = logging.getLogger(__name__) @@ -35,7 +40,12 @@ class InvalidSource(APIException): default_code = "invalid_source" -class MediaViewSet(ReadOnlyModelViewSet): +image_proxy_aget = sync_to_async(image_proxy.get) + + +class MediaViewSet(AsyncViewSetMixin, AsyncAPIView, ReadOnlyModelViewSet): + view_is_async = True + lookup_field = "identifier" # TODO: https://github.com/encode/django-rest-framework/pull/6789 lookup_value_regex = ( @@ -79,6 +89,8 @@ def get_queryset(self): ).values_list("provider_identifier") ) + aget_object = sync_to_async(ReadOnlyModelViewSet.get_object) + def get_serializer_context(self): context = super().get_serializer_context() req_serializer = self._get_request_serializer(self.request) @@ -265,16 +277,31 @@ def report(self, request, identifier): return Response(data=serializer.data, status=status.HTTP_201_CREATED) - def thumbnail(self, request, media_obj, image_url): + async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo: + raise NotImplementedError( + "Subclasses must implement `get_image_proxy_media_info`" + ) + + thumbnail_action = action( + detail=True, + url_path="thumb", + url_name="thumb", + serializer_class=media_serializers.MediaThumbnailRequestSerializer, + throttle_classes=[AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle], + ) + + async def thumbnail(self, request, *_, **__): serializer = self.get_serializer(data=request.query_params) serializer.is_valid(raise_exception=True) - return image_proxy.get( - image_url, - media_obj.identifier, - media_obj.provider, - accept_header=request.headers.get("Accept", "image/*"), - **serializer.validated_data, + media_info = await self.get_image_proxy_media_info() + + return await image_proxy_aget( + media_info, + request_config=image_proxy.RequestConfig( + accept_header=request.headers.get("Accept", "image/*"), + **serializer.validated_data, + ), ) # Helper functions diff --git a/api/test/unit/utils/test_image_proxy.py b/api/test/unit/utils/test_image_proxy.py index 35b047f5b7a..5f2aa58f15a 100644 --- a/api/test/unit/utils/test_image_proxy.py +++ b/api/test/unit/utils/test_image_proxy.py @@ -1,3 +1,4 @@ +from dataclasses import replace from test.factory.models.image import ImageFactory from unittest.mock import MagicMock from urllib.parse import urlencode @@ -9,7 +10,13 @@ import pytest import requests -from api.utils.image_proxy import HEADERS, UpstreamThumbnailException, extension +from api.utils.image_proxy import ( + HEADERS, + MediaInfo, + RequestConfig, + UpstreamThumbnailException, + extension, +) from api.utils.image_proxy import get as photon_get from api.utils.tallies import get_monthly_timestamp @@ -19,6 +26,12 @@ TEST_MEDIA_IDENTIFIER = "123" TEST_MEDIA_PROVIDER = "foo" +TEST_MEDIA_INFO = MediaInfo( + media_identifier=TEST_MEDIA_IDENTIFIER, + media_provider=TEST_MEDIA_PROVIDER, + image_url=TEST_IMAGE_URL, +) + UA_HEADER = HEADERS["User-Agent"] # cannot use actual image response because I kept running into some issue with @@ -60,7 +73,7 @@ def test_get_successful_no_auth_key_default_args(mock_image_data): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + res = photon_get(TEST_MEDIA_INFO) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -78,12 +91,12 @@ def test_get_successful_original_svg_no_auth_key_default_args(mock_image_data): .mock ) - res = photon_get( - TEST_IMAGE_URL.replace(".jpg", ".svg"), - TEST_MEDIA_IDENTIFIER, - TEST_MEDIA_PROVIDER, + media_info = replace( + TEST_MEDIA_INFO, image_url=TEST_MEDIA_INFO.image_url.replace(".jpg", ".svg") ) + res = photon_get(media_info) + assert res.content == SVG_BODY.encode() assert res.status_code == 200 assert mock_get.matched @@ -107,7 +120,7 @@ def test_get_successful_with_auth_key_default_args(mock_image_data, auth_key): .mock ) - res = photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + res = photon_get(TEST_MEDIA_INFO) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -130,9 +143,7 @@ def test_get_successful_no_auth_key_not_compressed(mock_image_data): .mock ) - res = photon_get( - TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER, is_compressed=False - ) + res = photon_get(TEST_MEDIA_INFO, RequestConfig(is_compressed=False)) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -155,9 +166,7 @@ def test_get_successful_no_auth_key_full_size(mock_image_data): .mock ) - res = photon_get( - TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER, is_full_size=True - ) + res = photon_get(TEST_MEDIA_INFO, RequestConfig(is_full_size=True)) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -176,11 +185,8 @@ def test_get_successful_no_auth_key_full_size_not_compressed(mock_image_data): ) res = photon_get( - TEST_IMAGE_URL, - TEST_MEDIA_IDENTIFIER, - TEST_MEDIA_PROVIDER, - is_full_size=True, - is_compressed=False, + TEST_MEDIA_INFO, + RequestConfig(is_full_size=True, is_compressed=False), ) assert res.content == MOCK_BODY.encode() @@ -205,12 +211,7 @@ def test_get_successful_no_auth_key_png_only(mock_image_data): .mock ) - res = photon_get( - TEST_IMAGE_URL, - TEST_MEDIA_IDENTIFIER, - TEST_MEDIA_PROVIDER, - accept_header="image/png", - ) + res = photon_get(TEST_MEDIA_INFO, RequestConfig(accept_header="image/png")) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -236,9 +237,11 @@ def test_get_successful_forward_query_params(mock_image_data): .mock ) - url_with_params = f"{TEST_IMAGE_URL}?{params}" + media_info_with_url_params = replace( + TEST_MEDIA_INFO, image_url=f"{TEST_IMAGE_URL}?{params}" + ) - res = photon_get(url_with_params, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + res = photon_get(media_info_with_url_params) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -273,7 +276,7 @@ def test_get_successful_records_response_code(mock_image_data, redis): .mock ) - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + photon_get(TEST_MEDIA_INFO) month = get_monthly_timestamp() assert redis.get(f"thumbnail_response_code:{month}:200") == b"1" assert ( @@ -328,7 +331,7 @@ def test_get_exception_handles_error( redis.set(key, count_start) with pytest.raises(UpstreamThumbnailException): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + photon_get(TEST_MEDIA_INFO) assert_func = ( capture_exception.assert_called_once @@ -369,7 +372,7 @@ def test_get_http_exception_handles_error( redis.set(key, count_start) with pytest.raises(UpstreamThumbnailException): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + photon_get(TEST_MEDIA_INFO) assert_func = ( capture_exception.assert_called_once @@ -407,7 +410,9 @@ def test_get_successful_https_image_url_sends_ssl_parameter(mock_image_data): .mock ) - res = photon_get(https_url, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + https_media_info = replace(TEST_MEDIA_INFO, image_url=https_url) + + res = photon_get(https_media_info) assert res.content == MOCK_BODY.encode() assert res.status_code == 200 @@ -421,7 +426,7 @@ def test_get_unsuccessful_request_raises_custom_exception(): with pytest.raises( UpstreamThumbnailException, match=r"Failed to render thumbnail." ): - photon_get(TEST_IMAGE_URL, TEST_MEDIA_IDENTIFIER, TEST_MEDIA_PROVIDER) + photon_get(TEST_MEDIA_INFO) assert mock_get.matched @@ -450,9 +455,14 @@ def test__get_extension_from_url(image_url, expected_ext): def test_photon_get_raises_by_not_allowed_types(image_type): image_url = TEST_IMAGE_URL.replace(".jpg", f".{image_type}") image = ImageFactory.create(url=image_url) + media_info = MediaInfo( + media_identifier=image.identifier, + media_provider=image.provider, + image_url=image_url, + ) with pytest.raises(UnsupportedMediaType): - photon_get(image_url, image.identifier, image.provider) + photon_get(media_info) @pytest.mark.django_db @@ -466,10 +476,15 @@ def test_photon_get_raises_by_not_allowed_types(image_type): def test_photon_get_saves_image_type_to_cache(redis, headers, expected_cache_val): image_url = TEST_IMAGE_URL.replace(".jpg", "") image = ImageFactory.create(url=image_url) + media_info = MediaInfo( + media_identifier=image.identifier, + media_provider=image.provider, + image_url=image_url, + ) with pook.use(): pook.head(image_url, reply=200, response_headers=headers) with pytest.raises(UnsupportedMediaType): - photon_get(image_url, image.identifier, image.provider) + photon_get(media_info) key = f"media:{image.identifier}:thumb_type" assert redis.get(key) == expected_cache_val diff --git a/api/test/unit/views/test_image_views.py b/api/test/unit/views/test_image_views.py index ffe7c0352ed..dc258dfc246 100644 --- a/api/test/unit/views/test_image_views.py +++ b/api/test/unit/views/test_image_views.py @@ -3,10 +3,9 @@ from dataclasses import dataclass from pathlib import Path from test.factory.models.image import ImageFactory -from unittest.mock import ANY, patch - -from django.http import HttpResponse +from unittest.mock import patch +import pook import pytest from PIL import UnidentifiedImageError from requests import Request, Response @@ -35,7 +34,7 @@ def _default_response_factory(req: Request) -> Response: return res -@pytest.fixture(autouse=True) +@pytest.fixture def requests(monkeypatch) -> RequestsFixture: fixture = RequestsFixture([]) @@ -68,18 +67,27 @@ def test_oembed_sends_ua_header(api_client, requests): [(True, "http://iip.smk.dk/thumb.jpg"), (False, "http://iip.smk.dk/image.jpg")], ) def test_thumbnail_uses_upstream_thumb_for_smk( - api_client, smk_has_thumb, expected_thumb_url + api_client, smk_has_thumb, expected_thumb_url, settings ): thumb_url = "http://iip.smk.dk/thumb.jpg" if smk_has_thumb else None image = ImageFactory.create( url="http://iip.smk.dk/image.jpg", thumbnail=thumb_url, ) - with patch("api.views.media_views.MediaViewSet.thumbnail") as thumb_call: - mock_response = HttpResponse("mock_response") - thumb_call.return_value = mock_response - api_client.get(f"/v1/images/{image.identifier}/thumb/") - thumb_call.assert_called_once_with(ANY, image, expected_thumb_url) + + with pook.use(): + mock_get = ( + # Pook interprets a trailing slash on the URL as the path, + # so strip that so the `path` matcher works + pook.get(settings.PHOTON_ENDPOINT[:-1]) + .path(expected_thumb_url.replace("http://", "/")) + .response(200) + ).mock + + response = api_client.get(f"/v1/images/{image.identifier}/thumb/") + + assert response.status_code == 200 + assert mock_get.matched is True @pytest.mark.django_db @@ -89,9 +97,13 @@ def test_watermark_raises_424_for_invalid_image(api_client): "cannot identify image file <_io.BytesIO object at 0xffff86d8fec0>" ) - with patch("PIL.Image.open") as mock_open: - mock_open.side_effect = UnidentifiedImageError(expected_error_message) - res = api_client.get(f"/v1/images/{image.identifier}/watermark/") + with pook.use(): + pook.get(image.url).reply(200) + + with patch("PIL.Image.open") as mock_open: + mock_open.side_effect = UnidentifiedImageError(expected_error_message) + res = api_client.get(f"/v1/images/{image.identifier}/watermark/") + assert res.status_code == 424 assert res.data["detail"] == expected_error_message