diff --git a/api/catalog/api/controllers/search_controller.py b/api/catalog/api/controllers/search_controller.py index 1dcb5e90d..d6db5072b 100644 --- a/api/catalog/api/controllers/search_controller.py +++ b/api/catalog/api/controllers/search_controller.py @@ -18,7 +18,6 @@ import catalog.api.models as models from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask -from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT from catalog.api.utils.validate_images import validate_images @@ -502,10 +501,9 @@ def _get_result_and_page_count( return 0, 1 result_count = response_obj.hits.total.value - natural_page_count = int(result_count / page_size) - if natural_page_count % page_size != 0: - natural_page_count += 1 - page_count = min(natural_page_count, MAX_TOTAL_PAGE_COUNT) + page_count = int(result_count / page_size) + if page_count % page_size != 0: + page_count += 1 if len(results) < page_size and page_count == 0: result_count = len(results) diff --git a/api/catalog/api/serializers/media_serializers.py b/api/catalog/api/serializers/media_serializers.py index 09d4b003c..c93c914c4 100644 --- a/api/catalog/api/serializers/media_serializers.py +++ b/api/catalog/api/serializers/media_serializers.py @@ -1,13 +1,16 @@ from collections import namedtuple +from django.conf import settings +from django.core.exceptions import ValidationError +from django.core.validators import MaxValueValidator from rest_framework import serializers +from rest_framework.exceptions import NotAuthenticated from catalog.api.constants.licenses import LICENSE_GROUPS from catalog.api.controllers import search_controller from catalog.api.models.media import AbstractMedia from catalog.api.serializers.base import BaseModelSerializer from catalog.api.serializers.fields import SchemableHyperlinkedIdentityField -from catalog.api.utils.exceptions import get_api_exception from catalog.api.utils.help_text import make_comma_separated_help_text from catalog.api.utils.licenses import get_license_url from catalog.api.utils.url import add_protocol @@ -42,6 +45,7 @@ class MediaSearchRequestSerializer(serializers.Serializer): "mature", "qa", "page_size", + "page", ] """ Keep the fields names in sync with the actual fields below as this list is @@ -111,6 +115,16 @@ class MediaSearchRequestSerializer(serializers.Serializer): label="page_size", help_text="Number of results to return per page.", required=False, + default=settings.MAX_ANONYMOUS_PAGE_SIZE, + min_value=1, + ) + page = serializers.IntegerField( + label="page", + help_text="The page of results to retrieve.", + required=False, + default=1, + max_value=settings.MAX_PAGINATION_DEPTH, + min_value=1, ) @staticmethod @@ -161,10 +175,30 @@ def validate_title(self, value): def validate_page_size(self, value): request = self.context.get("request") is_anonymous = bool(request and request.user and request.user.is_anonymous) - if is_anonymous and value > 20: - raise get_api_exception( - "Page size must be between 1 & 20 for unauthenticated requests.", 401 - ) + max_value = ( + settings.MAX_ANONYMOUS_PAGE_SIZE + if is_anonymous + else settings.MAX_AUTHED_PAGE_SIZE + ) + + validator = MaxValueValidator( + max_value, + message=serializers.IntegerField.default_error_messages["max_value"].format( + max_value=max_value + ), + ) + + if is_anonymous: + try: + validator(value) + except ValidationError as e: + raise NotAuthenticated( + detail=e.message, + code=e.code, + ) + else: + validator(value) + return value @staticmethod diff --git a/api/catalog/api/utils/pagination.py b/api/catalog/api/utils/pagination.py index 271317a82..e438eb2b3 100644 --- a/api/catalog/api/utils/pagination.py +++ b/api/catalog/api/utils/pagination.py @@ -1,11 +1,7 @@ +from django.conf import settings from rest_framework.pagination import PageNumberPagination from rest_framework.response import Response -from catalog.api.utils.exceptions import get_api_exception - - -MAX_TOTAL_PAGE_COUNT = 20 - class StandardPagination(PageNumberPagination): page_size_query_param = "page_size" @@ -15,45 +11,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.result_count = None # populated later self.page_count = None # populated later - - self._page_size = 20 - self._page = None - - @property - def page_size(self): - """the number of results to show in one page""" - return self._page_size - - @page_size.setter - def page_size(self, value): - if value is None or not str(value).isnumeric(): - return - value = int(value) # convert str params to int - if value <= 0 or value > 500: - raise get_api_exception("Page size must be between 0 & 500.", 400) - self._page_size = value - - @property - def page(self): - """the current page number being served""" - return self._page - - @page.setter - def page(self, value): - if value is None or not str(value).isnumeric(): - value = 1 - value = int(value) # convert str params to int - if value <= 0: - raise get_api_exception("Page must be greater than 0.", 400) - elif value > 20: - raise get_api_exception("Searches are limited to 20 pages.", 400) - self._page = value + self.page = 1 # default, get's updated when necessary def get_paginated_response(self, data): return Response( { "result_count": self.result_count, - "page_count": self.page_count, + "page_count": min(settings.MAX_PAGINATION_DEPTH, self.page_count), "page_size": self.page_size, "page": self.page, "results": data, diff --git a/api/catalog/api/views/media_views.py b/api/catalog/api/views/media_views.py index de2cc536b..cfcd15b04 100644 --- a/api/catalog/api/views/media_views.py +++ b/api/catalog/api/views/media_views.py @@ -76,13 +76,11 @@ def _get_request_serializer(self, request): # Standard actions def list(self, request, *_, **__): - self.paginator.page_size = request.query_params.get("page_size") - page_size = self.paginator.page_size - self.paginator.page = request.query_params.get("page") - page = self.paginator.page - params = self._get_request_serializer(request) + page_size = self.paginator.page_size = params.data["page_size"] + page = self.paginator.page = params.data["page"] + hashed_ip = hash(self._get_user_ip(request)) qa = params.validated_data["qa"] filter_dead = params.validated_data["filter_dead"] diff --git a/api/catalog/settings.py b/api/catalog/settings.py index 47e52b384..472e3d539 100644 --- a/api/catalog/settings.py +++ b/api/catalog/settings.py @@ -372,3 +372,7 @@ # E.g. LINK_VALIDATION_CACHE_EXPIRY__200='{"days": 1}' will set the expiration time # for links with HTTP status 200 to 1 day LINK_VALIDATION_CACHE_EXPIRY_CONFIGURATION = LinkValidationCacheExpiryConfiguration() + +MAX_ANONYMOUS_PAGE_SIZE = 20 +MAX_AUTHED_PAGE_SIZE = 500 +MAX_PAGINATION_DEPTH = 20 diff --git a/api/test/auth_test.py b/api/test/auth_test.py index f457f9994..f35b73f2b 100644 --- a/api/test/auth_test.py +++ b/api/test/auth_test.py @@ -97,7 +97,7 @@ def test_auth_rate_limit_reporting( @pytest.mark.django_db -def test_pase_size_limit_unauthed(client): +def test_page_size_limit_unauthed(client): query_params = {"filter_dead": False, "page_size": 20} res = client.get("/v1/images/", query_params) assert res.status_code == 200 diff --git a/api/test/dead_link_filter_test.py b/api/test/dead_link_filter_test.py index 0d295de2b..91331075d 100644 --- a/api/test/dead_link_filter_test.py +++ b/api/test/dead_link_filter_test.py @@ -2,12 +2,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from django.conf import settings + import pytest import requests from fakeredis import FakeRedis from catalog.api.controllers.search_controller import DEAD_LINK_RATIO -from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT @pytest.fixture(autouse=True) @@ -202,7 +203,7 @@ def test_page_consistency_removing_dead_links(search_without_dead_links): Test the results returned in consecutive pages are never repeated when filtering out dead links. """ - total_pages = MAX_TOTAL_PAGE_COUNT + total_pages = settings.MAX_PAGINATION_DEPTH page_size = 5 page_results = [] @@ -226,6 +227,8 @@ def no_duplicates(xs): @pytest.mark.django_db def test_max_page_count(): response = requests.get( - f"{API_URL}/v1/images", params={"page": MAX_TOTAL_PAGE_COUNT + 1}, verify=False + f"{API_URL}/v1/images", + params={"page": settings.MAX_PAGINATION_DEPTH + 1}, + verify=False, ) assert response.status_code == 400 diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index 5bfcf980b..f2c29d025 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -10,7 +10,6 @@ from catalog.api.controllers import search_controller from catalog.api.utils.dead_link_mask import get_query_hash, save_query_mask -from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT @pytest.mark.parametrize( @@ -41,8 +40,6 @@ (20, 5, 5, (20, 5)), # Fewer hits than page size, but result list somehow differs, use that for count (48, 20, 50, (20, 0)), - # Page count gets truncated always - (5000, 10, 10, (5000, MAX_TOTAL_PAGE_COUNT)), ], ) def test_get_result_and_page_count(total_hits, real_result_count, page_size, expected): diff --git a/api/test/unit/serializers/media_serializers_test.py b/api/test/unit/serializers/media_serializers_test.py index bf722246d..b48ecfbf8 100644 --- a/api/test/unit/serializers/media_serializers_test.py +++ b/api/test/unit/serializers/media_serializers_test.py @@ -1,21 +1,34 @@ import uuid +from test.factory.models.oauth2 import AccessTokenFactory from unittest.mock import MagicMock -from rest_framework.request import Request -from rest_framework.test import APIRequestFactory +from django.conf import settings +from rest_framework.exceptions import NotAuthenticated, ValidationError +from rest_framework.test import APIRequestFactory, force_authenticate +from rest_framework.views import APIView import pytest from catalog.api.serializers.audio_serializers import AudioSerializer from catalog.api.serializers.image_serializers import ImageSerializer +from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer +# TODO: @sarayourfriend consolidate these with the other +# request factory fixtures into conftest.py @pytest.fixture -def req(): - factory = APIRequestFactory() - request = factory.get("/") - request = Request(request) - return request +def request_factory() -> APIRequestFactory(): + request_factory = APIRequestFactory(defaults={"REMOTE_ADDR": "192.0.2.1"}) + + return request_factory + + +@pytest.fixture +def access_token(): + token = AccessTokenFactory.create() + token.application.verified = True + token.application.save() + return token @pytest.fixture @@ -28,6 +41,58 @@ def hit(): return hit +@pytest.fixture +def authed_request(access_token, request_factory): + request = request_factory.get("/") + + force_authenticate(request, token=access_token.token) + + return APIView().initialize_request(request) + + +@pytest.fixture +def anon_request(request_factory): + return APIView().initialize_request(request_factory.get("/")) + + +@pytest.mark.django_db +@pytest.mark.parametrize( + ("page_size", "authenticated"), + ( + pytest.param(-1, False, marks=pytest.mark.raises(exception=ValidationError)), + pytest.param(0, False, marks=pytest.mark.raises(exception=ValidationError)), + (1, False), + (settings.MAX_ANONYMOUS_PAGE_SIZE, False), + pytest.param( + settings.MAX_ANONYMOUS_PAGE_SIZE + 1, + False, + marks=pytest.mark.raises(exception=NotAuthenticated), + ), + pytest.param( + settings.MAX_AUTHED_PAGE_SIZE, + False, + marks=pytest.mark.raises(exception=NotAuthenticated), + ), + pytest.param(-1, True, marks=pytest.mark.raises(exception=ValidationError)), + pytest.param(0, True, marks=pytest.mark.raises(exception=ValidationError)), + (1, True), + (settings.MAX_ANONYMOUS_PAGE_SIZE + 1, True), + (settings.MAX_AUTHED_PAGE_SIZE, True), + pytest.param( + settings.MAX_AUTHED_PAGE_SIZE + 1, + True, + marks=pytest.mark.raises(exception=ValidationError), + ), + ), +) +def test_page_size_validation(page_size, authenticated, anon_request, authed_request): + request = authed_request if authenticated else anon_request + serializer = MediaSearchRequestSerializer( + context={"request": request}, data={"page_size": page_size} + ) + assert serializer.is_valid(raise_exception=True) + + @pytest.mark.parametrize( "serializer_class", [ @@ -35,10 +100,12 @@ def hit(): ImageSerializer, ], ) -def test_media_serializer_adds_license_url_if_missing(req, hit, serializer_class): +def test_media_serializer_adds_license_url_if_missing( + anon_request, hit, serializer_class +): # Note that this behaviour is inherited from the parent `MediaSerializer` class, but # it cannot be tested without a concrete model to test with. del hit.license_url # without the ``del``, the property is dynamically generated - repr = serializer_class(hit, context={"request": req}).data + repr = serializer_class(hit, context={"request": anon_request}).data assert repr["license_url"] == "https://creativecommons.org/publicdomain/zero/1.0/"