diff --git a/rest_framework_extensions/mixins.py b/rest_framework_extensions/mixins.py index 8a56c5f..3f80e4d 100644 --- a/rest_framework_extensions/mixins.py +++ b/rest_framework_extensions/mixins.py @@ -1,8 +1,39 @@ -from rest_framework_extensions.cache.mixins import CacheResponseMixin -# from rest_framework_extensions.etag.mixins import ReadOnlyETAGMixin, ETAGMixin -from rest_framework_extensions.bulk_operations.mixins import ListUpdateModelMixin, ListDestroyModelMixin -from rest_framework_extensions.settings import extensions_api_settings from django.http import Http404 +from rest_framework_extensions.settings import extensions_api_settings +from rest_framework import status, exceptions +from rest_framework.generics import get_object_or_404 + +class BulkCreateModelMixin: + """ + Builk create model instance. + Just post data like: + [ + {"name": "xxx"}, + {"name": "xxx2"}, + ] + """ + + def get_serializer(self, *args, **kwargs): + if isinstance(kwargs.get('data', {}), list): + kwargs['many'] = True + s = super().get_serializer(*args, **kwargs) + return s + + +class MultiSerializerViewSetMixin: + """ + serializer_action_classes = { + list: ListSerializer, + : Serializer, + ... + } + """ + serializer_classes = {} + def get_serializer_class(self): + try: + return self.serializer_classes[self.action] + except (KeyError, AttributeError): + return super(MultiSerializerViewSetMixin, self).get_serializer_class() class DetailSerializerMixin: @@ -50,6 +81,113 @@ def get_page_size(self, request): class NestedViewSetMixin: + parent_viewset = None + + def check_ownership(self, serializer): + parent_query_dicts = self.get_parents_query_dict() + if not parent_query_dicts: + return + + parent_lookup, parent_value = list(parent_query_dicts.items())[-1] + if "__" in parent_lookup: + receive_key, _ = parent_lookup.split("__") + else: + receive_key = parent_lookup + + instance_datas = serializer.validated_data + if not isinstance(instance_datas, list): + instance_datas = [instance_datas] + received_parent_values = [ + i.get(receive_key) for i in instance_datas if i.get(receive_key)] + + # 1. check filled parent field + if len(received_parent_values) != len(instance_datas): + raise exceptions.PermissionDenied( + detail=f"You must specific '{parent_lookup}'", code=status.HTTP_403_FORBIDDEN) + + received_parent_values = [str(v) if isinstance(v, (str, int)) else + str(getattr(v, self.parent_viewset.lookup_field)) + for v in received_parent_values + ] + # 2. check direct FK parent + if not "__" in parent_lookup: + not_blong = [ + v for v in received_parent_values if v != str(parent_value) + ] + if not_blong: + raise exceptions.PermissionDenied( + detail=f"You don't have permission to operate item that belong to '{parent_lookup}:{not_blong}'", code=status.HTTP_403_FORBIDDEN) + else: + # 3. for multiple layer parent + direct_parent, direct_parent_look_field = parent_lookup.split( + '__', 1) + current_model = self.get_queryset().model + direct_parent_model = current_model._meta.get_field( + direct_parent + ).related_model + direct_parent_instances = direct_parent_model.objects.filter( + **{f"pk__in": received_parent_values} + ) + fields = direct_parent_look_field.split("__") + for instance in direct_parent_instances: + final_parent_obj = instance + for f in fields: + final_parent_obj = getattr(instance, f) + if (received_value := str(getattr(final_parent_obj, self.parent_viewset.lookup_field))) != str(parent_value): + raise exceptions.PermissionDenied( + detail=f"You don't have permission to operate item that belong to '{parent_lookup}:{received_value}'", code=status.HTTP_403_FORBIDDEN) + + def perform_create(self, serializer): + self.check_ownership(serializer) + super().perform_create(serializer) + + def perform_update(self, serializer): + self.check_ownership(serializer) + super().perform_update(serializer) + + def get_parent_model(self, current_model, parent_model_lookup_name): + parent_model = current_model + for lookup_name in parent_model_lookup_name.split("__"): + parent_model = parent_model._meta.get_field( + lookup_name).related_model + return parent_model + + def check_parent_object_permissions(self, request): + # if parent viewset haven't init yet, then will raise no "kwargs" attribute error, but it doesn't matter, just ignore + try: + if not (parents_query_dict := self.get_parents_query_dict()): + return + except: + return + # 2. for generic relations case. + current_model = self.get_queryset().model + current_viewset = self + + for parent_model_lookup_name, parent_model_lookup_value in sorted(parents_query_dict.items(), key=lambda item: len(item[0])): + parent_model = self.get_parent_model( + current_model, parent_model_lookup_name) + parent_viewset = current_viewset.parent_viewset() + + parent_obj = get_object_or_404( + parent_model.objects.all(), + **{parent_viewset.lookup_field: parent_model_lookup_value} + ) + parent_viewset.check_object_permissions( + request, parent_obj + ) + + current_viewset = parent_viewset + + def check_permissions(self, request): + super().check_permissions(request) + if self.parent_viewset: + self.check_parent_object_permissions(request) + + def check_object_permissions(self, request, obj): + super().check_object_permissions(request, obj) + if self.parent_viewset: + self.check_parent_object_permissions(request) + def get_queryset(self): return self.filter_queryset_by_parents_lookups( super().get_queryset() diff --git a/rest_framework_extensions/routers.py b/rest_framework_extensions/routers.py index 2f119f8..a3ce5b2 100644 --- a/rest_framework_extensions/routers.py +++ b/rest_framework_extensions/routers.py @@ -1,27 +1,42 @@ +from copy import deepcopy from rest_framework.routers import DefaultRouter, SimpleRouter from rest_framework_extensions.utils import compose_parent_pk_kwarg_name class NestedRegistryItem: - def __init__(self, router, parent_prefix, parent_item=None, parent_viewset=None): + def __init__(self, router, parent_prefix, parent_item=None, parent_viewset=None, parent_lookups=[]): self.router = router self.parent_prefix = parent_prefix self.parent_item = parent_item self.parent_viewset = parent_viewset + self.parent_lookups = parent_lookups + + def register(self, prefix, viewset, basename, parents_query_lookups=[], parent_query_lookup=""): + copied_viewset = type(viewset.__name__, (viewset,), { + k: v for k, v in viewset.__dict__.items()}) + if not parents_query_lookups: + parents_query_lookups = ["__".join( + [parent_query_lookup, pl]) for pl in self.parent_lookups] + [parent_query_lookup] - def register(self, prefix, viewset, basename, parents_query_lookups): self.router._register( prefix=self.get_prefix( current_prefix=prefix, - parents_query_lookups=parents_query_lookups), - viewset=viewset, + parents_query_lookups=parents_query_lookups + ), + viewset=copied_viewset, basename=basename, ) + copied_viewset.parent_viewset = self.parent_viewset + v = copied_viewset + while v.parent_viewset: + v = v.parent_viewset + return NestedRegistryItem( router=self.router, parent_prefix=prefix, parent_item=self, - parent_viewset=viewset + parent_viewset=copied_viewset, + parent_lookups=parents_query_lookups ) def get_prefix(self, current_prefix, parents_query_lookups): diff --git a/rest_framework_extensions/serializers.py b/rest_framework_extensions/serializers.py index 09ecf0d..4dc4336 100644 --- a/rest_framework_extensions/serializers.py +++ b/rest_framework_extensions/serializers.py @@ -29,6 +29,39 @@ def get_fields_for_partial_update(opts, init_data, fields, init_files=None): return sorted(set(update_fields)) +class BulkCreateModelMixin: + """ + Builk create model instance. + Just post data like: + [ + {"name": "xxx"}, + {"name": "xxx2"}, + ] + """ + + def get_serializer(self, *args, **kwargs): + if isinstance(kwargs.get('data', {}), list): + kwargs['many'] = True + s = super().get_serializer(*args, **kwargs) + return s + + +class MultiSerializerViewSetMixin: + """ + serializer_action_classes = { + list: ListSerializer, + : Serializer, + ... + } + """ + serializer_classes = {} + def get_serializer_class(self): + try: + return self.serializer_classes[self.action] + except (KeyError, AttributeError): + return super(MultiSerializerViewSetMixin, self).get_serializer_class() + + class PartialUpdateSerializerMixin: def save(self, **kwargs): self._update_fields = kwargs.get('update_fields', None)