Skip to content

Commit

Permalink
core: only prefetch related objects when required (cherry-pick #9476) (
Browse files Browse the repository at this point in the history
…#9510)

core: only prefetch related objects when required (#9476)

* core: only prefetch related objects when required



* add tests



* add tests to assert query count



* "optimize" another query away



* prefetch parent and roles



* whops that needs to be pre-fetched



---------

Signed-off-by: Jens Langhammer <[email protected]>
Co-authored-by: Jens L <[email protected]>
  • Loading branch information
gcp-cherry-pick-bot[bot] and BeryJu authored Apr 29, 2024
1 parent e034f5e commit d0b3cc5
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 7 deletions.
8 changes: 7 additions & 1 deletion authentik/core/api/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,18 @@ class UserAccountSerializer(PassiveSerializer):

pk = IntegerField(required=True)

queryset = Group.objects.all().select_related("parent").prefetch_related("users")
queryset = Group.objects.none()
serializer_class = GroupSerializer
search_fields = ["name", "is_superuser"]
filterset_class = GroupFilter
ordering = ["name"]

def get_queryset(self):
base_qs = Group.objects.all().select_related("parent").prefetch_related("roles")
if self.serializer_class(context={"request": self.request})._should_include_users:
base_qs = base_qs.prefetch_related("users")
return base_qs

@extend_schema(
parameters=[
OpenApiParameter("include_users", bool, default=True),
Expand Down
7 changes: 5 additions & 2 deletions authentik/core/api/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,11 @@ class UserViewSet(UsedByMixin, ModelViewSet):
search_fields = ["username", "name", "is_active", "email", "uuid"]
filterset_class = UsersFilter

def get_queryset(self): # pragma: no cover
return User.objects.all().exclude_anonymous().prefetch_related("ak_groups")
def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous()
if self.serializer_class(context={"request": self.request})._should_include_groups:
base_qs = base_qs.prefetch_related("ak_groups")
return base_qs

@extend_schema(
parameters=[
Expand Down
9 changes: 8 additions & 1 deletion authentik/core/tests/test_groups_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rest_framework.test import APITestCase

from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.core.tests.utils import create_test_admin_user, create_test_user
from authentik.lib.generators import generate_id


Expand All @@ -16,6 +16,13 @@ def setUp(self) -> None:
self.login_user = create_test_user()
self.user = User.objects.create(username="test-user")

def test_list_with_users(self):
"""Test listing with users"""
admin = create_test_admin_user()
self.client.force_login(admin)
response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"})
self.assertEqual(response.status_code, 200)

def test_add_user(self):
"""Test add_user"""
group = Group.objects.create(name=generate_id())
Expand Down
6 changes: 6 additions & 0 deletions authentik/core/tests/test_users_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def test_filter_type(self):
)
self.assertEqual(response.status_code, 200)

def test_list_with_groups(self):
"""Test listing with groups"""
self.client.force_login(self.admin)
response = self.client.get(reverse("authentik_api:user-list"), {"include_groups": "true"})
self.assertEqual(response.status_code, 200)

def test_metrics(self):
"""Test user's metrics"""
self.client.force_login(self.admin)
Expand Down
2 changes: 0 additions & 2 deletions authentik/core/tests/test_users_avatars.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.config import CONFIG
from authentik.tenants.utils import get_current_tenant


Expand All @@ -25,7 +24,6 @@ def set_avatar_mode(self, mode: str):
tenant.avatars = mode
tenant.save()

@CONFIG.patch("avatars", "none")
def test_avatars_none(self):
"""Test avatars none"""
self.set_avatar_mode("none")
Expand Down
2 changes: 1 addition & 1 deletion authentik/events/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def get_user(self, request: HttpRequest) -> User:
return user
user = getattr(request, "user", self.anonymous_user)
if not user.is_authenticated:
self._ensure_fallback_user()
return self.anonymous_user
return user

def connect(self, request: HttpRequest):
"""Connect signal for automatic logging"""
self._ensure_fallback_user()
if not hasattr(request, "request_id"):
return
post_save.connect(
Expand Down

0 comments on commit d0b3cc5

Please sign in to comment.