Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use api_client, move fixtures around to share basic defaults #3560

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions api/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Fixtures usable by or necessary for both unit and integration tests."""

from test.fixtures.asynchronous import ensure_asgi_lifecycle, get_new_loop, session_loop
from test.fixtures.cache import (
django_cache,
redis,
unreachable_django_cache,
unreachable_redis,
)
from test.fixtures.rest_framework import api_client, request_factory


__all__ = [
"ensure_asgi_lifecycle",
"get_new_loop",
"session_loop",
"django_cache",
"redis",
"unreachable_django_cache",
"unreachable_redis",
"api_client",
"request_factory",
]
15 changes: 15 additions & 0 deletions api/test/fixtures/rest_framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from rest_framework.test import APIClient, APIRequestFactory

import pytest


@pytest.fixture
def api_client():
return APIClient()


@pytest.fixture
def request_factory() -> APIRequestFactory():
request_factory = APIRequestFactory(defaults={"REMOTE_ADDR": "192.0.2.1"})

return request_factory
18 changes: 0 additions & 18 deletions api/test/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
from test.fixtures.asynchronous import ensure_asgi_lifecycle, session_loop
from test.fixtures.cache import (
django_cache,
redis,
unreachable_django_cache,
unreachable_redis,
)

import pytest


Expand All @@ -19,13 +11,3 @@ def django_db_setup():
"""

pass


__all__ = [
"ensure_asgi_lifecycle",
"session_loop",
"django_cache",
"redis",
"unreachable_django_cache",
"unreachable_redis",
]
8 changes: 4 additions & 4 deletions api/test/integration/test_audio_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
pytestmark = pytest.mark.django_db


def test_audio_detail_without_thumb(client):
resp = client.get("/v1/audio/44540200-91eb-483d-9e99-38ce86a52fb6/")
def test_audio_detail_without_thumb(api_client):
resp = api_client.get("/v1/audio/44540200-91eb-483d-9e99-38ce86a52fb6/")
assert resp.status_code == 200
parsed = resp.json()
assert parsed["thumbnail"] is None


def test_audio_search_without_thumb(client):
def test_audio_search_without_thumb(api_client):
"""The first audio of this search should not have a thumbnail."""
resp = client.get("/v1/audio/?q=zaus")
resp = api_client.get("/v1/audio/?q=zaus")
assert resp.status_code == 200
parsed = resp.json()
assert parsed["results"][0]["thumbnail"] is None
84 changes: 45 additions & 39 deletions api/test/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from test.constants import API_URL

from django.urls import reverse
from django.utils.http import urlencode

import pytest
from oauth2_provider.models import AccessToken
Expand Down Expand Up @@ -36,13 +35,13 @@ def unreachable_oauth_cache(unreachable_django_cache, monkeypatch):

