From d0b3cc5916cb5b8f682dddc6fd0d9b5c43c97ef5 Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:10:28 +0200 Subject: [PATCH] core: only prefetch related objects when required (cherry-pick #9476) (#9510) core: only prefetch related objects when required (#9476) * core: only prefetch related objects when required * add tests * add tests to assert query count * "optimize" another query away * prefetch parent and roles * whops that needs to be pre-fetched --------- Signed-off-by: Jens Langhammer Co-authored-by: Jens L --- authentik/core/api/groups.py | 8 +++++++- authentik/core/api/users.py | 7 +++++-- authentik/core/tests/test_groups_api.py | 9 ++++++++- authentik/core/tests/test_users_api.py | 6 ++++++ authentik/core/tests/test_users_avatars.py | 2 -- authentik/events/middleware.py | 2 +- 6 files changed, 27 insertions(+), 7 deletions(-) diff --git a/authentik/core/api/groups.py b/authentik/core/api/groups.py index 6d0af25decc1..c865eedbbf70 100644 --- a/authentik/core/api/groups.py +++ b/authentik/core/api/groups.py @@ -154,12 +154,18 @@ class UserAccountSerializer(PassiveSerializer): pk = IntegerField(required=True) - queryset = Group.objects.all().select_related("parent").prefetch_related("users") + queryset = Group.objects.none() serializer_class = GroupSerializer search_fields = ["name", "is_superuser"] filterset_class = GroupFilter ordering = ["name"] + def get_queryset(self): + base_qs = Group.objects.all().select_related("parent").prefetch_related("roles") + if self.serializer_class(context={"request": self.request})._should_include_users: + base_qs = base_qs.prefetch_related("users") + return base_qs + @extend_schema( parameters=[ OpenApiParameter("include_users", bool, default=True), diff --git a/authentik/core/api/users.py b/authentik/core/api/users.py index c2b2c86dbea2..a617c1ce2e25 100644 --- a/authentik/core/api/users.py +++ b/authentik/core/api/users.py @@ -407,8 +407,11 @@ class UserViewSet(UsedByMixin, ModelViewSet): search_fields = ["username", "name", "is_active", "email", "uuid"] filterset_class = UsersFilter - def get_queryset(self): # pragma: no cover - return User.objects.all().exclude_anonymous().prefetch_related("ak_groups") + def get_queryset(self): + base_qs = User.objects.all().exclude_anonymous() + if self.serializer_class(context={"request": self.request})._should_include_groups: + base_qs = base_qs.prefetch_related("ak_groups") + return base_qs @extend_schema( parameters=[ diff --git a/authentik/core/tests/test_groups_api.py b/authentik/core/tests/test_groups_api.py index 50303d52150a..d20ef4df7ebc 100644 --- a/authentik/core/tests/test_groups_api.py +++ b/authentik/core/tests/test_groups_api.py @@ -5,7 +5,7 @@ from rest_framework.test import APITestCase from authentik.core.models import Group, User -from authentik.core.tests.utils import create_test_user +from authentik.core.tests.utils import create_test_admin_user, create_test_user from authentik.lib.generators import generate_id @@ -16,6 +16,13 @@ def setUp(self) -> None: self.login_user = create_test_user() self.user = User.objects.create(username="test-user") + def test_list_with_users(self): + """Test listing with users""" + admin = create_test_admin_user() + self.client.force_login(admin) + response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"}) + self.assertEqual(response.status_code, 200) + def test_add_user(self): """Test add_user""" group = Group.objects.create(name=generate_id()) diff --git a/authentik/core/tests/test_users_api.py b/authentik/core/tests/test_users_api.py index a2e792369343..140746f7c769 100644 --- a/authentik/core/tests/test_users_api.py +++ b/authentik/core/tests/test_users_api.py @@ -41,6 +41,12 @@ def test_filter_type(self): ) self.assertEqual(response.status_code, 200) + def test_list_with_groups(self): + """Test listing with groups""" + self.client.force_login(self.admin) + response = self.client.get(reverse("authentik_api:user-list"), {"include_groups": "true"}) + self.assertEqual(response.status_code, 200) + def test_metrics(self): """Test user's metrics""" self.client.force_login(self.admin) diff --git a/authentik/core/tests/test_users_avatars.py b/authentik/core/tests/test_users_avatars.py index dc31e9f61401..2dcfaed921e5 100644 --- a/authentik/core/tests/test_users_avatars.py +++ b/authentik/core/tests/test_users_avatars.py @@ -8,7 +8,6 @@ from authentik.core.models import User from authentik.core.tests.utils import create_test_admin_user -from authentik.lib.config import CONFIG from authentik.tenants.utils import get_current_tenant @@ -25,7 +24,6 @@ def set_avatar_mode(self, mode: str): tenant.avatars = mode tenant.save() - @CONFIG.patch("avatars", "none") def test_avatars_none(self): """Test avatars none""" self.set_avatar_mode("none") diff --git a/authentik/events/middleware.py b/authentik/events/middleware.py index 4f7ebbf37e49..20c93761a67d 100644 --- a/authentik/events/middleware.py +++ b/authentik/events/middleware.py @@ -116,12 +116,12 @@ def get_user(self, request: HttpRequest) -> User: return user user = getattr(request, "user", self.anonymous_user) if not user.is_authenticated: + self._ensure_fallback_user() return self.anonymous_user return user def connect(self, request: HttpRequest): """Connect signal for automatic logging""" - self._ensure_fallback_user() if not hasattr(request, "request_id"): return post_save.connect(