diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..ab400f7b --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +max-line-length=90 diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py index 495c6d62..8ceffec4 100644 --- a/dynamic_rest/filters.py +++ b/dynamic_rest/filters.py @@ -4,6 +4,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db.models import Q, Prefetch, Manager import six +from functools import reduce from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_framework.fields import BooleanField, NullBooleanField @@ -17,7 +18,7 @@ get_model_field, is_field_remote, is_model_field, - get_related_model + get_related_model, ) from dynamic_rest.patches import patch_prefetch_one_level from dynamic_rest.prefetch import FastQuery, FastPrefetch @@ -26,6 +27,14 @@ patch_prefetch_one_level() +def OR(a, b): + return a | b + + +def AND(a, b): + return a & b + + def has_joins(queryset): """Return True iff. a queryset includes joins. @@ -39,7 +48,6 @@ def has_joins(queryset): class FilterNode(object): - def __init__(self, field, operator, value): """Create an object representing a filter, to be stored in a TreeMap. @@ -67,7 +75,7 @@ def __init__(self, field, operator, value): def key(self): return '%s%s' % ( '__'.join(self.field), - '__' + self.operator if self.operator else '' + '__' + self.operator if self.operator else '', ) def generate_query_key(self, serializer): @@ -105,9 +113,7 @@ def generate_query_key(self, serializer): continue if field_name not in fields: - raise ValidationError( - "Invalid filter field: %s" % field_name - ) + raise ValidationError("Invalid filter field: %s" % field_name) field = fields[field_name] @@ -130,9 +136,7 @@ def generate_query_key(self, serializer): if isinstance(s, serializers.ListSerializer): s = s.child if not s: - raise ValidationError( - "Invalid nested filter field: %s" % field_name - ) + raise ValidationError("Invalid nested filter field: %s" % field_name) if self.operator: rewritten.append(self.operator) @@ -140,6 +144,38 @@ def generate_query_key(self, serializer): return ('__'.join(rewritten), field) +def rewrite_filters(fs, serializer): + out = {} + for node in fs.values(): + filter_key, field = node.generate_query_key(serializer) + if isinstance(field, (BooleanField, NullBooleanField)): + node.value = is_truthy(node.value) + out[filter_key] = node.value + + return out + + +def clause_to_q(clause, serializer): + key, value = clause + negate = False + q = {} + if key.startswith('-'): + negate = True + key = key[1:] + parts = key.split('.') + operator = 'eq' + if parts[-1] in DynamicFilterBackend.VALID_FILTER_OPERATORS: + operator = parts.pop() + if operator == 'eq': + operator = None + node = FilterNode(parts, operator, value) + key, _ = node.generate_query_key(serializer) + q = Q(**{key: node.value}) + if negate: + q = ~q + return q + + class DynamicFilterBackend(BaseFilterBackend): """A DRF filter backend that constructs DREST querysets. @@ -207,6 +243,7 @@ def filter_queryset(self, request, queryset, view): This function was renamed and broke downstream dependencies that haven't been updated to use the new naming convention. """ + def _extract_filters(self, **kwargs): return self._get_requested_filters(**kwargs) @@ -221,12 +258,12 @@ def _get_requested_filters(self, **kwargs): """ - filters_map = ( - kwargs.get('filters_map') or - self.view.get_request_feature(self.view.FILTER) - ) - out = TreeMap() + filters_map = kwargs.get('filters_map') or self.view.get_request_feature( + self.view.FILTER + ) + if getattr(self, 'view', None): + out['_complex'] = self.view.get_request_feature(self.view.FILTER, raw=True) for spec, value in six.iteritems(filters_map): @@ -260,10 +297,7 @@ def _get_requested_filters(self, **kwargs): pass elif operator in self.VALID_FILTER_OPERATORS: value = value[0] - if ( - operator == 'isnull' and - isinstance(value, six.string_types) - ): + if operator == 'isnull' and isinstance(value, six.string_types): value = is_truthy(value) elif operator == 'eq': operator = None @@ -277,7 +311,7 @@ def _get_requested_filters(self, **kwargs): return out - def _filters_to_query(self, includes, excludes, serializer, q=None): + def _filters_to_query(self, filters, serializer, q=None): """ Construct Django Query object from request. Arguments are dictionaries, which will be passed to Q() as kwargs. @@ -290,6 +324,7 @@ def _filters_to_query(self, includes, excludes, serializer, q=None): Arguments: includes: TreeMap representing inclusion filters. excludes: TreeMap representing exclusion filters. + filters: TreeMap with include/exclude filters OR query map serializer: serializer instance of top-level object q: Q() object (optional) @@ -298,39 +333,50 @@ def _filters_to_query(self, includes, excludes, serializer, q=None): were specified. """ - def rewrite_filters(filters, serializer): - out = {} - for k, node in six.iteritems(filters): - filter_key, field = node.generate_query_key(serializer) - if isinstance(field, (BooleanField, NullBooleanField)): - node.value = is_truthy(node.value) - out[filter_key] = node.value - - return out - - q = q or Q() + if ( + not filters.get('_complex') + ): + includes = filters.get('_include') + excludes = filters.get('_exclude') + q = q or Q() - if not includes and not excludes: - return None + if not includes and not excludes: + return None - if includes: - includes = rewrite_filters(includes, serializer) - q &= Q(**includes) - if excludes: - excludes = rewrite_filters(excludes, serializer) - for k, v in six.iteritems(excludes): - q &= ~Q(**{k: v}) - return q + if includes: + includes = rewrite_filters(includes, serializer) + q &= Q(**includes) + if excludes: + excludes = rewrite_filters(excludes, serializer) + for k, v in six.iteritems(excludes): + q &= ~Q(**{k: v}) + return q + else: + filters = filters.get('_complex') + ors = filters.get('.or') or filters.get('$or') + ands = filters.get('.and') or filters.get('$and') + if q is None: + q = Q() + if ors: + result = reduce( + OR, + [self._filters_to_query({"_complex": f}, serializer) for f in ors] + ) + return result + if ands: + return reduce( + AND, + [self._filters_to_query({"_complex": f}, serializer) for f in ands] + ) + clauses = [ + clause_to_q(clause, serializer) for clause in filters.items() + ] + return reduce(AND, clauses) if clauses else q def _create_prefetch(self, source, queryset): return Prefetch(source, queryset=queryset) - def _build_implicit_prefetches( - self, - model, - prefetches, - requirements - ): + def _build_implicit_prefetches(self, model, prefetches, requirements): """Build a prefetch dictionary based on internal requirements.""" for source, remainder in six.iteritems(requirements): @@ -341,16 +387,14 @@ def _build_implicit_prefetches( related_field = get_model_field(model, source) related_model = get_related_model(related_field) - queryset = self._build_implicit_queryset( - related_model, - remainder - ) if related_model else None - - prefetches[source] = self._create_prefetch( - source, - queryset + queryset = ( + self._build_implicit_queryset(related_model, remainder) + if related_model + else None ) + prefetches[source] = self._create_prefetch(source, queryset) + return prefetches def _make_model_queryset(self, model): @@ -361,11 +405,7 @@ def _build_implicit_queryset(self, model, requirements): queryset = self._make_model_queryset(model) prefetches = {} - self._build_implicit_prefetches( - model, - prefetches, - requirements - ) + self._build_implicit_prefetches(model, prefetches, requirements) prefetch = prefetches.values() queryset = queryset.prefetch_related(*prefetch).distinct() if self.DEBUG: @@ -373,12 +413,7 @@ def _build_implicit_queryset(self, model, requirements): return queryset def _build_requested_prefetches( - self, - prefetches, - requirements, - model, - fields, - filters + self, prefetches, requirements, model, fields, filters ): """Build a prefetch dictionary based on request requirements.""" @@ -393,10 +428,7 @@ def _build_requested_prefetches( source = field.source or name if '.' in source: - raise ValidationError( - 'nested relationship values ' - 'are not supported' - ) + raise ValidationError('nested relationship values are not supported') if source in prefetches: # ignore duplicated sources @@ -422,7 +454,7 @@ def _build_requested_prefetches( serializer=field, filters=filters.get(name, {}), queryset=related_queryset, - requirements=required + requirements=required, ) # Note: There can only be one prefetch per source, even @@ -430,18 +462,11 @@ def _build_requested_prefetches( # the same source. This could break in some cases, # but is mostly an issue on writes when we use all # fields by default. - prefetches[source] = self._create_prefetch( - source, - prefetch_queryset - ) + prefetches[source] = self._create_prefetch(source, prefetch_queryset) return prefetches - def _get_implicit_requirements( - self, - fields, - requirements - ): + def _get_implicit_requirements(self, fields, requirements): """Extract internal prefetch requirements from serializer fields.""" for name, field in six.iteritems(fields): source = field.source @@ -513,10 +538,7 @@ def _build_queryset( if requirements is None: requirements = TreeMap() - self._get_implicit_requirements( - fields, - requirements - ) + self._get_implicit_requirements(fields, requirements) # Implicit requirements (i.e. via `requires`) can potentially # include fields that haven't been explicitly included. @@ -524,56 +546,45 @@ def _build_queryset( implicitly_included = set(requirements.keys()) - set(fields.keys()) if implicitly_included: all_fields = serializer.get_all_fields() - fields.update({ - field: all_fields[field] - for field in implicitly_included - if field in all_fields - }) + fields.update( + { + field: all_fields[field] + for field in implicitly_included + if field in all_fields + } + ) if filters is None: filters = self._get_requested_filters() # build nested Prefetch queryset self._build_requested_prefetches( - prefetches, - requirements, - model, - fields, - filters + prefetches, requirements, model, fields, filters ) # build remaining prefetches out of internal requirements # that are not already covered by request requirements - self._build_implicit_prefetches( - model, - prefetches, - requirements - ) + self._build_implicit_prefetches(model, prefetches, requirements) # use requirements at this level to limit fields selected # only do this for GET requests where we are not requesting the # entire fieldset if ( - '*' not in requirements and - not self.view.is_update() and - not self.view.is_delete() + '*' not in requirements + and not self.view.is_update() + and not self.view.is_delete() ): id_fields = getattr(serializer, 'get_id_fields', lambda: [])() # only include local model fields only = [ - field for field in set( - id_fields + list(requirements.keys()) - ) if is_model_field(model, field) and - not is_field_remote(model, field) + field + for field in set(id_fields + list(requirements.keys())) + if is_model_field(model, field) and not is_field_remote(model, field) ] queryset = queryset.only(*only) # add request filters - query = self._filters_to_query( - includes=filters.get('_include'), - excludes=filters.get('_exclude'), - serializer=serializer - ) + query = self._filters_to_query(filters=filters, serializer=serializer) # add additional filters specified by calling view if extra_filters: @@ -586,9 +597,7 @@ def _build_queryset( try: queryset = queryset.filter(query) except InternalValidationError as e: - raise ValidationError( - dict(e) if hasattr(e, 'error_dict') else list(e) - ) + raise ValidationError(dict(e) if hasattr(e, 'error_dict') else list(e)) except Exception as e: # Some other Django error in parsing the filter. # Very likely a bad query, so throw a ValidationError. @@ -602,10 +611,7 @@ def _build_queryset( # serializers for different subsets of a model or to # implement permissions which work even in sideloads if hasattr(serializer, 'filter_queryset'): - queryset = self._serializer_filter( - serializer=serializer, - queryset=queryset - ) + queryset = self._serializer_filter(serializer=serializer, queryset=queryset) # add prefetches and remove duplicates if necessary prefetch = prefetches.values() @@ -627,8 +633,7 @@ def _create_prefetch(self, source, queryset): def _get_queryset(self, queryset=None, serializer=None): queryset = super(FastDynamicFilterBackend, self)._get_queryset( - queryset=queryset, - serializer=serializer + queryset=queryset, serializer=serializer ) if not isinstance(queryset, FastQuery): @@ -637,15 +642,11 @@ def _get_queryset(self, queryset=None, serializer=None): return queryset def _make_model_queryset(self, model): - queryset = super(FastDynamicFilterBackend, self)._make_model_queryset( - model - ) + queryset = super(FastDynamicFilterBackend, self)._make_model_queryset(model) return FastQuery(queryset) def _serializer_filter(self, serializer=None, queryset=None): - queryset.queryset = serializer.filter_queryset( - queryset.queryset - ) + queryset.queryset = serializer.filter_queryset(queryset.queryset) return queryset @@ -690,9 +691,7 @@ def get_ordering(self, request, queryset, view): # if any of the sort fields are invalid, throw an error. # else return the ordering if invalid_ordering: - raise ValidationError( - "Invalid filter field: %s" % invalid_ordering - ) + raise ValidationError("Invalid filter field: %s" % invalid_ordering) else: return valid_ordering @@ -742,8 +741,11 @@ def ordering_for(self, term, view): for segment in serializer_chain[:-1]: field = serializer.get_all_fields().get(segment) - if not (field and field.source != '*' and - isinstance(field, DynamicRelationField)): + if not ( + field + and field.source != '*' + and isinstance(field, DynamicRelationField) + ): return None model_chain.append(field.source or segment) diff --git a/dynamic_rest/viewsets.py b/dynamic_rest/viewsets.py index 7fbdb169..2781bed8 100644 --- a/dynamic_rest/viewsets.py +++ b/dynamic_rest/viewsets.py @@ -2,6 +2,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.http import QueryDict import six +import json from django.db import transaction, IntegrityError from rest_framework import exceptions, status, viewsets from rest_framework.exceptions import ValidationError @@ -86,7 +87,7 @@ class WithDynamicViewSetMixin(object): PER_PAGE, SORT, SIDELOADING, - PATCH_ALL + PATCH_ALL, ) meta = None filter_backends = (DynamicFilterBackend, DynamicSortingFilter) @@ -153,7 +154,7 @@ def get_renderers(self): else: return renderers - def get_request_feature(self, name): + def get_request_feature(self, name, raw=False): """Parses the request for a particular feature. Arguments: @@ -169,22 +170,31 @@ def get_request_feature(self, name): elif '{}' in name: # object-type (keys are not consistent) return self._extract_object_params( - name) if name in self.features else {} + name, raw=raw) if name in self.features else {} else: # single-type return self.request.query_params.get( name) if name in self.features else None - def _extract_object_params(self, name): + def _extract_object_params(self, name, raw=False): """ Extract object params, return as dict """ - params = self.request.query_params.lists() params_map = {} + original_name = name prefix = name[:-1] offset = len(prefix) + for name, value in params: + name_match = name == original_name + if name_match: + if raw and value: + # filter{} as object + return json.loads(value[0]) + else: + continue + if name.startswith(prefix): if name.endswith('}'): name = name[offset:-1] @@ -202,7 +212,7 @@ def _extract_object_params(self, name): continue params_map[name] = value - return params_map + return params_map if not raw else None def get_queryset(self, queryset=None): """ diff --git a/setup.py b/setup.py index ab2fe906..d0fcb245 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ NAME = 'dynamic-rest' DESCRIPTION = 'Dynamic API support to Django REST Framework.' URL = 'http://github.com/AltSchool/dynamic-rest' -VERSION = '2.1.6' +VERSION = '2.1.7' SCRIPTS = ['manage.py'] setup( diff --git a/tests/test_api.py b/tests/test_api.py index 0a144577..147600f5 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,6 +6,7 @@ from django.test import override_settings import six from rest_framework.test import APITestCase +from urllib.parse import quote from tests.models import Cat, Group, Location, Permission, Profile, User from tests.serializers import NestedEphemeralSerializer, PermissionSerializer @@ -405,6 +406,33 @@ def test_get_with_filter_in(self): }, json.loads(response.content.decode('utf-8'))) + def test_get_with_complex_filter(self): + # same filter as the above case + f = { + ".or": [{ + "name": "1" + }, { + ".and": [{ + "name": "2" + }, { + "location": 2 + }] + }] + } + f = quote(json.dumps(f)) + url = f'/users/?filter{{}}={f}' + with self.assertNumQueries(1): + response = self.client.get(url) + self.assertEqual(200, response.status_code) + self.assertEqual( + { + 'users': [ + {'id': 2, 'location': 1, 'name': '1'}, + {'id': 3, 'location': 2, 'name': '2'}, + ] + }, + json.loads(response.content.decode('utf-8'))) + def test_get_with_filter_exclude(self): url = '/users/?filter{-name}=1' with self.assertNumQueries(1): diff --git a/tests/viewsets.py b/tests/viewsets.py index cd04c0dc..7433aa5b 100644 --- a/tests/viewsets.py +++ b/tests/viewsets.py @@ -33,7 +33,7 @@ class UserViewSet(DynamicModelViewSet): features = ( DynamicModelViewSet.INCLUDE, DynamicModelViewSet.EXCLUDE, DynamicModelViewSet.FILTER, DynamicModelViewSet.SORT, - DynamicModelViewSet.SIDELOADING, DynamicModelViewSet.DEBUG + DynamicModelViewSet.SIDELOADING, DynamicModelViewSet.DEBUG, ) model = User serializer_class = UserSerializer @@ -86,7 +86,7 @@ def create(self, request, *args, **kwargs): class GroupViewSet(DynamicModelViewSet): features = ( DynamicModelViewSet.INCLUDE, DynamicModelViewSet.EXCLUDE, - DynamicModelViewSet.FILTER, DynamicModelViewSet.SORT + DynamicModelViewSet.FILTER, DynamicModelViewSet.SORT, ) model = Group serializer_class = GroupSerializer @@ -97,7 +97,7 @@ class LocationViewSet(DynamicModelViewSet): features = ( DynamicModelViewSet.INCLUDE, DynamicModelViewSet.EXCLUDE, DynamicModelViewSet.FILTER, DynamicModelViewSet.SORT, - DynamicModelViewSet.DEBUG, DynamicModelViewSet.SIDELOADING + DynamicModelViewSet.DEBUG, DynamicModelViewSet.SIDELOADING, ) model = Location serializer_class = LocationSerializer @@ -135,7 +135,7 @@ class ProfileViewSet(DynamicModelViewSet): DynamicModelViewSet.EXCLUDE, DynamicModelViewSet.FILTER, DynamicModelViewSet.INCLUDE, - DynamicModelViewSet.SORT + DynamicModelViewSet.SORT, ) model = Profile serializer_class = ProfileSerializer