@pytest.mark.django_db
@pytest.fixture
def test_auth_tokens_registration(client):
def test_auth_tokens_registration(api_client):
data = {
"name": f"INTEGRATION TEST APPLICATION {uuid.uuid4()}",
"description": "A key for testing the OAuth2 registration process.",
"email": "[email protected]",
}
res = client.post(
res = api_client.post(
"/v1/auth_tokens/register/",
data,
verify=False,
Expand All @@ -54,20 +53,19 @@ def test_auth_tokens_registration(client):

@pytest.mark.django_db
@pytest.fixture
def test_auth_token_exchange(client, test_auth_tokens_registration):
client_id = test_auth_tokens_registration["client_id"]
client_secret = test_auth_tokens_registration["client_secret"]
data = urlencode(
{
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "client_credentials",
}
)
res = client.post(
def test_auth_token_exchange(api_client, test_auth_tokens_registration):
api_client_id = test_auth_tokens_registration["client_id"]
api_client_secret = test_auth_tokens_registration["client_secret"]
data = {
"client_id": api_client_id,
"client_secret": api_client_secret,
"grant_type": "client_credentials",
}

res = api_client.post(
"/v1/auth_tokens/token/",
data,
"application/x-www-form-urlencoded",
"multipart",
Comment on lines -70 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's part of DRF's api_client's testing request renderer formats. It takes a slug corresponding to the renderer rather than the content type string: https://www.django-rest-framework.org/api-guide/testing/#setting-the-available-formats

Here's an SO answer explaining it as well: https://stackoverflow.com/questions/68528408/difference-between-content-type-application-json-and-format-json-in-drf-apiclien

DRF's testing APIClient actually gave a pretty easy to understand error message about this, I was not aware of this until fixing that error when I changed this test to api_client.

verify=False,
)
res_data = res.json()
Expand All @@ -76,20 +74,20 @@ def test_auth_token_exchange(client, test_auth_tokens_registration):


@pytest.mark.django_db
def test_auth_token_exchange_unsupported_method(client):
res = client.get(
def test_auth_token_exchange_unsupported_method(api_client):
res = api_client.get(
"/v1/auth_tokens/token/",
verify=False,
)
assert res.status_code == 405
assert res.json()["detail"] == 'Method "GET" not allowed.'


def _integration_verify_most_recent_token(client):
def _integration_verify_most_recent_token(api_client):
verify = OAuth2Verification.objects.last()
code = verify.code
path = reverse("verify-email", args=[code])
return client.get(path)
return api_client.get(path)


@pytest.mark.django_db
Expand All @@ -108,17 +106,17 @@ def _integration_verify_most_recent_token(client):
)
def test_auth_email_verification(
request,
client,
api_client,
is_cache_reachable,
cache_name,
rate_limit_model,
test_auth_token_exchange,
):
res = _integration_verify_most_recent_token(client)
res = _integration_verify_most_recent_token(api_client)
assert res.status_code == 200
test_auth_rate_limit_reporting(
request,
client,
api_client,
is_cache_reachable,
cache_name,
rate_limit_model,
Expand All @@ -135,7 +133,7 @@ def test_auth_email_verification(
@cache_availability_params
def test_auth_rate_limit_reporting(
request,
client,
api_client,
is_cache_reachable,
cache_name,
rate_limit_model,
Expand All @@ -151,7 +149,7 @@ def test_auth_rate_limit_reporting(
application = AccessToken.objects.get(token=token).application
application.rate_limit_model = rate_limit_model
application.save()
res = client.get("/v1/rate_limit/", HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get("/v1/rate_limit/", HTTP_AUTHORIZATION=f"Bearer {token}")
res_data = res.json()
if is_cache_reachable:
assert res.status_code == 200
Expand All @@ -174,14 +172,14 @@ def test_auth_rate_limit_reporting(
(True, False),
)
def test_auth_response_headers(
client, verified, test_auth_tokens_registration, test_auth_token_exchange
api_client, verified, test_auth_tokens_registration, test_auth_token_exchange
):
if verified:
_integration_verify_most_recent_token(client)
_integration_verify_most_recent_token(api_client)

token = test_auth_token_exchange["access_token"]

res = client.get("/v1/images/", HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get("/v1/images/", HTTP_AUTHORIZATION=f"Bearer {token}")

assert (
res.headers["x-ov-client-application-name"]
Expand All @@ -190,8 +188,8 @@ def test_auth_response_headers(
assert res.headers["x-ov-client-application-verified"] == str(verified)


def test_unauthed_response_headers(client):
res = client.get("/v1/images")
def test_unauthed_response_headers(api_client):
res = api_client.get("/v1/images")

assert "x-ov-client-application-name" not in res.headers
assert "x-ov-client-application-verified" not in res.headers
Expand All @@ -206,15 +204,17 @@ def test_unauthed_response_headers(client):
],
)
def test_sorting_authed(
client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on
api_client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on
):
# Prevent DB lookup for ES results because DB is empty.
monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False)

time.sleep(1)
token = test_auth_token_exchange["access_token"]
query_params = {"unstable__sort_by": "indexed_on", "unstable__sort_dir": sort_dir}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200

res_data = res.json()
Expand All @@ -231,7 +231,7 @@ def test_sorting_authed(
],
)
def test_authority_authed(
client, monkeypatch, test_auth_token_exchange, authority_boost, exp_source
api_client, monkeypatch, test_auth_token_exchange, authority_boost, exp_source
):
# Prevent DB lookup for ES results because DB is empty.
monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False)
Expand All @@ -243,7 +243,9 @@ def test_authority_authed(
"unstable__authority": "true",
"unstable__authority_boost": authority_boost,
}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200

res_data = res.json()
Expand All @@ -252,23 +254,27 @@ def test_authority_authed(


@pytest.mark.django_db
def test_page_size_limit_unauthed(client):
def test_page_size_limit_unauthed(api_client):
query_params = {"page_size": 20}
res = client.get("/v1/images/", query_params)
res = api_client.get("/v1/images/", query_params)
assert res.status_code == 200
query_params["page_size"] = 21
res = client.get("/v1/images/", query_params)
res = api_client.get("/v1/images/", query_params)
assert res.status_code == 401


@pytest.mark.django_db
def test_page_size_limit_authed(client, test_auth_token_exchange):
def test_page_size_limit_authed(api_client, test_auth_token_exchange):
time.sleep(1)
token = test_auth_token_exchange["access_token"]
query_params = {"page_size": 21}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200

query_params = {"page_size": 500}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
res = api_client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200
20 changes: 11 additions & 9 deletions api/test/integration/test_dead_link_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,20 @@ def _make_head_requests(urls):

@pytest.mark.django_db
@_patch_make_head_requests()
def test_dead_link_filtering(mocked_map, client):
def test_dead_link_filtering(mocked_map, api_client):
path = "/v1/images/"
query_params = {"q": "*", "page_size": 20}

# Make a request that does not filter dead links...
res_with_dead_links = client.get(
res_with_dead_links = api_client.get(
path,
query_params | {"filter_dead": False},
)
# ...and ensure that our patched function was not called
mocked_map.assert_not_called()

# Make a request that filters dead links...
res_without_dead_links = client.get(
res_without_dead_links = api_client.get(
path,
query_params | {"filter_dead": True},
)
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_dead_link_filtering(mocked_map, client):
),
)
def test_dead_link_filtering_all_dead_links(
client,
api_client,
filter_dead,
page_size,
expected_result_count,
Expand All @@ -130,7 +130,7 @@ def test_dead_link_filtering_all_dead_links(
query_params = {"q": "*", "page_size": page_size}

with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO):
response = client.get(
response = api_client.get(
path,
query_params | {"filter_dead": filter_dead},
)
Expand All @@ -145,11 +145,11 @@ def test_dead_link_filtering_all_dead_links(


@pytest.fixture
def search_factory(client):
def search_factory(api_client):
"""Allow passing url parameters along with a search request."""

def _parameterized_search(**kwargs):
response = client.get("/v1/images/", kwargs)
response = api_client.get("/v1/images/", kwargs)
assert response.status_code == 200
parsed = response.json()
return parsed
Expand Down Expand Up @@ -207,6 +207,8 @@ def no_duplicates(xs):


@pytest.mark.django_db
def test_max_page_count(client):
response = client.get("/v1/images/", {"page": settings.MAX_PAGINATION_DEPTH + 1})
def test_max_page_count(api_client):
response = api_client.get(
"/v1/images/", {"page": settings.MAX_PAGINATION_DEPTH + 1}
)
assert response.status_code == 400
Loading
Loading