diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py index 15410e92..cde2db52 100644 --- a/dynamic_rest/filters.py +++ b/dynamic_rest/filters.py @@ -1,27 +1,21 @@ """This module contains custom filter backends.""" -from django.core.exceptions import ValidationError as InternalValidationError from django.core.exceptions import ImproperlyConfigured -from django.db.models import Q, Prefetch, Manager +from django.core.exceptions import ValidationError as InternalValidationError +from django.db.models import Manager, Prefetch, Q from django.utils import six -from rest_framework import serializers -from rest_framework.exceptions import ValidationError -from rest_framework.fields import BooleanField, NullBooleanField -from rest_framework.filters import BaseFilterBackend, OrderingFilter - -from dynamic_rest.utils import is_truthy from dynamic_rest.conf import settings from dynamic_rest.datastructures import TreeMap from dynamic_rest.fields import DynamicRelationField -from dynamic_rest.meta import ( - get_model_field, - is_field_remote, - is_model_field, - get_related_model -) +from dynamic_rest.meta import get_model_field, get_related_model, is_field_remote, is_model_field from dynamic_rest.patches import patch_prefetch_one_level -from dynamic_rest.prefetch import FastQuery, FastPrefetch +from dynamic_rest.prefetch import FastPrefetch, FastQuery from dynamic_rest.related import RelatedObject +from dynamic_rest.utils import is_truthy +from rest_framework import serializers +from rest_framework.exceptions import ValidationError +from rest_framework.fields import BooleanField, JSONField, NullBooleanField +from rest_framework.filters import BaseFilterBackend, OrderingFilter patch_prefetch_one_level() @@ -39,7 +33,6 @@ def has_joins(queryset): class FilterNode(object): - def __init__(self, field, operator, value): """Create an object representing a filter, to be stored in a TreeMap. @@ -65,10 +58,7 @@ def __init__(self, field, operator, value): @property def key(self): - return '%s%s' % ( - '__'.join(self.field), - '__' + self.operator if self.operator else '' - ) + return "%s%s" % ("__".join(self.field), "__" + self.operator if self.operator else "") def generate_query_key(self, serializer): """Get the key that can be passed to Django's filter method. @@ -90,7 +80,13 @@ def generate_query_key(self, serializer): last = len(self.field) - 1 s = serializer field = None + jsonfield_recurse = False for i, field_name in enumerate(self.field): + # Note: this is to handle jsonfield for recursive filtering + if jsonfield_recurse: + rewritten.append(field_name) + if i == last: + break # Note: .fields can be empty for related serializers that aren't # sideloaded. Fields that are deferred also won't be present. # If field name isn't in serializer.fields, get full list from @@ -98,16 +94,14 @@ def generate_query_key(self, serializer): # this if we have to. fields = s.fields if field_name not in fields: - fields = getattr(s, 'get_all_fields', lambda: {})() + fields = getattr(s, "get_all_fields", lambda: {})() - if field_name == 'pk': - rewritten.append('pk') + if field_name == "pk": + rewritten.append("pk") continue if field_name not in fields: - raise ValidationError( - "Invalid filter field: %s" % field_name - ) + raise ValidationError("Invalid filter field: %s" % field_name) field = fields[field_name] @@ -126,18 +120,20 @@ def generate_query_key(self, serializer): break # Recurse into nested field - s = getattr(field, 'serializer', None) + s = getattr(field, "serializer", None) if isinstance(s, serializers.ListSerializer): s = s.child + # Handle the field when it's a JSONField + elif isinstance(field, JSONField): + s = serializer + jsonfield_recurse = True if not s: - raise ValidationError( - "Invalid nested filter field: %s" % field_name - ) + raise ValidationError("Invalid nested filter field: %s" % field_name) if self.operator: rewritten.append(self.operator) - return ('__'.join(rewritten), field) + return ("__".join(rewritten), field) class DynamicFilterBackend(BaseFilterBackend): @@ -152,28 +148,28 @@ class DynamicFilterBackend(BaseFilterBackend): """ VALID_FILTER_OPERATORS = ( - 'in', - 'any', - 'all', - 'icontains', - 'contains', - 'startswith', - 'istartswith', - 'endswith', - 'iendswith', - 'year', - 'month', - 'day', - 'week_day', - 'regex', - 'range', - 'gt', - 'lt', - 'gte', - 'lte', - 'isnull', - 'eq', - 'iexact', + "in", + "any", + "all", + "icontains", + "contains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "year", + "month", + "day", + "week_day", + "regex", + "range", + "gt", + "lt", + "gte", + "lte", + "isnull", + "eq", + "iexact", None, ) @@ -198,15 +194,14 @@ def filter_queryset(self, request, queryset, view): self.DEBUG = settings.DEBUG return self._build_queryset( - queryset=queryset, - extra_filters=extra_filters, - disable_prefetches=disable_prefetches, + queryset=queryset, extra_filters=extra_filters, disable_prefetches=disable_prefetches, ) """ This function was renamed and broke downstream dependencies that haven't been updated to use the new naming convention. """ + def _extract_filters(self, **kwargs): return self._get_requested_filters(**kwargs) @@ -221,30 +216,27 @@ def _get_requested_filters(self, **kwargs): """ - filters_map = ( - kwargs.get('filters_map') or - self.view.get_request_feature(self.view.FILTER) - ) + filters_map = kwargs.get("filters_map") or self.view.get_request_feature(self.view.FILTER) out = TreeMap() for spec, value in six.iteritems(filters_map): # Inclusion or exclusion? - if spec[0] == '-': + if spec[0] == "-": spec = spec[1:] - inex = '_exclude' + inex = "_exclude" else: - inex = '_include' + inex = "_include" # for relational filters, separate out relation path part - if '|' in spec: - rel, spec = spec.split('|') - rel = rel.split('.') + if "|" in spec: + rel, spec = spec.split("|") + rel = rel.split(".") else: rel = None - parts = spec.split('.') + parts = spec.split(".") # Last part could be operator, e.g. "events.capacity.gte" if len(parts) > 1 and parts[-1] in self.VALID_FILTER_OPERATORS: @@ -253,19 +245,16 @@ def _get_requested_filters(self, **kwargs): operator = None # All operators except 'range' and 'in' should have one value - if operator == 'range': + if operator == "range": value = value[:2] - elif operator == 'in': + elif operator == "in": # no-op: i.e. accept `value` as an arbitrarily long list pass elif operator in self.VALID_FILTER_OPERATORS: value = value[0] - if ( - operator == 'isnull' and - isinstance(value, six.string_types) - ): + if operator == "isnull" and isinstance(value, six.string_types): value = is_truthy(value) - elif operator == 'eq': + elif operator == "eq": operator = None node = FilterNode(parts, operator, value) @@ -325,12 +314,7 @@ def rewrite_filters(filters, serializer): def _create_prefetch(self, source, queryset): return Prefetch(source, queryset=queryset) - def _build_implicit_prefetches( - self, - model, - prefetches, - requirements - ): + def _build_implicit_prefetches(self, model, prefetches, requirements): """Build a prefetch dictionary based on internal requirements.""" for source, remainder in six.iteritems(requirements): @@ -341,16 +325,12 @@ def _build_implicit_prefetches( related_field = get_model_field(model, source) related_model = get_related_model(related_field) - queryset = self._build_implicit_queryset( - related_model, - remainder - ) if related_model else None - - prefetches[source] = self._create_prefetch( - source, - queryset + queryset = ( + self._build_implicit_queryset(related_model, remainder) if related_model else None ) + prefetches[source] = self._create_prefetch(source, queryset) + return prefetches def _make_model_queryset(self, model): @@ -361,25 +341,14 @@ def _build_implicit_queryset(self, model, requirements): queryset = self._make_model_queryset(model) prefetches = {} - self._build_implicit_prefetches( - model, - prefetches, - requirements - ) + self._build_implicit_prefetches(model, prefetches, requirements) prefetch = prefetches.values() queryset = queryset.prefetch_related(*prefetch).distinct() if self.DEBUG: queryset._using_prefetches = prefetches return queryset - def _build_requested_prefetches( - self, - prefetches, - requirements, - model, - fields, - filters - ): + def _build_requested_prefetches(self, prefetches, requirements, model, fields, filters): """Build a prefetch dictionary based on request requirements.""" for name, field in six.iteritems(fields): @@ -392,22 +361,19 @@ def _build_requested_prefetches( continue source = field.source or name - if '.' in source: - raise ValidationError( - 'nested relationship values ' - 'are not supported' - ) + if "." in source: + raise ValidationError("nested relationship values " "are not supported") if source in prefetches: # ignore duplicated sources continue is_remote = is_field_remote(model, source) - is_id_only = getattr(field, 'id_only', lambda: False)() + is_id_only = getattr(field, "id_only", lambda: False)() if is_id_only and not is_remote: continue - related_queryset = getattr(original_field, 'queryset', None) + related_queryset = getattr(original_field, "queryset", None) if callable(related_queryset): related_queryset = related_queryset(field) @@ -422,7 +388,7 @@ def _build_requested_prefetches( serializer=field, filters=filters.get(name, {}), queryset=related_queryset, - requirements=required + requirements=required, ) # Note: There can only be one prefetch per source, even @@ -430,34 +396,27 @@ def _build_requested_prefetches( # the same source. This could break in some cases, # but is mostly an issue on writes when we use all # fields by default. - prefetches[source] = self._create_prefetch( - source, - prefetch_queryset - ) + prefetches[source] = self._create_prefetch(source, prefetch_queryset) return prefetches - def _get_implicit_requirements( - self, - fields, - requirements - ): + def _get_implicit_requirements(self, fields, requirements): """Extract internal prefetch requirements from serializer fields.""" for name, field in six.iteritems(fields): source = field.source # Requires may be manually set on the field -- if not, # assume the field requires only its source. - requires = getattr(field, 'requires', None) or [source] + requires = getattr(field, "requires", None) or [source] for require in requires: if not require: # ignore fields with empty source continue - requirement = require.split('.') - if requirement[-1] == '': + requirement = require.split(".") + if requirement[-1] == "": # Change 'a.b.' -> 'a.b.*', # supporting 'a.b.' for backwards compatibility. - requirement[-1] = '*' + requirement[-1] = "*" requirements.insert(requirement, TreeMap(), update=True) def _get_queryset(self, queryset=None, serializer=None): @@ -499,7 +458,7 @@ def _build_queryset( queryset = self._get_queryset(queryset=queryset, serializer=serializer) - model = getattr(serializer.Meta, 'model', None) + model = getattr(serializer.Meta, "model", None) if not model: return queryset @@ -513,10 +472,7 @@ def _build_queryset( if requirements is None: requirements = TreeMap() - self._get_implicit_requirements( - fields, - requirements - ) + self._get_implicit_requirements(fields, requirements) # Implicit requirements (i.e. via `requires`) can potentially # include fields that haven't been explicitly included. @@ -524,55 +480,38 @@ def _build_queryset( implicitly_included = set(requirements.keys()) - set(fields.keys()) if implicitly_included: all_fields = serializer.get_all_fields() - fields.update({ - field: all_fields[field] - for field in implicitly_included - if field in all_fields - }) + fields.update( + {field: all_fields[field] for field in implicitly_included if field in all_fields} + ) if filters is None: filters = self._get_requested_filters() # build nested Prefetch queryset - self._build_requested_prefetches( - prefetches, - requirements, - model, - fields, - filters - ) + self._build_requested_prefetches(prefetches, requirements, model, fields, filters) # build remaining prefetches out of internal requirements # that are not already covered by request requirements - self._build_implicit_prefetches( - model, - prefetches, - requirements - ) + self._build_implicit_prefetches(model, prefetches, requirements) # use requirements at this level to limit fields selected # only do this for GET requests where we are not requesting the # entire fieldset - if ( - '*' not in requirements and - not self.view.is_update() and - not self.view.is_delete() - ): - id_fields = getattr(serializer, 'get_id_fields', lambda: [])() + if "*" not in requirements and not self.view.is_update() and not self.view.is_delete(): + id_fields = getattr(serializer, "get_id_fields", lambda: [])() # only include local model fields only = [ - field for field in set( - id_fields + list(requirements.keys()) - ) if is_model_field(model, field) and - not is_field_remote(model, field) + field + for field in set(id_fields + list(requirements.keys())) + if is_model_field(model, field) and not is_field_remote(model, field) ] queryset = queryset.only(*only) # add request filters query = self._filters_to_query( - includes=filters.get('_include'), - excludes=filters.get('_exclude'), - serializer=serializer + includes=filters.get("_include"), + excludes=filters.get("_exclude"), + serializer=serializer, ) # add additional filters specified by calling view @@ -586,13 +525,11 @@ def _build_queryset( try: queryset = queryset.filter(query) except InternalValidationError as e: - raise ValidationError( - dict(e) if hasattr(e, 'error_dict') else list(e) - ) + raise ValidationError(dict(e) if hasattr(e, "error_dict") else list(e)) except Exception as e: # Some other Django error in parsing the filter. # Very likely a bad query, so throw a ValidationError. - err_msg = getattr(e, 'message', '') + err_msg = getattr(e, "message", "") raise ValidationError(err_msg) # A serializer can have this optional function @@ -601,11 +538,8 @@ def _build_queryset( # You could use this to have (for example) different # serializers for different subsets of a model or to # implement permissions which work even in sideloads - if hasattr(serializer, 'filter_queryset'): - queryset = self._serializer_filter( - serializer=serializer, - queryset=queryset - ) + if hasattr(serializer, "filter_queryset"): + queryset = self._serializer_filter(serializer=serializer, queryset=queryset) # add prefetches and remove duplicates if necessary prefetch = prefetches.values() @@ -627,8 +561,7 @@ def _create_prefetch(self, source, queryset): def _get_queryset(self, queryset=None, serializer=None): queryset = super(FastDynamicFilterBackend, self)._get_queryset( - queryset=queryset, - serializer=serializer + queryset=queryset, serializer=serializer ) if not isinstance(queryset, FastQuery): @@ -637,15 +570,11 @@ def _get_queryset(self, queryset=None, serializer=None): return queryset def _make_model_queryset(self, model): - queryset = super(FastDynamicFilterBackend, self)._make_model_queryset( - model - ) + queryset = super(FastDynamicFilterBackend, self)._make_model_queryset(model) return FastQuery(queryset) def _serializer_filter(self, serializer=None, queryset=None): - queryset.queryset = serializer.filter_queryset( - queryset.queryset - ) + queryset.queryset = serializer.filter_queryset(queryset.queryset) return queryset @@ -668,7 +597,7 @@ def filter_queryset(self, request, queryset, view): ordering = self.get_ordering(request, queryset, view) if ordering: queryset = queryset.order_by(*ordering) - if any(['__' in o for o in ordering]): + if any(["__" in o for o in ordering]): # add distinct() to remove duplicates # in case of order-by-related queryset = queryset.distinct() @@ -683,16 +612,12 @@ def get_ordering(self, request, queryset, view): params = view.get_request_feature(view.SORT) if params: fields = [param.strip() for param in params] - valid_ordering, invalid_ordering = self.remove_invalid_fields( - queryset, fields, view - ) + valid_ordering, invalid_ordering = self.remove_invalid_fields(queryset, fields, view) # if any of the sort fields are invalid, throw an error. # else return the ordering if invalid_ordering: - raise ValidationError( - "Invalid filter field: %s" % invalid_ordering - ) + raise ValidationError("Invalid filter field: %s" % invalid_ordering) else: return valid_ordering @@ -712,9 +637,9 @@ def remove_invalid_fields(self, queryset, fields, view): # for each field sent down from the query param, # determine if its valid or invalid for term in fields: - stripped_term = term.lstrip('-') + stripped_term = term.lstrip("-") # add back the '-' add the end if necessary - reverse_sort_term = '' if len(stripped_term) is len(term) else '-' + reverse_sort_term = "" if len(stripped_term) is len(term) else "-" ordering = self.ordering_for(stripped_term, view) if ordering: @@ -735,15 +660,14 @@ def ordering_for(self, term, view): return None serializer = self._get_serializer_class(view)() - serializer_chain = term.split('.') + serializer_chain = term.split(".") model_chain = [] for segment in serializer_chain[:-1]: field = serializer.get_all_fields().get(segment) - if not (field and field.source != '*' and - isinstance(field, DynamicRelationField)): + if not (field and field.source != "*" and isinstance(field, DynamicRelationField)): return None model_chain.append(field.source or segment) @@ -753,22 +677,22 @@ def ordering_for(self, term, view): last_segment = serializer_chain[-1] last_field = serializer.get_all_fields().get(last_segment) - if not last_field or last_field.source == '*': + if not last_field or last_field.source == "*": return None model_chain.append(last_field.source or last_segment) - return '__'.join(model_chain) + return "__".join(model_chain) def _is_allowed_term(self, term, view): - valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) - all_fields_allowed = valid_fields is None or valid_fields == '__all__' + valid_fields = getattr(view, "ordering_fields", self.ordering_fields) + all_fields_allowed = valid_fields is None or valid_fields == "__all__" return all_fields_allowed or term in valid_fields def _get_serializer_class(self, view): # prefer the overriding method - if hasattr(view, 'get_serializer_class'): + if hasattr(view, "get_serializer_class"): try: serializer_class = view.get_serializer_class() except AssertionError: @@ -777,7 +701,7 @@ def _get_serializer_class(self, view): serializer_class = None # use the attribute else: - serializer_class = getattr(view, 'serializer_class', None) + serializer_class = getattr(view, "serializer_class", None) # neither a method nor an attribute has been specified if serializer_class is None: diff --git a/tests/models.py b/tests/models.py index f2b495f5..da5d36cd 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,32 +1,22 @@ from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType +from django.contrib.postgres.fields import JSONField from django.db import models class User(models.Model): name = models.TextField() last_name = models.TextField() - groups = models.ManyToManyField('Group', related_name='users') - permissions = models.ManyToManyField('Permission', related_name='users') + groups = models.ManyToManyField("Group", related_name="users") + permissions = models.ManyToManyField("Permission", related_name="users") date_of_birth = models.DateField(null=True, blank=True) # 'related_name' intentionally left unset in location field below: - location = models.ForeignKey( - 'Location', - null=True, - blank=True, - on_delete=models.CASCADE - ) + location = models.ForeignKey("Location", null=True, blank=True, on_delete=models.CASCADE) favorite_pet_type = models.ForeignKey( - ContentType, - null=True, - blank=True, - on_delete=models.CASCADE + ContentType, null=True, blank=True, on_delete=models.CASCADE ) favorite_pet_id = models.TextField(null=True, blank=True) - favorite_pet = GenericForeignKey( - 'favorite_pet_type', - 'favorite_pet_id', - ) + favorite_pet = GenericForeignKey("favorite_pet_type", "favorite_pet_id",) is_dead = models.NullBooleanField(default=False) @@ -38,23 +28,15 @@ class Profile(models.Model): class Cat(models.Model): name = models.TextField() - home = models.ForeignKey('Location', on_delete=models.CASCADE) + home = models.ForeignKey("Location", on_delete=models.CASCADE) backup_home = models.ForeignKey( - 'Location', - related_name='friendly_cats', - on_delete=models.CASCADE + "Location", related_name="friendly_cats", on_delete=models.CASCADE ) hunting_grounds = models.ManyToManyField( - 'Location', - related_name='annoying_cats', - related_query_name='getoffmylawn' + "Location", related_name="annoying_cats", related_query_name="getoffmylawn" ) parent = models.ForeignKey( - 'Cat', - null=True, - blank=True, - related_name='kittens', - on_delete=models.CASCADE + "Cat", null=True, blank=True, related_name="kittens", on_delete=models.CASCADE ) @@ -76,7 +58,7 @@ class Zebra(models.Model): class Group(models.Model): name = models.TextField(unique=True) - permissions = models.ManyToManyField('Permission', related_name='groups') + permissions = models.ManyToManyField("Permission", related_name="groups") class Permission(models.Model): @@ -94,15 +76,11 @@ class Event(models.Model): Event model -- Intentionally missing serializer and viewset, so they can be added as part of a codelab. """ + name = models.TextField() status = models.TextField(default="current") - location = models.ForeignKey( - 'Location', - null=True, - blank=True, - on_delete=models.CASCADE - ) - users = models.ManyToManyField('User') + location = models.ForeignKey("Location", null=True, blank=True, on_delete=models.CASCADE) + users = models.ManyToManyField("User") class A(models.Model): @@ -110,12 +88,12 @@ class A(models.Model): class B(models.Model): - a = models.OneToOneField('A', related_name='b', on_delete=models.CASCADE) + a = models.OneToOneField("A", related_name="b", on_delete=models.CASCADE) class C(models.Model): - b = models.ForeignKey('B', related_name='cs', on_delete=models.CASCADE) - d = models.ForeignKey('D', on_delete=models.CASCADE) + b = models.ForeignKey("B", related_name="cs", on_delete=models.CASCADE) + d = models.ForeignKey("D", on_delete=models.CASCADE) class D(models.Model): @@ -136,3 +114,8 @@ class Part(models.Model): car = models.ForeignKey(Car, on_delete=models.CASCADE) name = models.CharField(max_length=60) country = models.ForeignKey(Country, on_delete=models.CASCADE) + + +class JsonFieldModel(models.Model): + name = models.CharField(max_length=60) + some_jsonfield = JSONField(default=dict) diff --git a/tests/setup.py b/tests/setup.py index c7e69abb..2ba9e5a5 100644 --- a/tests/setup.py +++ b/tests/setup.py @@ -8,12 +8,13 @@ Event, Group, Horse, + JsonFieldModel, Location, Part, Permission, User, - Zebra - ) + Zebra, +) def create_fixture(): @@ -25,31 +26,46 @@ def create_fixture(): # Create 4 dogs. # Create 2 Country # Create 1 Car has 2 Parts each from different Country + # Create 1 JsonFieldModel types = [ - 'users', 'groups', 'locations', 'permissions', - 'events', 'cats', 'dogs', 'horses', 'zebras', - 'cars', 'countries', 'parts', + "users", + "groups", + "locations", + "permissions", + "events", + "cats", + "dogs", + "horses", + "zebras", + "cars", + "countries", + "parts", + "json_field_models", ] - Fixture = namedtuple('Fixture', types) + Fixture = namedtuple("Fixture", types) fixture = Fixture( - users=[], groups=[], locations=[], permissions=[], - events=[], cats=[], dogs=[], horses=[], zebras=[], - cars=[], countries=[], parts=[] + users=[], + groups=[], + locations=[], + permissions=[], + events=[], + cats=[], + dogs=[], + horses=[], + zebras=[], + cars=[], + countries=[], + parts=[], + json_field_models=[], ) for i in range(0, 4): - fixture.users.append( - User.objects.create( - name=str(i), - last_name=str(i))) + fixture.users.append(User.objects.create(name=str(i), last_name=str(i))) for i in range(0, 4): - fixture.permissions.append( - Permission.objects.create( - name=str(i), - code=i)) + fixture.permissions.append(Permission.objects.create(name=str(i), code=i)) for i in range(0, 2): fixture.groups.append(Group.objects.create(name=str(i))) @@ -58,97 +74,63 @@ def create_fixture(): fixture.locations.append(Location.objects.create(name=str(i))) for i in range(0, 2): - fixture.cats.append(Cat.objects.create( - name=str(i), - home_id=fixture.locations[i].id, - backup_home_id=( - fixture.locations[len(fixture.locations) - 1 - i].id))) - - dogs = [{ - 'name': 'Clifford', - 'fur_color': 'red', - 'origin': 'Clifford the big red dog' - }, { - 'name': 'Air-Bud', - 'fur_color': 'gold', - 'origin': 'Air Bud 4: Seventh Inning Fetch' - }, { - 'name': 'Spike', - 'fur_color': 'brown', - 'origin': 'Rugrats' - }, { - 'name': 'Pluto', - 'fur_color': 'brown and white', - 'origin': 'Mickey Mouse' - }, { - 'name': 'Spike', - 'fur_color': 'light-brown', - 'origin': 'Tom and Jerry' - }] - - horses = [{ - 'name': 'Seabiscuit', - 'origin': 'LA' - }, { - 'name': 'Secretariat', - 'origin': 'Kentucky' - }] - - zebras = [{ - 'name': 'Ralph', - 'origin': 'new york' - }, { - 'name': 'Ted', - 'origin': 'africa' - }] - - events = [{ - 'name': 'Event 1', - 'status': 'archived', - 'location': 2 - }, { - 'name': 'Event 2', - 'status': 'current', - 'location': 1 - }, { - 'name': 'Event 3', - 'status': 'current', - 'location': 1 - }, { - 'name': 'Event 4', - 'status': 'archived', - 'location': 2 - }, { - 'name': 'Event 5', - 'status': 'current', - 'location': 2 - }] + fixture.cats.append( + Cat.objects.create( + name=str(i), + home_id=fixture.locations[i].id, + backup_home_id=(fixture.locations[len(fixture.locations) - 1 - i].id), + ) + ) + + fixture.json_field_models.append( + JsonFieldModel.objects.create( + name=str(i), some_jsonfield={"value": "string value for icontains testing"} + ) + ) + + dogs = [ + {"name": "Clifford", "fur_color": "red", "origin": "Clifford the big red dog"}, + {"name": "Air-Bud", "fur_color": "gold", "origin": "Air Bud 4: Seventh Inning Fetch"}, + {"name": "Spike", "fur_color": "brown", "origin": "Rugrats"}, + {"name": "Pluto", "fur_color": "brown and white", "origin": "Mickey Mouse"}, + {"name": "Spike", "fur_color": "light-brown", "origin": "Tom and Jerry"}, + ] + + horses = [{"name": "Seabiscuit", "origin": "LA"}, {"name": "Secretariat", "origin": "Kentucky"}] + + zebras = [{"name": "Ralph", "origin": "new york"}, {"name": "Ted", "origin": "africa"}] + + events = [ + {"name": "Event 1", "status": "archived", "location": 2}, + {"name": "Event 2", "status": "current", "location": 1}, + {"name": "Event 3", "status": "current", "location": 1}, + {"name": "Event 4", "status": "archived", "location": 2}, + {"name": "Event 5", "status": "current", "location": 2}, + ] for dog in dogs: - fixture.dogs.append(Dog.objects.create( - name=dog.get('name'), - fur_color=dog.get('fur_color'), - origin=dog.get('origin') - )) + fixture.dogs.append( + Dog.objects.create( + name=dog.get("name"), fur_color=dog.get("fur_color"), origin=dog.get("origin") + ) + ) for horse in horses: - fixture.horses.append(Horse.objects.create( - name=horse.get('name'), - origin=horse.get('origin') - )) + fixture.horses.append( + Horse.objects.create(name=horse.get("name"), origin=horse.get("origin")) + ) for zebra in zebras: - fixture.zebras.append(Zebra.objects.create( - name=zebra.get('name'), - origin=zebra.get('origin') - )) + fixture.zebras.append( + Zebra.objects.create(name=zebra.get("name"), origin=zebra.get("origin")) + ) for event in events: - fixture.events.append(Event.objects.create( - name=event['name'], - status=event['status'], - location_id=event['location'] - )) + fixture.events.append( + Event.objects.create( + name=event["name"], status=event["status"], location_id=event["location"] + ) + ) fixture.events[1].users.add(fixture.users[0]) fixture.events[1].users.add(fixture.users[1]) fixture.events[2].users.add(fixture.users[0]) @@ -158,7 +140,7 @@ def create_fixture(): fixture.events[4].users.add(fixture.users[1]) fixture.events[4].users.add(fixture.users[2]) - fixture.locations[0].blob = 'here' + fixture.locations[0].blob = "here" fixture.locations[0].save() fixture.users[0].location = fixture.locations[0] @@ -192,47 +174,30 @@ def create_fixture(): fixture.groups[0].permissions.add(fixture.permissions[0]) fixture.groups[1].permissions.add(fixture.permissions[1]) - countries = [{ - 'id': 1, - 'name': 'United States', - 'short_name': 'US', - }, { - 'id': 2, - 'name': 'China', - 'short_name': 'CN', - }] - - cars = [{ - 'id': 1, - 'name': 'Porshe', - 'country': 1 - }] - - parts = [{ - 'car': 1, - 'name': 'wheel', - 'country': 1 - }, { - 'car': 1, - 'name': 'tire', - 'country': 2 - }] + countries = [ + {"id": 1, "name": "United States", "short_name": "US",}, + {"id": 2, "name": "China", "short_name": "CN",}, + ] + + cars = [{"id": 1, "name": "Porshe", "country": 1}] + + parts = [{"car": 1, "name": "wheel", "country": 1}, {"car": 1, "name": "tire", "country": 2}] for country in countries: fixture.countries.append(Country.objects.create(**country)) for car in cars: - fixture.cars.append(Car.objects.create( - id=car.get('id'), - name=car.get('name'), - country_id=car.get('country') - )) + fixture.cars.append( + Car.objects.create( + id=car.get("id"), name=car.get("name"), country_id=car.get("country") + ) + ) for part in parts: - fixture.parts.append(Part.objects.create( - car_id=part.get('car'), - name=part.get('name'), - country_id=part.get('country') - )) + fixture.parts.append( + Part.objects.create( + car_id=part.get("car"), name=part.get("name"), country_id=part.get("country") + ) + ) return fixture diff --git a/tests/test_generic.py b/tests/test_generic.py index 05e74043..ab0b402f 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,16 +1,14 @@ import json -from rest_framework.test import APITestCase - from dynamic_rest.fields import DynamicGenericRelationField from dynamic_rest.routers import DynamicRouter -from tests.models import User, Zebra +from rest_framework.test import APITestCase +from tests.models import JsonFieldModel, User, Zebra from tests.serializers import UserSerializer from tests.setup import create_fixture class TestGenericRelationFieldAPI(APITestCase): - def setUp(self): self.fixture = create_fixture() f = self.fixture @@ -34,61 +32,47 @@ def test_id_only(self): } ``` """ - url = ( - '/users/?include[]=favorite_pet' - '&filter{favorite_pet_id.isnull}=false' - ) + url = "/users/?include[]=favorite_pet" "&filter{favorite_pet_id.isnull}=false" response = self.client.get(url) self.assertEqual(200, response.status_code) - content = json.loads(response.content.decode('utf-8')) - self.assertTrue( - all( - [_['favorite_pet'] for _ in content['users']] - ) - ) - self.assertFalse('cats' in content) - self.assertFalse('dogs' in content) - self.assertTrue('type' in content['users'][0]['favorite_pet']) - self.assertTrue('id' in content['users'][0]['favorite_pet']) + content = json.loads(response.content.decode("utf-8")) + self.assertTrue(all([_["favorite_pet"] for _ in content["users"]])) + self.assertFalse("cats" in content) + self.assertFalse("dogs" in content) + self.assertTrue("type" in content["users"][0]["favorite_pet"]) + self.assertTrue("id" in content["users"][0]["favorite_pet"]) def test_sideload(self): - url = ( - '/users/?include[]=favorite_pet.' - '&filter{favorite_pet_id.isnull}=false' - ) + url = "/users/?include[]=favorite_pet." "&filter{favorite_pet_id.isnull}=false" response = self.client.get(url) self.assertEqual(200, response.status_code) - content = json.loads(response.content.decode('utf-8')) - self.assertTrue( - all( - [_['favorite_pet'] for _ in content['users']] - ) - ) - self.assertTrue('cats' in content) - self.assertEqual(2, len(content['cats'])) - self.assertTrue('dogs' in content) - self.assertEqual(1, len(content['dogs'])) - self.assertTrue('type' in content['users'][0]['favorite_pet']) - self.assertTrue('id' in content['users'][0]['favorite_pet']) + content = json.loads(response.content.decode("utf-8")) + self.assertTrue(all([_["favorite_pet"] for _ in content["users"]])) + self.assertTrue("cats" in content) + self.assertEqual(2, len(content["cats"])) + self.assertTrue("dogs" in content) + self.assertEqual(1, len(content["dogs"])) + self.assertTrue("type" in content["users"][0]["favorite_pet"]) + self.assertTrue("id" in content["users"][0]["favorite_pet"]) def test_multi_sideload_include(self): url = ( - '/cars/1/?include[]=name&include[]=country.short_name' - '&include[]=parts.name&include[]=parts.country.name' + "/cars/1/?include[]=name&include[]=country.short_name" + "&include[]=parts.name&include[]=parts.country.name" ) response = self.client.get(url) self.assertEqual(200, response.status_code) - content = json.loads(response.content.decode('utf-8')) - self.assertTrue('countries' in content) + content = json.loads(response.content.decode("utf-8")) + self.assertTrue("countries" in content) country = None - for _ in content['countries']: - if _['id'] == 1: + for _ in content["countries"]: + if _["id"] == 1: country = _ self.assertTrue(country) - self.assertTrue('short_name' in country) - self.assertTrue('name' in country) + self.assertTrue("short_name" in country) + self.assertTrue("name" in country) def test_query_counts(self): # NOTE: Django doesn't seem to prefetch ContentType objects @@ -96,15 +80,12 @@ def test_query_counts(self): # this call could do 5 SQL queries if the Cat and Dog # ContentType objects haven't been cached. with self.assertNumQueries(3): - url = ( - '/users/?include[]=favorite_pet.' - '&filter{favorite_pet_id.isnull}=false' - ) + url = "/users/?include[]=favorite_pet." "&filter{favorite_pet_id.isnull}=false" response = self.client.get(url) self.assertEqual(200, response.status_code) with self.assertNumQueries(3): - url = '/users/?include[]=favorite_pet.' + url = "/users/?include[]=favorite_pet." response = self.client.get(url) self.assertEqual(200, response.status_code) @@ -113,10 +94,7 @@ def test_unknown_resource(self): which there is no known canonical serializer. """ - zork = Zebra.objects.create( - name='Zork', - origin='San Francisco Zoo' - ) + zork = Zebra.objects.create(name="Zork", origin="San Francisco Zoo") user = self.fixture.users[0] user.favorite_pet = zork @@ -124,26 +102,23 @@ def test_unknown_resource(self): self.assertIsNone(DynamicRouter.get_canonical_serializer(Zebra)) - url = '/users/%s/?include[]=favorite_pet' % user.pk + url = "/users/%s/?include[]=favorite_pet" % user.pk response = self.client.get(url) self.assertEqual(200, response.status_code) - content = json.loads(response.content.decode('utf-8')) - self.assertTrue('user' in content) - self.assertFalse('zebras' in content) # Not sideloaded - user_obj = content['user'] - self.assertTrue('favorite_pet' in user_obj) - self.assertEqual('Zebra', user_obj['favorite_pet']['type']) - self.assertEqual(zork.pk, user_obj['favorite_pet']['id']) + content = json.loads(response.content.decode("utf-8")) + self.assertTrue("user" in content) + self.assertFalse("zebras" in content) # Not sideloaded + user_obj = content["user"] + self.assertTrue("favorite_pet" in user_obj) + self.assertEqual("Zebra", user_obj["favorite_pet"]["type"]) + self.assertEqual(zork.pk, user_obj["favorite_pet"]["id"]) def test_dgrf_with_requires_raises(self): with self.assertRaises(Exception): - DynamicGenericRelationField(requires=['foo', 'bar']) + DynamicGenericRelationField(requires=["foo", "bar"]) def test_if_field_inclusion_then_error(self): - url = ( - '/users/?include[]=favorite_pet.name' - '&filter{favorite_pet_id.isnull}=false' - ) + url = "/users/?include[]=favorite_pet.name" "&filter{favorite_pet_id.isnull}=false" response = self.client.get(url) self.assertEqual(400, response.status_code) @@ -154,47 +129,45 @@ def test_patch_resource(self): """ user = self.fixture.users[0] - url = '/users/%s/?include[]=favorite_pet.' % user.pk + url = "/users/%s/?include[]=favorite_pet." % user.pk response = self.client.patch( url, - json.dumps({ - 'id': user.id, - 'favorite_pet': { - 'type': 'dog', - 'id': 1 - } - }), - content_type='application/json' + json.dumps({"id": user.id, "favorite_pet": {"type": "dog", "id": 1}}), + content_type="application/json", ) self.assertEqual(200, response.status_code) - content = json.loads(response.content.decode('utf-8')) - self.assertTrue('user' in content) - self.assertFalse('cats' in content) - self.assertTrue('dogs' in content) - self.assertEqual(1, content['dogs'][0]['id']) + content = json.loads(response.content.decode("utf-8")) + self.assertTrue("user" in content) + self.assertFalse("cats" in content) + self.assertTrue("dogs" in content) + self.assertEqual(1, content["dogs"][0]["id"]) def test_non_deferred_generic_field(self): class FooUserSerializer(UserSerializer): - class Meta: model = User - name = 'user' + name = "user" fields = ( - 'id', - 'favorite_pet', + "id", + "favorite_pet", ) - user = User.objects.filter( - favorite_pet_id__isnull=False - ).prefetch_related( - 'favorite_pet' - ).first() + user = ( + User.objects.filter(favorite_pet_id__isnull=False) + .prefetch_related("favorite_pet") + .first() + ) - data = FooUserSerializer(user, envelope=True).data['user'] + data = FooUserSerializer(user, envelope=True).data["user"] self.assertIsNotNone(data) - self.assertTrue('favorite_pet' in data) - self.assertTrue(isinstance(data['favorite_pet'], dict)) - self.assertEqual( - set(['id', 'type']), - set(data['favorite_pet'].keys()) - ) + self.assertTrue("favorite_pet" in data) + self.assertTrue(isinstance(data["favorite_pet"], dict)) + self.assertEqual(set(["id", "type"]), set(data["favorite_pet"].keys())) + + def test_jsonfield_filter_recurse(self): + """ + make the filter work with JSONField + """ + url = "/json_field_models/?&filter{some_jsonfield.value.icontains}=icontains" + response = self.client.get(url) + self.assertEqual(200, response.status_code)