Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Move page and page_size query param validation into serializer #868

Merged
merged 9 commits into from
Sep 16, 2022
8 changes: 3 additions & 5 deletions api/catalog/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import catalog.api.models as models
from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask
from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT
from catalog.api.utils.validate_images import validate_images


Expand Down Expand Up @@ -425,10 +424,9 @@ def _get_result_and_page_count(
:return: Result and page count.
"""
result_count = response_obj.hits.total.value
natural_page_count = int(result_count / page_size)
if natural_page_count % page_size != 0:
natural_page_count += 1
page_count = min(natural_page_count, MAX_TOTAL_PAGE_COUNT)
zackkrida marked this conversation as resolved.
Show resolved Hide resolved
page_count = int(result_count / page_size)
if page_count % page_size != 0:
page_count += 1
if len(results) < page_size and page_count == 0:
result_count = len(results)

Expand Down
44 changes: 39 additions & 5 deletions api/catalog/api/serializers/media_serializers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from collections import namedtuple

from django.conf import settings
from django.core.exceptions import ValidationError
from django.core.validators import MaxValueValidator
from rest_framework import serializers
from rest_framework.exceptions import NotAuthenticated

from catalog.api.constants.licenses import LICENSE_GROUPS
from catalog.api.controllers import search_controller
from catalog.api.models.media import AbstractMedia
from catalog.api.serializers.base import BaseModelSerializer
from catalog.api.serializers.fields import SchemableHyperlinkedIdentityField
from catalog.api.utils.exceptions import get_api_exception
from catalog.api.utils.help_text import make_comma_separated_help_text
from catalog.api.utils.url import add_protocol

Expand Down Expand Up @@ -41,6 +44,7 @@ class MediaSearchRequestSerializer(serializers.Serializer):
"mature",
"qa",
"page_size",
"page",
]
"""
Keep the fields names in sync with the actual fields below as this list is
Expand Down Expand Up @@ -110,6 +114,16 @@ class MediaSearchRequestSerializer(serializers.Serializer):
label="page_size",
help_text="Number of results to return per page.",
required=False,
default=settings.MAX_ANONYMOUS_PAGE_SIZE,
min_value=1,
)
page = serializers.IntegerField(
label="page",
help_text="The page of results to retrieve.",
required=False,
default=1,
max_value=20,
min_value=1,
)

