diff --git a/api/test/conftest.py b/api/test/conftest.py new file mode 100644 index 00000000000..1aecc9b86c7 --- /dev/null +++ b/api/test/conftest.py @@ -0,0 +1,23 @@ +"""Fixtures usable by or necessary for both unit and integration tests.""" + +from test.fixtures.asynchronous import ensure_asgi_lifecycle, get_new_loop, session_loop +from test.fixtures.cache import ( + django_cache, + redis, + unreachable_django_cache, + unreachable_redis, +) +from test.fixtures.rest_framework import api_client, request_factory + + +__all__ = [ + "ensure_asgi_lifecycle", + "get_new_loop", + "session_loop", + "django_cache", + "redis", + "unreachable_django_cache", + "unreachable_redis", + "api_client", + "request_factory", +] diff --git a/api/test/fixtures/rest_framework.py b/api/test/fixtures/rest_framework.py new file mode 100644 index 00000000000..3359b0a81df --- /dev/null +++ b/api/test/fixtures/rest_framework.py @@ -0,0 +1,15 @@ +from rest_framework.test import APIClient, APIRequestFactory + +import pytest + + +@pytest.fixture +def api_client(): + return APIClient() + + +@pytest.fixture +def request_factory() -> APIRequestFactory(): + request_factory = APIRequestFactory(defaults={"REMOTE_ADDR": "192.0.2.1"}) + + return request_factory diff --git a/api/test/integration/conftest.py b/api/test/integration/conftest.py index 5374364d0e1..bd9e998175c 100644 --- a/api/test/integration/conftest.py +++ b/api/test/integration/conftest.py @@ -1,11 +1,3 @@ -from test.fixtures.asynchronous import ensure_asgi_lifecycle, session_loop -from test.fixtures.cache import ( - django_cache, - redis, - unreachable_django_cache, - unreachable_redis, -) - import pytest @@ -19,13 +11,3 @@ def django_db_setup(): """ pass - - -__all__ = [ - "ensure_asgi_lifecycle", - "session_loop", - "django_cache", - "redis", - "unreachable_django_cache", - "unreachable_redis", -] diff --git a/api/test/integration/test_audio_integration.py b/api/test/integration/test_audio_integration.py index ef1fbc57eda..6f6e5481d1c 100644 --- a/api/test/integration/test_audio_integration.py +++ b/api/test/integration/test_audio_integration.py @@ -11,16 +11,16 @@ pytestmark = pytest.mark.django_db -def test_audio_detail_without_thumb(client): - resp = client.get("/v1/audio/44540200-91eb-483d-9e99-38ce86a52fb6/") +def test_audio_detail_without_thumb(api_client): + resp = api_client.get("/v1/audio/44540200-91eb-483d-9e99-38ce86a52fb6/") assert resp.status_code == 200 parsed = resp.json() assert parsed["thumbnail"] is None -def test_audio_search_without_thumb(client): +def test_audio_search_without_thumb(api_client): """The first audio of this search should not have a thumbnail.""" - resp = client.get("/v1/audio/?q=zaus") + resp = api_client.get("/v1/audio/?q=zaus") assert resp.status_code == 200 parsed = resp.json() assert parsed["results"][0]["thumbnail"] is None diff --git a/api/test/integration/test_auth.py b/api/test/integration/test_auth.py index 6387b1818ed..85aeadafcdf 100644 --- a/api/test/integration/test_auth.py +++ b/api/test/integration/test_auth.py @@ -3,7 +3,6 @@ from test.constants import API_URL from django.urls import reverse -from django.utils.http import urlencode import pytest from oauth2_provider.models import AccessToken @@ -37,13 +36,13 @@ def unreachable_oauth_cache(unreachable_django_cache, monkeypatch): @pytest.mark.django_db @pytest.fixture -def test_auth_tokens_registration(client): +def test_auth_tokens_registration(api_client): data = { "name": f"INTEGRATION TEST APPLICATION {uuid.uuid4()}", "description": "A key for testing the OAuth2 registration process.", "email": "example@example.org", } - res = client.post( + res = api_client.post( "/v1/auth_tokens/register/", data, verify=False, @@ -55,20 +54,19 @@ def test_auth_tokens_registration(client): @pytest.mark.django_db @pytest.fixture -def test_auth_token_exchange(client, test_auth_tokens_registration): - client_id = test_auth_tokens_registration["client_id"] - client_secret = test_auth_tokens_registration["client_secret"] - data = urlencode( - { - "client_id": client_id, - "client_secret": client_secret, - "grant_type": "client_credentials", - } - ) - res = client.post( +def test_auth_token_exchange(api_client, test_auth_tokens_registration): + api_client_id = test_auth_tokens_registration["client_id"] + api_client_secret = test_auth_tokens_registration["client_secret"] + data = { + "client_id": api_client_id, + "client_secret": api_client_secret, + "grant_type": "client_credentials", + } + + res = api_client.post( "/v1/auth_tokens/token/", data, - "application/x-www-form-urlencoded", + "multipart", verify=False, ) res_data = res.json() @@ -77,8 +75,8 @@ def test_auth_token_exchange(client, test_auth_tokens_registration): @pytest.mark.django_db -def test_auth_token_exchange_unsupported_method(client): - res = client.get( +def test_auth_token_exchange_unsupported_method(api_client): + res = api_client.get( "/v1/auth_tokens/token/", verify=False, ) @@ -86,11 +84,11 @@ def test_auth_token_exchange_unsupported_method(client): assert res.json()["detail"] == 'Method "GET" not allowed.' -def _integration_verify_most_recent_token(client): +def _integration_verify_most_recent_token(api_client): verify = OAuth2Verification.objects.last() code = verify.code path = reverse("verify-email", args=[code]) - return client.get(path) + return api_client.get(path) @pytest.mark.django_db @@ -109,17 +107,17 @@ def _integration_verify_most_recent_token(client): ) def test_auth_email_verification( request, - client, + api_client, is_cache_reachable, cache_name, rate_limit_model, test_auth_token_exchange, ): - res = _integration_verify_most_recent_token(client) + res = _integration_verify_most_recent_token(api_client) assert res.status_code == 200 test_auth_rate_limit_reporting( request, - client, + api_client, is_cache_reachable, cache_name, rate_limit_model, @@ -136,7 +134,7 @@ def test_auth_email_verification( @cache_availability_params def test_auth_rate_limit_reporting( request, - client, + api_client, is_cache_reachable, cache_name, rate_limit_model, @@ -152,7 +150,7 @@ def test_auth_rate_limit_reporting( application = AccessToken.objects.get(token=token).application application.rate_limit_model = rate_limit_model application.save() - res = client.get("/v1/rate_limit/", HTTP_AUTHORIZATION=f"Bearer {token}") + res = api_client.get("/v1/rate_limit/", HTTP_AUTHORIZATION=f"Bearer {token}") res_data = res.json() if is_cache_reachable: assert res.status_code == 200 @@ -175,14 +173,14 @@ def test_auth_rate_limit_reporting( (True, False), ) def test_auth_response_headers( - client, verified, test_auth_tokens_registration, test_auth_token_exchange + api_client, verified, test_auth_tokens_registration, test_auth_token_exchange ): if verified: - _integration_verify_most_recent_token(client) + _integration_verify_most_recent_token(api_client) token = test_auth_token_exchange["access_token"] - res = client.get("/v1/images/", HTTP_AUTHORIZATION=f"Bearer {token}") + res = api_client.get("/v1/images/", HTTP_AUTHORIZATION=f"Bearer {token}") assert ( res.headers["x-ov-client-application-name"] @@ -191,8 +189,8 @@ def test_auth_response_headers( assert res.headers["x-ov-client-application-verified"] == str(verified) -def test_unauthed_response_headers(client): - res = client.get("/v1/images") +def test_unauthed_response_headers(api_client): + res = api_client.get("/v1/images") assert "x-ov-client-application-name" not in res.headers assert "x-ov-client-application-verified" not in res.headers @@ -207,7 +205,7 @@ def test_unauthed_response_headers(client): ], ) def test_sorting_authed( - client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on + api_client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on ): # Prevent DB lookup for ES results because DB is empty. monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False) @@ -215,7 +213,9 @@ def test_sorting_authed( time.sleep(1) token = test_auth_token_exchange["access_token"] query_params = {"unstable__sort_by": "indexed_on", "unstable__sort_dir": sort_dir} - res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}") + res = api_client.get( + "/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}" + ) assert res.status_code == 200 res_data = res.json() @@ -232,7 +232,7 @@ def test_sorting_authed( ], ) def test_authority_authed( - client, monkeypatch, test_auth_token_exchange, authority_boost, exp_source + api_client, monkeypatch, test_auth_token_exchange, authority_boost, exp_source ): # Prevent DB lookup for ES results because DB is empty. monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False) @@ -244,7 +244,9 @@ def test_authority_authed( "unstable__authority": "true", "unstable__authority_boost": authority_boost, } - res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}") + res = api_client.get( + "/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}" + ) assert res.status_code == 200 res_data = res.json() @@ -253,23 +255,27 @@ def test_authority_authed( @pytest.mark.django_db -def test_page_size_limit_unauthed(client): +def test_page_size_limit_unauthed(api_client): query_params = {"page_size": 20} - res = client.get("/v1/images/", query_params) + res = api_client.get("/v1/images/", query_params) assert res.status_code == 200 query_params["page_size"] = 21 - res = client.get("/v1/images/", query_params) + res = api_client.get("/v1/images/", query_params) assert res.status_code == 401 @pytest.mark.django_db -def test_page_size_limit_authed(client, test_auth_token_exchange): +def test_page_size_limit_authed(api_client, test_auth_token_exchange): time.sleep(1) token = test_auth_token_exchange["access_token"] query_params = {"page_size": 21} - res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}") + res = api_client.get( + "/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}" + ) assert res.status_code == 200 query_params = {"page_size": 500} - res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}") + res = api_client.get( + "/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}" + ) assert res.status_code == 200 diff --git a/api/test/integration/test_dead_link_filter.py b/api/test/integration/test_dead_link_filter.py index 4b3bf347f7a..55690338727 100644 --- a/api/test/integration/test_dead_link_filter.py +++ b/api/test/integration/test_dead_link_filter.py @@ -74,12 +74,12 @@ def _make_head_requests(urls): @pytest.mark.django_db @_patch_make_head_requests() -def test_dead_link_filtering(mocked_map, client): +def test_dead_link_filtering(mocked_map, api_client): path = "/v1/images/" query_params = {"q": "*", "page_size": 20} # Make a request that does not filter dead links... - res_with_dead_links = client.get( + res_with_dead_links = api_client.get( path, query_params | {"filter_dead": False}, ) @@ -87,7 +87,7 @@ def test_dead_link_filtering(mocked_map, client): mocked_map.assert_not_called() # Make a request that filters dead links... - res_without_dead_links = client.get( + res_without_dead_links = api_client.get( path, query_params | {"filter_dead": True}, ) @@ -119,7 +119,7 @@ def test_dead_link_filtering(mocked_map, client): ), ) def test_dead_link_filtering_all_dead_links( - client, + api_client, filter_dead, page_size, expected_result_count, @@ -130,7 +130,7 @@ def test_dead_link_filtering_all_dead_links( query_params = {"q": "*", "page_size": page_size} with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO): - response = client.get( + response = api_client.get( path, query_params | {"filter_dead": filter_dead}, ) @@ -145,11 +145,11 @@ def test_dead_link_filtering_all_dead_links( @pytest.fixture -def search_factory(client): +def search_factory(api_client): """Allow passing url parameters along with a search request.""" def _parameterized_search(**kwargs): - response = client.get("/v1/images/", kwargs) + response = api_client.get("/v1/images/", kwargs) assert response.status_code == 200 parsed = response.json() return parsed @@ -207,6 +207,8 @@ def no_duplicates(xs): @pytest.mark.django_db -def test_max_page_count(client): - response = client.get("/v1/images/", {"page": settings.MAX_PAGINATION_DEPTH + 1}) +def test_max_page_count(api_client): + response = api_client.get( + "/v1/images/", {"page": settings.MAX_PAGINATION_DEPTH + 1} + ) assert response.status_code == 400 diff --git a/api/test/integration/test_deprecations.py b/api/test/integration/test_deprecations.py index ce675587b6a..a6b3d30bebf 100644 --- a/api/test/integration/test_deprecations.py +++ b/api/test/integration/test_deprecations.py @@ -12,12 +12,12 @@ ("/v1/thumbs/{idx}", "/v1/images/{idx}/thumb/"), ], ) -def test_deprecated_endpoints_redirect_to_new(old, new, client): +def test_deprecated_endpoints_redirect_to_new(old, new, api_client): idx = uuid.uuid4() old = old.format(idx=str(idx)) new = new.format(idx=str(idx)) - res = client.get(old) + res = api_client.get(old) assert res.status_code == 301 assert res.headers.get("Location") == new @@ -33,6 +33,6 @@ def test_deprecated_endpoints_redirect_to_new(old, new, client): ), ], ) -def test_deleted_endpoints_are_gone(method, path, kwargs, client): - res = getattr(client, method)(path, **kwargs) +def test_deleted_endpoints_are_gone(method, path, kwargs, api_client): + res = getattr(api_client, method)(path, **kwargs) assert res.status_code == 410 diff --git a/api/test/integration/test_image_integration.py b/api/test/integration/test_image_integration.py index 528ba62888e..4f6dac5afe1 100644 --- a/api/test/integration/test_image_integration.py +++ b/api/test/integration/test_image_integration.py @@ -12,8 +12,8 @@ @pytest.fixture -def image_fixture(client): - response = client.get("/v1/images/", {"q": "dog"}) +def image_fixture(api_client): + response = api_client.get("/v1/images/", {"q": "dog"}) assert response.status_code == 200 parsed = response.json() return parsed @@ -44,21 +44,23 @@ def image_fixture(client): ), ], ) -def test_oembed_endpoint(image_fixture, url: str, expected_status_code: int, client): +def test_oembed_endpoint( + image_fixture, url: str, expected_status_code: int, api_client +): if "{identifier}" in url: url = url.format(identifier=image_fixture["results"][0]["id"]) params = {"url": url} - response = client.get("/v1/images/oembed/", params) + response = api_client.get("/v1/images/oembed/", params) assert response.status_code == expected_status_code -def test_oembed_endpoint_for_json(image_fixture, client): +def test_oembed_endpoint_for_json(image_fixture, api_client): identifier = image_fixture["results"][0]["id"] params = { "url": f"https://any.domain/any/path/{identifier}", # 'format': 'json' is the default } - response = client.get("/v1/images/oembed/", params) + response = api_client.get("/v1/images/oembed/", params) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json" diff --git a/api/test/integration/test_media_integration.py b/api/test/integration/test_media_integration.py index 02baa9aff96..99094f71d4e 100644 --- a/api/test/integration/test_media_integration.py +++ b/api/test/integration/test_media_integration.py @@ -63,8 +63,8 @@ def media_type(request): @pytest.fixture -def search_results(media_type: MediaType, client) -> tuple[MediaType, dict]: - res = client.get(f"/v1/{media_type.path}/", {"q": media_type.q}) +def search_results(media_type: MediaType, api_client) -> tuple[MediaType, dict]: + res = api_client.get(f"/v1/{media_type.path}/", {"q": media_type.q}) assert res.status_code == 200 data = res.json() @@ -80,9 +80,9 @@ def single_result(search_results) -> tuple[MediaType, dict]: @pytest.fixture -def related_results(single_result, client) -> tuple[MediaType, dict, dict]: +def related_results(single_result, api_client) -> tuple[MediaType, dict, dict]: media_type, item = single_result - res = client.get(f"/v1/{media_type.path}/{item['id']}/related/") + res = api_client.get(f"/v1/{media_type.path}/{item['id']}/related/") assert res.status_code == 200 data = res.json() @@ -90,9 +90,9 @@ def related_results(single_result, client) -> tuple[MediaType, dict, dict]: @pytest.fixture -def sensitive_result(media_type: MediaType, client) -> tuple[MediaType, dict]: +def sensitive_result(media_type: MediaType, api_client) -> tuple[MediaType, dict]: q = "bird" # Not using the default ``q`` from ``media_type``. - res = client.get( + res = api_client.get( f"/v1/{media_type.path}/", {"q": q, "unstable__include_sensitive_results": True}, ) @@ -112,8 +112,8 @@ def sensitive_result(media_type: MediaType, client) -> tuple[MediaType, dict]: ############## -def test_stats(media_type: MediaType, client): - res = client.get(f"/v1/{media_type.path}/stats/") +def test_stats(media_type: MediaType, api_client): + res = api_client.get(f"/v1/{media_type.path}/stats/") data = res.json() num_media = 0 provider_count = 0 @@ -134,16 +134,16 @@ def test_search_returns_non_zero_results(search_results): assert data["result_count"] > 0 -def test_search_handles_unbalanced_quotes_with_ok(media_type: MediaType, client): - res = client.get(f"/v1/{media_type.path}/", {"q": f'"{media_type.q}'}) +def test_search_handles_unbalanced_quotes_with_ok(media_type: MediaType, api_client): + res = api_client.get(f"/v1/{media_type.path}/", {"q": f'"{media_type.q}'}) assert res.status_code == 200 data = res.json() assert data["result_count"] > 0 -def test_search_handles_special_chars_with_ok(media_type: MediaType, client): - res = client.get(f"/v1/{media_type.path}/", {"q": f"{media_type.q}!"}) +def test_search_handles_special_chars_with_ok(media_type: MediaType, api_client): + res = api_client.get(f"/v1/{media_type.path}/", {"q": f"{media_type.q}!"}) assert res.status_code == 200 data = res.json() @@ -155,7 +155,7 @@ def test_search_results_have_non_es_fields(search_results): _check_non_es_fields_are_present(data["results"]) -def test_search_removes_dupes_from_initial_pages(media_type: MediaType, client): +def test_search_removes_dupes_from_initial_pages(media_type: MediaType, api_client): """ Return consistent, non-duplicate results in the first n pages. @@ -168,7 +168,7 @@ def test_search_removes_dupes_from_initial_pages(media_type: MediaType, client): num_pages = 5 searches = { - client.get(f"/v1/{media_type.path}/", {"page": page}) + api_client.get(f"/v1/{media_type.path}/", {"page": page}) for page in range(1, num_pages) } @@ -185,7 +185,7 @@ def test_search_removes_dupes_from_initial_pages(media_type: MediaType, client): "search_field, match_field", [("q", "title"), ("creator", "creator")] ) def test_search_quotes_matches_only_exact( - media_type: MediaType, search_field, match_field, client + media_type: MediaType, search_field, match_field, api_client ): # We want a query containing more than one word. if match_field == "title": @@ -196,7 +196,7 @@ def test_search_quotes_matches_only_exact( base_params = {"unstable__include_sensitive_results": True} path = f"/v1/{media_type.path}/" - unquoted_res = client.get(path, base_params | {search_field: q}) + unquoted_res = api_client.get(path, base_params | {search_field: q}) assert unquoted_res.status_code == 200 unquoted_data = unquoted_res.json() @@ -207,7 +207,7 @@ def test_search_quotes_matches_only_exact( exact_matches = [q in item[match_field] for item in unquoted_results].count(True) assert 0 < exact_matches < unquoted_result_count - quoted_res = client.get(path, base_params | {search_field: f'"{q}"'}) + quoted_res = api_client.get(path, base_params | {search_field: f'"{q}"'}) assert quoted_res.status_code == 200 quoted_data = quoted_res.json() @@ -223,9 +223,9 @@ def test_search_quotes_matches_only_exact( assert quoted_result_count < unquoted_result_count -def test_search_filters_by_source(media_type: MediaType, client): +def test_search_filters_by_source(media_type: MediaType, api_client): provider = media_type.providers[0] - res = client.get( + res = api_client.get( f"/v1/{media_type.path}/", {"q": media_type.q, "source": provider}, ) @@ -236,8 +236,10 @@ def test_search_filters_by_source(media_type: MediaType, client): assert all(result["source"] == provider for result in data["results"]) -def test_search_returns_zero_results_when_all_excluded(media_type: MediaType, client): - res = client.get( +def test_search_returns_zero_results_when_all_excluded( + media_type: MediaType, api_client +): + res = api_client.get( f"/v1/{media_type.path}/", {"q": media_type.q, "excluded_source": ",".join(media_type.providers)}, ) @@ -247,8 +249,8 @@ def test_search_returns_zero_results_when_all_excluded(media_type: MediaType, cl assert data["result_count"] == 0 -def test_search_refuses_both_sources_and_excluded(media_type: MediaType, client): - res = client.get( +def test_search_refuses_both_sources_and_excluded(media_type: MediaType, api_client): + res = api_client.get( f"/v1/{media_type.path}/", {"q": media_type.q, "source": "x", "excluded_source": "y"}, ) @@ -269,9 +271,9 @@ def test_search_refuses_both_sources_and_excluded(media_type: MediaType, client) ], ) def test_search_filters_by_license( - media_type: MediaType, filter_rule, exp_licenses, client + media_type: MediaType, filter_rule, exp_licenses, api_client ): - res = client.get(f"/v1/{media_type.path}/", filter_rule) + res = api_client.get(f"/v1/{media_type.path}/", filter_rule) assert res.status_code == 200 data = res.json() @@ -279,9 +281,9 @@ def test_search_filters_by_license( assert all(result["license"] in exp_licenses for result in data["results"]) -def test_search_filters_by_extension(media_type: MediaType, client): +def test_search_filters_by_extension(media_type: MediaType, api_client): ext = "mp3" if media_type.name == "audio" else "jpg" - res = client.get(f"/v1/{media_type.path}/", {"extension": ext}) + res = api_client.get(f"/v1/{media_type.path}/", {"extension": ext}) assert res.status_code == 200 data = res.json() @@ -289,9 +291,9 @@ def test_search_filters_by_extension(media_type: MediaType, client): assert all(result["filetype"] == ext for result in data["results"]) -def test_search_filters_by_category(media_type: MediaType, client): +def test_search_filters_by_category(media_type: MediaType, api_client): for category in media_type.categories: - res = client.get(f"/v1/{media_type.path}/", {"category": category}) + res = api_client.get(f"/v1/{media_type.path}/", {"category": category}) assert res.status_code == 200 data = res.json() @@ -299,8 +301,8 @@ def test_search_filters_by_category(media_type: MediaType, client): assert all(result["category"] == category for result in data["results"]) -def test_search_refuses_invalid_categories(media_type: MediaType, client): - res = client.get(f"/v1/{media_type.path}/", {"category": "invalid_category"}) +def test_search_refuses_invalid_categories(media_type: MediaType, api_client): + res = api_client.get(f"/v1/{media_type.path}/", {"category": "invalid_category"}) assert res.status_code == 400 @@ -318,21 +320,21 @@ def test_search_refuses_invalid_categories(media_type: MediaType, client): ], ) def test_detail_view_for_invalid_uuids_returns_not_found( - media_type: MediaType, bad_uuid: str, client + media_type: MediaType, bad_uuid: str, api_client ): - res = client.get(f"/v1/{media_type.path}/{bad_uuid}/") + res = api_client.get(f"/v1/{media_type.path}/{bad_uuid}/") assert res.status_code == 404 -def test_detail_view_returns_ok(single_result, client): +def test_detail_view_returns_ok(single_result, api_client): media_type, item = single_result - res = client.get(f"/v1/{media_type.path}/{item['id']}/") + res = api_client.get(f"/v1/{media_type.path}/{item['id']}/") assert res.status_code == 200 -def test_detail_view_contains_sensitivity_info(sensitive_result, client): +def test_detail_view_contains_sensitivity_info(sensitive_result, api_client): media_type, item = sensitive_result - res = client.get(f"/v1/{media_type.path}/{item['id']}/") + res = api_client.get(f"/v1/{media_type.path}/{item['id']}/") assert res.status_code == 200 data = res.json() @@ -380,15 +382,15 @@ def test_related_results_have_non_es_fields(related_results): ############### -def test_report_is_created(single_result, client): +def test_report_is_created(single_result, api_client): media_type, item = single_result - res = client.post( + res = api_client.post( f"/v1/{media_type.path}/{item['id']}/report/", - data={ + { "reason": "mature", "description": "This item contains sensitive content", }, - content_type="application/json", + "json", ) assert res.status_code == 201 @@ -401,8 +403,8 @@ def test_report_is_created(single_result, client): #################### -def test_collection_by_tag(media_type: MediaType, client): - res = client.get(f"/v1/{media_type.path}/tag/cat/") +def test_collection_by_tag(media_type: MediaType, api_client): + res = api_client.get(f"/v1/{media_type.path}/tag/cat/") assert res.status_code == 200 data = res.json() @@ -412,10 +414,10 @@ def test_collection_by_tag(media_type: MediaType, client): assert "cat" in tag_names -def test_collection_by_source(media_type: MediaType, client): - source = client.get(f"/v1/{media_type.path}/stats/").json()[0]["source_name"] +def test_collection_by_source(media_type: MediaType, api_client): + source = api_client.get(f"/v1/{media_type.path}/stats/").json()[0]["source_name"] - res = client.get(f"/v1/{media_type.path}/source/{source}/") + res = api_client.get(f"/v1/{media_type.path}/source/{source}/") assert res.status_code == 200 data = res.json() @@ -423,15 +425,15 @@ def test_collection_by_source(media_type: MediaType, client): assert all(result["source"] == source for result in data["results"]) -def test_collection_by_creator(media_type: MediaType, client): - source_res = client.get(f"/v1/{media_type.path}/stats/") +def test_collection_by_creator(media_type: MediaType, api_client): + source_res = api_client.get(f"/v1/{media_type.path}/stats/") source = source_res.json()[0]["source_name"] - first_res = client.get(f"/v1/{media_type.path}/source/{source}/") + first_res = api_client.get(f"/v1/{media_type.path}/source/{source}/") first = first_res.json()["results"][0] assert (creator := first.get("creator")) - res = client.get(f"/v1/{media_type.path}/source/{source}/creator/{creator}/") + res = api_client.get(f"/v1/{media_type.path}/source/{source}/creator/{creator}/") assert res.status_code == 200 data = res.json() diff --git a/api/test/unit/conftest.py b/api/test/unit/conftest.py index db4ccccbbbd..d894a0e4b24 100644 --- a/api/test/unit/conftest.py +++ b/api/test/unit/conftest.py @@ -5,17 +5,8 @@ MediaFactory, MediaReportFactory, ) -from test.fixtures.asynchronous import ensure_asgi_lifecycle, get_new_loop, session_loop -from test.fixtures.cache import ( - django_cache, - redis, - unreachable_django_cache, - unreachable_redis, -) from unittest.mock import MagicMock -from rest_framework.test import APIClient, APIRequestFactory - import pook import pytest from elasticsearch import Elasticsearch @@ -46,11 +37,6 @@ ) -@pytest.fixture -def api_client(): - return APIClient() - - @pytest.fixture(autouse=True) def sentry_capture_exception(monkeypatch): mock = MagicMock() @@ -59,13 +45,6 @@ def sentry_capture_exception(monkeypatch): yield mock -@pytest.fixture -def request_factory() -> APIRequestFactory(): - request_factory = APIRequestFactory(defaults={"REMOTE_ADDR": "192.0.2.1"}) - - return request_factory - - @dataclass class MediaTypeConfig: media_type: str @@ -168,16 +147,7 @@ def cleanup_elasticsearch_test_documents(request, settings): __all__ = [ - "ensure_asgi_lifecycle", - "get_new_loop", - "session_loop", - "django_cache", - "redis", - "unreachable_django_cache", - "unreachable_redis", - "api_client", "sentry_capture_exception", - "request_factory", "image_media_type_config", "audio_media_type_config", "media_type_config",