diff --git a/api/catalog/api/views/health_views.py b/api/catalog/api/views/health_views.py index 84448f264..b744094a5 100644 --- a/api/catalog/api/views/health_views.py +++ b/api/catalog/api/views/health_views.py @@ -1,10 +1,19 @@ +from django.conf import settings +from rest_framework import status +from rest_framework.exceptions import APIException +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView +class ElasticsearchHealthcheckException(APIException): + status_code = status.HTTP_503_SERVICE_UNAVAILABLE + + class HealthCheck(APIView): """ - Returns a "200 OK" response if the server is running normally. + Returns a "200 OK" response if the server is running normally. Returns 503 + otherwise. This endpoint is used in production to ensure that the server should receive traffic. If no response is provided, the server is deregistered from the @@ -13,5 +22,20 @@ class HealthCheck(APIView): swagger_schema = None - def get(self, request, format=None): + def _check_es(self) -> Response | None: + """ + Checks Elasticsearch cluster health. Raises an exception if ES is not healthy. + """ + es_health = settings.ES.cluster.health(timeout="5s") + + if es_health["timed_out"]: + raise ElasticsearchHealthcheckException("es_timed_out") + + if (status := es_health["status"]) != "green": + raise ElasticsearchHealthcheckException(f"es_status_{status}") + + def get(self, request: Request): + if "check_es" in request.query_params: + self._check_es() + return Response({"status": "200 OK"}, status=200) diff --git a/api/test/unit/conftest.py b/api/test/unit/conftest.py new file mode 100644 index 000000000..fbadf3d12 --- /dev/null +++ b/api/test/unit/conftest.py @@ -0,0 +1,8 @@ +from rest_framework.test import APIClient + +import pytest + + +@pytest.fixture +def api_client(): + return APIClient() diff --git a/api/test/unit/views/health_views_test.py b/api/test/unit/views/health_views_test.py new file mode 100644 index 000000000..2ddf3b3d5 --- /dev/null +++ b/api/test/unit/views/health_views_test.py @@ -0,0 +1,51 @@ +import pook +import pytest + + +def mock_health_response(status="green", timed_out=False): + return ( + pook.get(pook.regex(r"_cluster\/health")) + .times(1) + .reply(200) + .json( + { + "status": status if not timed_out else None, + "timed_out": timed_out, + } + ) + ) + + +def test_health_check_plain(api_client): + res = api_client.get("/healthcheck/") + assert res.status_code == 200 + + +def test_health_check_es_timed_out(api_client): + mock_health_response(timed_out=True) + pook.on() + res = api_client.get("/healthcheck/", data={"check_es": True}) + pook.off() + + assert res.status_code == 503 + assert res.json()["detail"] == "es_timed_out" + + +@pytest.mark.parametrize("status", ("yellow", "red")) +def test_health_check_es_status_bad(status, api_client): + mock_health_response(status=status) + pook.on() + res = api_client.get("/healthcheck/", data={"check_es": True}) + pook.off() + + assert res.status_code == 503 + assert res.json()["detail"] == f"es_status_{status}" + + +def test_health_check_es_all_good(api_client): + mock_health_response(status="green") + pook.on() + res = api_client.get("/healthcheck/", data={"check_es": True}) + pook.off() + + assert res.status_code == 200 diff --git a/api/test/unit/views/image_views_test.py b/api/test/unit/views/image_views_test.py index a74339dbb..a9580f01c 100644 --- a/api/test/unit/views/image_views_test.py +++ b/api/test/unit/views/image_views_test.py @@ -4,8 +4,6 @@ from pathlib import Path from test.factory.models.image import ImageFactory -from rest_framework.test import APIClient - import pytest from requests import Request, Response @@ -17,11 +15,6 @@ _MOCK_IMAGE_INFO = json.loads((_MOCK_IMAGE_PATH / "sample-image-info.json").read_text()) -@pytest.fixture -def api_client(): - return APIClient() - - @dataclass class RequestsFixture: requests: list[Request] diff --git a/api/test/unit/views/media_views_test.py b/api/test/unit/views/media_views_test.py index 5ca567b7d..c297fda40 100644 --- a/api/test/unit/views/media_views_test.py +++ b/api/test/unit/views/media_views_test.py @@ -7,8 +7,6 @@ from unittest import mock from unittest.mock import MagicMock, patch -from rest_framework.test import APIClient - import pytest import pytest_django.asserts import requests as requests_lib @@ -23,11 +21,6 @@ _MOCK_IMAGE_INFO = json.loads((_MOCK_IMAGE_PATH / "sample-image-info.json").read_text()) -@pytest.fixture -def api_client(): - return APIClient() - - @dataclass class SentRequest: request: PreparedRequest