Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add parent permission check. #3

Merged
merged 3 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 64 additions & 57 deletions rest_framework_extensions/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,44 @@
from django.core.exceptions import ValidationError
# from rest_framework_extensions.etag.mixins import ReadOnlyETAGMixin, ETAGMixin
from django.http import Http404
from django.db import models
from rest_framework_extensions.bulk_operations.mixins import ListUpdateModelMixin, ListDestroyModelMixin
from rest_framework_extensions.settings import extensions_api_settings
from rest_framework import status, exceptions
from rest_framework.generics import get_object_or_404

class BulkCreateModelMixin:
"""
Builk create model instance.
Just post data like:
[
{"name": "xxx"},
{"name": "xxx2"},
]
"""

def get_serializer(self, *args, **kwargs):
if isinstance(kwargs.get('data', {}), list):
kwargs['many'] = True
s = super().get_serializer(*args, **kwargs)
return s


class MultiSerializerViewSetMixin:
"""
serializer_action_classes = {
list: ListSerializer,
<action_name>: Serializer,
...
}
"""
serializer_classes = {}
def get_serializer_class(self):
try:
return self.serializer_classes[self.action]
except (KeyError, AttributeError):
return super(MultiSerializerViewSetMixin, self).get_serializer_class()


class DetailSerializerMixin:
"""
Expand Down Expand Up @@ -52,55 +85,26 @@ def get_page_size(self, request):
# pass


class BulkCreateModelMixin:
"""
Builk create model instance.
Just post data like:
[
{"name": "xxx"},
{"name": "xxx2"},
]
"""

def get_serializer(self, *args, **kwargs):
if isinstance(kwargs.get('data', {}), list):
kwargs['many'] = True
s = super().get_serializer(*args, **kwargs)
return s


class MultiSerializerViewSetMixin:
"""
serializer_action_classes = {
list: ListSerializer,
<action_name>: Serializer,
...
}
"""
serializer_action_classes = {}
def get_serializer_class(self):
try:
return self.serializer_action_classes[self.action]
except (KeyError, AttributeError):
return super(MultiSerializerViewSetMixin, self).get_serializer_class()



class NestedViewSetMixin:
parent_viewsets = set()

def check_ownership(self, serializer):
parent_query_dicts = self.get_parents_query_dict()
if parent_query_dicts:
parent_name, parent_value = list(parent_query_dicts.items())[-1]
items = serializer.validated_data
if not isinstance(items, list):
items = [items]
for item in items:
if item.get(parent_name, None) is None:
instance_datas = serializer.validated_data
if not isinstance(instance_datas, list):
instance_datas = [instance_datas]
for instance_data in instance_datas:
if instance_data.get(parent_name, None) is None:
raise exceptions.PermissionDenied(
detail=f"You must specific '{parent_name}'", code=status.HTTP_403_FORBIDDEN)
if item.get(parent_name, None) != parent_value:
received_parent_value = instance_data.get(parent_name, None)
print(received_parent_value)
if not isinstance(received_parent_value, (str, int)):
received_parent_value = getattr(
received_parent_value, self.parent_viewset.lookup_field)
if str(received_parent_value) != str(parent_value):
raise exceptions.PermissionDenied(
detail=f"You don't have permission to operate item that belone to '{parent_name}:{parent_value}'", code=status.HTTP_403_FORBIDDEN)

Expand All @@ -112,6 +116,13 @@ def perform_update(self, serializer):
self.check_ownership(serializer)
super().perform_update(serializer)

def get_parent_model(self, current_model, parent_model_lookup_name):
parent_model = current_model
for lookup_name in parent_model_lookup_name.split("__"):
parent_model = parent_model._meta.get_field(
lookup_name).related_model
return parent_model

def check_parent_object_permissions(self, request):
# if parent viewset haven't init yet, then will raise no "kwargs" attribute error, but it doesn't matter, just ignore
try:
Expand All @@ -122,25 +133,21 @@ def check_parent_object_permissions(self, request):
return
current_model = self.get_queryset().model
# TODO
# 1. for model__submodel case(Done).
# 1. for model__submodel case.
# 2. for generic relations case.
for parent_model_lookup_name, parent_model_lookup_value in reversed(parents_query_dict.items()):
parent_model = current_model
for lookup_name in parent_model_lookup_name.split("__"):
parent_model = parent_model._meta.get_field(
lookup_name).related_model
for parent_viewset_class in self.parent_viewsets:
parent_viewset = parent_viewset_class()
parent_viewset_model = getattr(
parent_viewset, "model", None) or parent_viewset.queryset.model
if parent_viewset_model == parent_model:
parent_obj = get_object_or_404(
parent_viewset_model.objects.all(),
**{parent_viewset.lookup_field: parent_model_lookup_value}
)
parent_viewset.check_object_permissions(
request, parent_obj
)
parent_model = get_parent_model(
current_model, parent_model_lookup_name)
parent_viewset = self.parent_viewset()
parent_viewset_model = getattr(
parent_viewset, "model", None) or parent_viewset.queryset.model
parent_obj = get_object_or_404(
parent_viewset_model.objects.all(),
**{parent_viewset.lookup_field: parent_model_lookup_value}
)
parent_viewset.check_object_permissions(
request, parent_obj
)
current_model = parent_model

def check_permissions(self, request):
Expand Down
9 changes: 6 additions & 3 deletions rest_framework_extensions/routers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework_extensions.utils import compose_parent_pk_kwarg_name

Expand All @@ -10,19 +11,21 @@ def __init__(self, router, parent_prefix, parent_item=None, parent_viewset=None)
self.parent_viewset = parent_viewset

def register(self, prefix, viewset, basename, parents_query_lookups):
# deepcopy to make sure one viewset class only has one parent viewset
copied_viewset = deepcopy(viewset)
self.router._register(
prefix=self.get_prefix(
current_prefix=prefix,
parents_query_lookups=parents_query_lookups),
viewset=viewset,
viewset=copied_viewset,
basename=basename,
)
viewset.parent_viewsets.add(self.parent_viewset)
copied_viewset.parent_viewset = self.parent_viewset
return NestedRegistryItem(
router=self.router,
parent_prefix=prefix,
parent_item=self,
parent_viewset=viewset
parent_viewset=copied_viewset
)

def get_prefix(self, current_prefix, parents_query_lookups):
Expand Down