Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: add parent permission check. #328

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 142 additions & 4 deletions rest_framework_extensions/mixins.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,39 @@
from rest_framework_extensions.cache.mixins import CacheResponseMixin
# from rest_framework_extensions.etag.mixins import ReadOnlyETAGMixin, ETAGMixin
from rest_framework_extensions.bulk_operations.mixins import ListUpdateModelMixin, ListDestroyModelMixin
from rest_framework_extensions.settings import extensions_api_settings
from django.http import Http404
from rest_framework_extensions.settings import extensions_api_settings
from rest_framework import status, exceptions
from rest_framework.generics import get_object_or_404

class BulkCreateModelMixin:
"""
Builk create model instance.
Just post data like:
[
{"name": "xxx"},
{"name": "xxx2"},
]
"""

def get_serializer(self, *args, **kwargs):
if isinstance(kwargs.get('data', {}), list):
kwargs['many'] = True
s = super().get_serializer(*args, **kwargs)
return s


class MultiSerializerViewSetMixin:
"""
serializer_action_classes = {
list: ListSerializer,
<action_name>: Serializer,
...
}
"""
serializer_classes = {}
def get_serializer_class(self):
try:
return self.serializer_classes[self.action]
except (KeyError, AttributeError):
return super(MultiSerializerViewSetMixin, self).get_serializer_class()


class DetailSerializerMixin:
Expand Down Expand Up @@ -50,6 +81,113 @@ def get_page_size(self, request):


class NestedViewSetMixin:
parent_viewset = None

def check_ownership(self, serializer):
parent_query_dicts = self.get_parents_query_dict()
if not parent_query_dicts:
return

parent_lookup, parent_value = list(parent_query_dicts.items())[-1]
if "__" in parent_lookup:
receive_key, _ = parent_lookup.split("__")
else:
receive_key = parent_lookup

instance_datas = serializer.validated_data
if not isinstance(instance_datas, list):
instance_datas = [instance_datas]
received_parent_values = [
i.get(receive_key) for i in instance_datas if i.get(receive_key)]

# 1. check filled parent field
if len(received_parent_values) != len(instance_datas):
raise exceptions.PermissionDenied(
detail=f"You must specific '{parent_lookup}'", code=status.HTTP_403_FORBIDDEN)

received_parent_values = [str(v) if isinstance(v, (str, int)) else
str(getattr(v, self.parent_viewset.lookup_field))
for v in received_parent_values
]
# 2. check direct FK parent
if not "__" in parent_lookup:
not_blong = [
v for v in received_parent_values if v != str(parent_value)
]
if not_blong:
raise exceptions.PermissionDenied(
detail=f"You don't have permission to operate item that belong to '{parent_lookup}:{not_blong}'", code=status.HTTP_403_FORBIDDEN)
else:
# 3. for multiple layer parent
direct_parent, direct_parent_look_field = parent_lookup.split(
'__', 1)
current_model = self.get_queryset().model
direct_parent_model = current_model._meta.get_field(
direct_parent
).related_model
direct_parent_instances = direct_parent_model.objects.filter(
**{f"pk__in": received_parent_values}
)
fields = direct_parent_look_field.split("__")
for instance in direct_parent_instances:
final_parent_obj = instance
for f in fields:
final_parent_obj = getattr(instance, f)
if (received_value := str(getattr(final_parent_obj, self.parent_viewset.lookup_field))) != str(parent_value):
raise exceptions.PermissionDenied(
detail=f"You don't have permission to operate item that belong to '{parent_lookup}:{received_value}'", code=status.HTTP_403_FORBIDDEN)

def perform_create(self, serializer):
self.check_ownership(serializer)
super().perform_create(serializer)

def perform_update(self, serializer):
self.check_ownership(serializer)
super().perform_update(serializer)

def get_parent_model(self, current_model, parent_model_lookup_name):
parent_model = current_model
for lookup_name in parent_model_lookup_name.split("__"):
parent_model = parent_model._meta.get_field(
lookup_name).related_model
return parent_model

