Skip to content

Commit

Permalink
Use api_client, move fixtures around to share basic defaults (#3560)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarayourfriend authored Dec 20, 2023
1 parent d1d370e commit b5eb751
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 160 deletions.
23 changes: 23 additions & 0 deletions api/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Fixtures usable by or necessary for both unit and integration tests."""

from test.fixtures.asynchronous import ensure_asgi_lifecycle, get_new_loop, session_loop
from test.fixtures.cache import (
django_cache,
redis,
unreachable_django_cache,
unreachable_redis,
)
from test.fixtures.rest_framework import api_client, request_factory


__all__ = [
"ensure_asgi_lifecycle",
"get_new_loop",
"session_loop",
"django_cache",
"redis",
"unreachable_django_cache",
"unreachable_redis",
"api_client",
"request_factory",
]
15 changes: 15 additions & 0 deletions api/test/fixtures/rest_framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from rest_framework.test import APIClient, APIRequestFactory

import pytest


@pytest.fixture
def api_client():
return APIClient()


@pytest.fixture
def request_factory() -> APIRequestFactory():
request_factory = APIRequestFactory(defaults={"REMOTE_ADDR": "192.0.2.1"})

return request_factory
18 changes: 0 additions & 18 deletions api/test/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
from test.fixtures.asynchronous import ensure_asgi_lifecycle, session_loop
from test.fixtures.cache import (
django_cache,
redis,
unreachable_django_cache,
unreachable_redis,
)

import pytest


Expand All @@ -19,13 +11,3 @@ def django_db_setup():
"""

pass


__all__ = [
"ensure_asgi_lifecycle",
"session_loop",
"django_cache",
"redis",
"unreachable_django_cache",
"unreachable_redis",
]
8 changes: 4 additions & 4 deletions api/test/integration/test_audio_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
pytestmark = pytest.mark.django_db


def test_audio_detail_without_thumb(client):
resp = client.get("/v1/audio/44540200-91eb-483d-9e99-38ce86a52fb6/")
def test_audio_detail_without_thumb(api_client):
resp = api_client.get("/v1/audio/44540200-91eb-483d-9e99-38ce86a52fb6/")
assert resp.status_code == 200
parsed = resp.json()
assert parsed["thumbnail"] is None


def test_audio_search_without_thumb(client):
def test_audio_search_without_thumb(api_client):
"""The first audio of this search should not have a thumbnail."""
resp = client.get("/v1/audio/?q=zaus")
resp = api_client.get("/v1/audio/?q=zaus")
assert resp.status_code == 200
parsed = resp.json()
assert parsed["results"][0]["thumbnail"] is None
84 changes: 45 additions & 39 deletions api/test/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from test.constants import API_URL

from django.urls import reverse
from django.utils.http import urlencode

import pytest
from oauth2_provider.models import AccessToken
Expand Down Expand Up @@ -37,13 +36,13 @@ def unreachable_oauth_cache(unreachable_django_cache, monkeypatch):

@pytest.mark.django_db
@pytest.fixture
def test_auth_tokens_registration(client):
def test_auth_tokens_registration(api_client):
data = {
"name": f"INTEGRATION TEST APPLICATION {uuid.uuid4()}",
"description": "A key for testing the OAuth2 registration process.",
"email": "[email protected]",
}
res = client.post(
res = api_client.post(
"/v1/auth_tokens/register/",
data,
verify=False,
Expand All @@ -55,20 +54,19 @@ def test_auth_tokens_registration(client):

@pytest.mark.django_db
@pytest.fixture
def test_auth_token_exchange(client, test_auth_tokens_registration):
client_id = test_auth_tokens_registration["client_id"]
client_secret = test_auth_tokens_registration["client_secret"]
data = urlencode(
{
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "client_credentials",
}
)
res = client.post(
def test_auth_token_exchange(api_client, test_auth_tokens_registration):
api_client_id = test_auth_tokens_registration["client_id"]
api_client_secret = test_auth_tokens_registration["client_secret"]
data = {
"client_id": api_client_id,
"client_secret": api_client_secret,
"grant_type": "client_credentials",
}

res = api_client.post(
"/v1/auth_tokens/token/",
data,
"application/x-www-form-urlencoded",
"multipart",
verify=False,
)
res_data = res.json()
Expand All @@ -77,20 +75,20 @@ def test_auth_token_exchange(client, test_auth_tokens_registration):


@pytest.mark.django_db
def test_auth_token_exchange_unsupported_method(client):
res = client.get(
def test_auth_token_exchange_unsupported_method(api_client):
res = api_client.get(
"/v1/auth_tokens/token/",
verify=False,
)
assert res.status_code == 405
assert res.json()["detail"] == 'Method "GET" not allowed.'


def _integration_verify_most_recent_token(client):
def _integration_verify_most_recent_token(api_client):
verify = OAuth2Verification.objects.last()
code = verify.code
path = reverse("verify-email", args=[code])
return client.get(path)
return api_client.get(path)


@pytest.mark.django_db
Expand All @@ -109,17 +107,17 @@ def _integration_verify_most_recent_token(client):
)
def test_auth_email_verification(
request,
client,
api_client,
is_cache_reachable,
cache_name,
rate_limit_model,
test_auth_token_exchange,
):
res = _integration_verify_most_recent_token(client)
res = _integration_verify_most_recent_token(api_client)
assert res.status_code == 200
test_auth_rate_limit_reporting(
request,
client,
api_client,
is_cache_reachable,
cache_name,
rate_limit_model,
Expand All @@ -136,7 +134,7 @@ def test_auth_email_verification(
@cache_availability_params
def test_auth_rate_limit_reporting(
request,
client,
api_client,
is_cache_reachable,
cache_name,
rate_limit_model,
Expand All @@ -152,7 +150,7 @@ def test_auth_rate_limit_reporting(
application = AccessToken.objects.get(token=token).application
application.rate_limit_model = rate_limit_model
application.save()
res = client.get("/v1/rate_limit/", HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get("/v1/rate_limit/", HTTP_AUTHORIZATION=f"Bearer {token}")
res_data = res.json()
if is_cache_reachable:
assert res.status_code == 200
Expand All @@ -175,14 +173,14 @@ def test_auth_rate_limit_reporting(
(True, False),
)
def test_auth_response_headers(
client, verified, test_auth_tokens_registration, test_auth_token_exchange
api_client, verified, test_auth_tokens_registration, test_auth_token_exchange
):
if verified:
_integration_verify_most_recent_token(client)
_integration_verify_most_recent_token(api_client)

token = test_auth_token_exchange["access_token"]

res = client.get("/v1/images/", HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get("/v1/images/", HTTP_AUTHORIZATION=f"Bearer {token}")

assert (
res.headers["x-ov-client-application-name"]
Expand All @@ -191,8 +189,8 @@ def test_auth_response_headers(
assert res.headers["x-ov-client-application-verified"] == str(verified)


def test_unauthed_response_headers(client):
res = client.get("/v1/images")
def test_unauthed_response_headers(api_client):
res = api_client.get("/v1/images")

assert "x-ov-client-application-name" not in res.headers
assert "x-ov-client-application-verified" not in res.headers
Expand All @@ -207,15 +205,17 @@ def test_unauthed_response_headers(client):
],
)
def test_sorting_authed(
client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on
api_client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on
):
# Prevent DB lookup for ES results because DB is empty.
monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False)

time.sleep(1)
token = test_auth_token_exchange["access_token"]
query_params = {"unstable__sort_by": "indexed_on", "unstable__sort_dir": sort_dir}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200

res_data = res.json()
Expand All @@ -232,7 +232,7 @@ def test_sorting_authed(
],
)
def test_authority_authed(
client, monkeypatch, test_auth_token_exchange, authority_boost, exp_source
api_client, monkeypatch, test_auth_token_exchange, authority_boost, exp_source
):
# Prevent DB lookup for ES results because DB is empty.
monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False)
Expand All @@ -244,7 +244,9 @@ def test_authority_authed(
"unstable__authority": "true",
"unstable__authority_boost": authority_boost,
}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200

res_data = res.json()
Expand All @@ -253,23 +255,27 @@ def test_authority_authed(


@pytest.mark.django_db
def test_page_size_limit_unauthed(client):
def test_page_size_limit_unauthed(api_client):
query_params = {"page_size": 20}
res = client.get("/v1/images/", query_params)
res = api_client.get("/v1/images/", query_params)
assert res.status_code == 200
query_params["page_size"] = 21
res = client.get("/v1/images/", query_params)
res = api_client.get("/v1/images/", query_params)
assert res.status_code == 401


@pytest.mark.django_db
def test_page_size_limit_authed(client, test_auth_token_exchange):
def test_page_size_limit_authed(api_client, test_auth_token_exchange):
time.sleep(1)
token = test_auth_token_exchange["access_token"]
query_params = {"page_size": 21}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200

query_params = {"page_size": 500}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200
20 changes: 11 additions & 9 deletions api/test/integration/test_dead_link_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,20 @@ def _make_head_requests(urls):

@pytest.mark.django_db
@_patch_make_head_requests()
def test_dead_link_filtering(mocked_map, client):
def test_dead_link_filtering(mocked_map, api_client):
path = "/v1/images/"
query_params = {"q": "*", "page_size": 20}

# Make a request that does not filter dead links...
res_with_dead_links = client.get(
res_with_dead_links = api_client.get(
path,
query_params | {"filter_dead": False},
)
# ...and ensure that our patched function was not called
mocked_map.assert_not_called()

# Make a request that filters dead links...
res_without_dead_links = client.get(
res_without_dead_links = api_client.get(
path,
query_params | {"filter_dead": True},
)
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_dead_link_filtering(mocked_map, client):
),
)
def test_dead_link_filtering_all_dead_links(
client,
api_client,
filter_dead,
page_size,
expected_result_count,
Expand All @@ -130,7 +130,7 @@ def test_dead_link_filtering_all_dead_links(
query_params = {"q": "*", "page_size": page_size}

with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO):
response = client.get(
response = api_client.get(
path,
query_params | {"filter_dead": filter_dead},
)
Expand All @@ -145,11 +145,11 @@ def test_dead_link_filtering_all_dead_links(


@pytest.fixture
def search_factory(client):
def search_factory(api_client):
"""Allow passing url parameters along with a search request."""

def _parameterized_search(**kwargs):
response = client.get("/v1/images/", kwargs)
response = api_client.get("/v1/images/", kwargs)
assert response.status_code == 200
parsed = response.json()
return parsed
Expand Down Expand Up @@ -207,6 +207,8 @@ def no_duplicates(xs):


@pytest.mark.django_db
def test_max_page_count(client):
response = client.get("/v1/images/", {"page": settings.MAX_PAGINATION_DEPTH + 1})
def test_max_page_count(api_client):
response = api_client.get(
"/v1/images/", {"page": settings.MAX_PAGINATION_DEPTH + 1}
)
assert response.status_code == 400
Loading

0 comments on commit b5eb751

Please sign in to comment.