@staticmethod
Expand Down Expand Up @@ -160,10 +174,30 @@ def validate_title(self, value):
def validate_page_size(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
if is_anonymous and value > 20:
raise get_api_exception(
"Page size must be between 1 & 20 for unauthenticated requests.", 401
)
max_value = (
settings.MAX_ANONYMOUS_PAGE_SIZE
if is_anonymous
else settings.MAX_AUTHED_PAGE_SIZE
)

validator = MaxValueValidator(
max_value,
message=serializers.IntegerField.default_error_messages["max_value"].format(
max_value=max_value
),
)
zackkrida marked this conversation as resolved.
Show resolved Hide resolved

if is_anonymous:
try:
validator(value)
except ValidationError as e:
raise NotAuthenticated(
detail=e.message,
code=e.code,
)
else:
validator(value)

return value

@staticmethod
Expand Down
42 changes: 3 additions & 39 deletions api/catalog/api/utils/pagination.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from django.conf import settings
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response

from catalog.api.utils.exceptions import get_api_exception


MAX_TOTAL_PAGE_COUNT = 20


class StandardPagination(PageNumberPagination):
page_size_query_param = "page_size"
Expand All @@ -15,45 +11,13 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.result_count = None # populated later
self.page_count = None # populated later

self._page_size = 20
self._page = None

@property
def page_size(self):
"""the number of results to show in one page"""
return self._page_size

@page_size.setter
def page_size(self, value):
if value is None or not str(value).isnumeric():
return
value = int(value) # convert str params to int
if value <= 0 or value > 500:
raise get_api_exception("Page size must be between 0 & 500.", 400)
self._page_size = value

@property
def page(self):
"""the current page number being served"""
return self._page

@page.setter
def page(self, value):
if value is None or not str(value).isnumeric():
value = 1
value = int(value) # convert str params to int
if value <= 0:
raise get_api_exception("Page must be greater than 0.", 400)
elif value > 20:
raise get_api_exception("Searches are limited to 20 pages.", 400)
self._page = value
self.page = 1 # default, get's updated when necessary

def get_paginated_response(self, data):
return Response(
{
"result_count": self.result_count,
"page_count": self.page_count,
"page_count": min(settings.MAX_PAGINATION_DEPTH, self.page_count),
"page_size": self.page_size,
"page": self.page,
"results": data,
Expand Down
8 changes: 3 additions & 5 deletions api/catalog/api/views/media_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ def get_queryset(self):
# Standard actions

def list(self, request, *_, **__):
self.paginator.page_size = request.query_params.get("page_size")
page_size = self.paginator.page_size
self.paginator.page = request.query_params.get("page")
page = self.paginator.page

params = self.query_serializer_class(
data=request.query_params, context={"request": request}
)
params.is_valid(raise_exception=True)

page_size = self.paginator.page_size = params.data["page_size"]
page = self.paginator.page = params.data["page"]

hashed_ip = hash(self._get_user_ip(request))
qa = params.validated_data["qa"]
filter_dead = params.validated_data["filter_dead"]
Expand Down
4 changes: 4 additions & 0 deletions api/catalog/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,7 @@
# from pushing un-actionable alerts to Sentry like
# https://sentry.io/share/issue/9af3cdf8ef74420aa7bbb6697760a82c/
ignore_logger("django.security.DisallowedHost")

MAX_ANONYMOUS_PAGE_SIZE = 20
MAX_AUTHED_PAGE_SIZE = 500
MAX_PAGINATION_DEPTH = 20
4 changes: 2 additions & 2 deletions api/test/auth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def test_auth_rate_limit_reporting(


@pytest.mark.django_db
def test_pase_size_limit_unauthed(client):
def test_page_size_limit_unauthed(client):
query_params = {"filter_dead": False, "page_size": 20}
res = client.get("/v1/images/", query_params)
assert res.status_code == 200
query_params["page_size"] = 21
res = client.get("/v1/images/", query_params)
assert res.status_code == 401
assert res.status_code == 400


@pytest.mark.django_db
Expand Down
10 changes: 6 additions & 4 deletions api/test/dead_link_filter_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from test.constants import API_URL
from unittest.mock import MagicMock, patch

from django.conf import settings

import pytest
import requests

from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT


def _patch_redis():
def redis_mget(keys, *_, **__):
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_page_consistency_removing_dead_links(search_without_dead_links):
Test the results returned in consecutive pages are never repeated when
filtering out dead links.
"""
total_pages = MAX_TOTAL_PAGE_COUNT
total_pages = settings.MAX_PAGINATION_DEPTH
page_size = 5

page_results = []
Expand All @@ -141,6 +141,8 @@ def no_duplicates(xs):
@pytest.mark.django_db
def test_max_page_count():
response = requests.get(
f"{API_URL}/v1/images", params={"page": MAX_TOTAL_PAGE_COUNT + 1}, verify=False
f"{API_URL}/v1/images",
params={"page": settings.MAX_PAGINATION_DEPTH + 1},
verify=False,
)
assert response.status_code == 400
4 changes: 1 addition & 3 deletions api/test/unit/controllers/test_search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest

from catalog.api.controllers import search_controller
from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT


@pytest.mark.parametrize(
Expand Down Expand Up @@ -34,8 +33,7 @@
(20, 5, 5, (20, 5)),
# Fewer hits than page size, but result list somehow differs, use that for count
(48, 20, 50, (20, 0)),
# Page count gets truncated always
(5000, 10, 10, (5000, MAX_TOTAL_PAGE_COUNT)),
(5000, 10, 10, (5000, 5000 / 10)),
zackkrida marked this conversation as resolved.
Show resolved Hide resolved
krysal marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_get_result_and_page_count(total_hits, real_result_count, page_size, expected):
Expand Down
79 changes: 79 additions & 0 deletions api/test/unit/serializers/media_serializers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from test.factory.models.oauth2 import AccessTokenFactory

from django.conf import settings
from rest_framework.exceptions import NotAuthenticated, ValidationError
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.views import APIView

import pytest

from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer


# TODO: @sarayourfriend consolidate these with the other
# request factory fixtures into conftest.py
@pytest.fixture
def request_factory() -> APIRequestFactory():
request_factory = APIRequestFactory(defaults={"REMOTE_ADDR": "192.0.2.1"})

return request_factory


@pytest.fixture
def access_token():
token = AccessTokenFactory.create()
token.application.verified = True
token.application.save()
return token


@pytest.fixture
def authed_request(access_token, request_factory):
request = request_factory.get("/")

force_authenticate(request, token=access_token.token)

return APIView().initialize_request(request)


@pytest.fixture
def anon_request(request_factory):
return APIView().initialize_request(request_factory.get("/"))


@pytest.mark.django_db
@pytest.mark.parametrize(
("page_size", "authenticated"),
(
pytest.param(-1, False, marks=pytest.mark.raises(exception=ValidationError)),
pytest.param(0, False, marks=pytest.mark.raises(exception=ValidationError)),
(1, False),
(settings.MAX_ANONYMOUS_PAGE_SIZE, False),
pytest.param(
settings.MAX_ANONYMOUS_PAGE_SIZE + 1,
False,
marks=pytest.mark.raises(exception=NotAuthenticated),
),
pytest.param(
settings.MAX_AUTHED_PAGE_SIZE,
False,
marks=pytest.mark.raises(exception=NotAuthenticated),
),
pytest.param(-1, True, marks=pytest.mark.raises(exception=ValidationError)),
pytest.param(0, True, marks=pytest.mark.raises(exception=ValidationError)),
(1, True),
(settings.MAX_ANONYMOUS_PAGE_SIZE + 1, True),
(settings.MAX_AUTHED_PAGE_SIZE, True),
pytest.param(
settings.MAX_AUTHED_PAGE_SIZE + 1,
True,
marks=pytest.mark.raises(exception=ValidationError),
),
),
)
def test_page_size_validation(page_size, authenticated, anon_request, authed_request):
request = authed_request if authenticated else anon_request
serializer = MediaSearchRequestSerializer(
context={"request": request}, data={"page_size": page_size}
)
serializer.is_valid(raise_exception=True)