def check_parent_object_permissions(self, request):
# if parent viewset haven't init yet, then will raise no "kwargs" attribute error, but it doesn't matter, just ignore
try:
if not (parents_query_dict := self.get_parents_query_dict()):
return
except:
return
# 2. for generic relations case.
current_model = self.get_queryset().model
current_viewset = self

for parent_model_lookup_name, parent_model_lookup_value in sorted(parents_query_dict.items(), key=lambda item: len(item[0])):
parent_model = self.get_parent_model(
current_model, parent_model_lookup_name)
parent_viewset = current_viewset.parent_viewset()

parent_obj = get_object_or_404(
parent_model.objects.all(),
**{parent_viewset.lookup_field: parent_model_lookup_value}
)
parent_viewset.check_object_permissions(
request, parent_obj
)

current_viewset = parent_viewset

def check_permissions(self, request):
super().check_permissions(request)
if self.parent_viewset:
self.check_parent_object_permissions(request)

def check_object_permissions(self, request, obj):
super().check_object_permissions(request, obj)
if self.parent_viewset:
self.check_parent_object_permissions(request)

def get_queryset(self):
return self.filter_queryset_by_parents_lookups(
super().get_queryset()
Expand Down
25 changes: 20 additions & 5 deletions rest_framework_extensions/routers.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,42 @@
from copy import deepcopy
from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework_extensions.utils import compose_parent_pk_kwarg_name


class NestedRegistryItem:
def __init__(self, router, parent_prefix, parent_item=None, parent_viewset=None):
def __init__(self, router, parent_prefix, parent_item=None, parent_viewset=None, parent_lookups=[]):
self.router = router
self.parent_prefix = parent_prefix
self.parent_item = parent_item
self.parent_viewset = parent_viewset
self.parent_lookups = parent_lookups

def register(self, prefix, viewset, basename, parents_query_lookups=[], parent_query_lookup=""):
copied_viewset = type(viewset.__name__, (viewset,), {
k: v for k, v in viewset.__dict__.items()})
if not parents_query_lookups:
parents_query_lookups = ["__".join(
[parent_query_lookup, pl]) for pl in self.parent_lookups] + [parent_query_lookup]

def register(self, prefix, viewset, basename, parents_query_lookups):
self.router._register(
prefix=self.get_prefix(
current_prefix=prefix,
parents_query_lookups=parents_query_lookups),
viewset=viewset,
parents_query_lookups=parents_query_lookups
),
viewset=copied_viewset,
basename=basename,
)
copied_viewset.parent_viewset = self.parent_viewset
v = copied_viewset
while v.parent_viewset:
v = v.parent_viewset

return NestedRegistryItem(
router=self.router,
parent_prefix=prefix,
parent_item=self,
parent_viewset=viewset
parent_viewset=copied_viewset,
parent_lookups=parents_query_lookups
)

def get_prefix(self, current_prefix, parents_query_lookups):
Expand Down
33 changes: 33 additions & 0 deletions rest_framework_extensions/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,39 @@ def get_fields_for_partial_update(opts, init_data, fields, init_files=None):
return sorted(set(update_fields))


class BulkCreateModelMixin:
"""
Builk create model instance.
Just post data like:
[
{"name": "xxx"},
{"name": "xxx2"},
]
"""

def get_serializer(self, *args, **kwargs):
if isinstance(kwargs.get('data', {}), list):
kwargs['many'] = True
s = super().get_serializer(*args, **kwargs)
return s


class MultiSerializerViewSetMixin:
"""
serializer_action_classes = {
list: ListSerializer,
<action_name>: Serializer,
...
}
"""
serializer_classes = {}
def get_serializer_class(self):
try:
return self.serializer_classes[self.action]
except (KeyError, AttributeError):
return super(MultiSerializerViewSetMixin, self).get_serializer_class()


class PartialUpdateSerializerMixin:
def save(self, **kwargs):
self._update_fields = kwargs.get('update_fields', None)
Expand Down