From db3bda91db0d4e15f94c0e6cb1bff61047ef24c6 Mon Sep 17 00:00:00 2001 From: Ryo Chijiiwa Date: Thu, 11 Jun 2015 12:38:16 -0700 Subject: [PATCH 1/2] fix filter rewrites for m2o/m2m fields --- dynamic_rest/filters.py | 9 ++++++++- tests/test_api.py | 7 +++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py index ae5bee50..2f0e85e2 100644 --- a/dynamic_rest/filters.py +++ b/dynamic_rest/filters.py @@ -53,9 +53,16 @@ def generate_query_key(self, serializer): field = fields[field_name] + # For remote fields, strip off '_set' for filtering. This is a + # weird Django inconsistency. + model_field_name = field.source or field_name + if model_field_name.endswith('_set') and field_is_remote( + s.get_model(), model_field_name): + model_field_name = model_field_name[:-4] + # If get_all_fields() was used above, field could be unbound, # and field.source would be None - rewritten.append(field.source or field_name) + rewritten.append(model_field_name) if i == last: break diff --git a/tests/test_api.py b/tests/test_api.py index b088bc44..5c4c920d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -491,6 +491,13 @@ def testCreate(self): content = json.loads(response.content) self.assertEqual(content['location']['metadata'], data['metadata']) + def testFilterByUser(self): + url = '/locations/?filter{users}=1' + response = self.client.get(url) + self.assertEqual(200, response.status_code) + content = json.loads(response.content) + self.assertEqual(1, len(content['locations'])) + class TestUserLocationsAPI(APITestCase): """ From 3302dbc684d557c7387c65c5fa26d7a913d6b04c Mon Sep 17 00:00:00 2001 From: Ryo Chijiiwa Date: Thu, 11 Jun 2015 15:29:28 -0700 Subject: [PATCH 2/2] more robust filter rewrites on remote fields, more tests --- dynamic_rest/fields.py | 39 ++++++++++++++++++++++++++------------- dynamic_rest/filters.py | 11 +++++++---- tests/models.py | 11 +++++++++++ tests/serializers.py | 21 +++++++++++++++++++-- tests/test_api.py | 11 +++++++++++ 5 files changed, 74 insertions(+), 19 deletions(-) diff --git a/dynamic_rest/fields.py b/dynamic_rest/fields.py index 1c352a2a..c8d4badc 100644 --- a/dynamic_rest/fields.py +++ b/dynamic_rest/fields.py @@ -1,43 +1,52 @@ import importlib from itertools import chain +import os from rest_framework import fields from rest_framework.exceptions import ParseError, NotFound +from django.conf import settings from django.db.models.related import RelatedObject from django.db.models import ManyToManyField from dynamic_rest.bases import DynamicSerializerBase -def field_is_remote(model, field_name): +def get_model_field(model, field_name): """ - Helper function to determine whether model field is remote or not. - Remote fields are many-to-many or many-to-one. + Helper function to get model field, including related fields. """ - if not hasattr(model, '_meta'): - # ephemeral model with no metaclass - return False - meta = model._meta try: - model_field = meta.get_field_by_name(field_name)[0] - return isinstance(model_field, (ManyToManyField, RelatedObject)) + return meta.get_field_by_name(field_name)[0] except: - related_object_names = { - o.get_accessor_name() + related_objects = { + o.get_accessor_name(): o for o in chain( meta.get_all_related_objects(), meta.get_all_related_many_to_many_objects() ) } - if field_name in related_object_names: - return True + if field_name in related_objects: + return related_objects[field_name] else: raise AttributeError( '%s is not a valid field for %s' % (field_name, model) ) +def field_is_remote(model, field_name): + """ + Helper function to determine whether model field is remote or not. + Remote fields are many-to-many or many-to-one. + """ + if not hasattr(model, '_meta'): + # ephemeral model with no metaclass + return False + + model_field = get_model_field(model, field_name) + return isinstance(model_field, (ManyToManyField, RelatedObject)) + + class DynamicField(fields.Field): """ @@ -187,6 +196,10 @@ def to_representation(self, instance): return serializer.to_representation(related) except Exception as e: # Provide more context to help debug these cases + if getattr(settings, 'DEBUG', False) or os.environ.get( + 'DREST_DEBUG', False): + import traceback + traceback.print_exc() raise Exception( "Failed to serialize %s.%s: %s\nObj: %s" % (self.parent.__class__.__name__, diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py index 2f0e85e2..89f09c2a 100644 --- a/dynamic_rest/filters.py +++ b/dynamic_rest/filters.py @@ -1,6 +1,9 @@ from django.db.models import Q, Prefetch +from django.db.models.related import RelatedObject from dynamic_rest.datastructures import TreeMap -from dynamic_rest.fields import DynamicRelationField, field_is_remote +from dynamic_rest.fields import ( + DynamicRelationField, field_is_remote, get_model_field +) from rest_framework.exceptions import ValidationError from rest_framework import serializers @@ -56,9 +59,9 @@ def generate_query_key(self, serializer): # For remote fields, strip off '_set' for filtering. This is a # weird Django inconsistency. model_field_name = field.source or field_name - if model_field_name.endswith('_set') and field_is_remote( - s.get_model(), model_field_name): - model_field_name = model_field_name[:-4] + model_field = get_model_field(s.get_model(), model_field_name) + if isinstance(model_field, RelatedObject): + model_field_name = model_field.field.related_query_name() # If get_all_fields() was used above, field could be unbound, # and field.source would be None diff --git a/tests/models.py b/tests/models.py index 165ecd9c..b2680199 100644 --- a/tests/models.py +++ b/tests/models.py @@ -11,6 +11,17 @@ class User(models.Model): location = models.ForeignKey('Location', null=True, blank=True) +class Cat(models.Model): + name = models.TextField() + home = models.ForeignKey('Location') + backup_home = models.ForeignKey('Location', related_name='friendly_cats') + hunting_grounds = models.ManyToManyField( + 'Location', + related_name='annoying_cats', + related_query_name='getoffmylawn' + ) + + class Group(models.Model): name = models.TextField() permissions = models.ManyToManyField('Permission', related_name='groups') diff --git a/tests/serializers.py b/tests/serializers.py index 988b10ca..0c51fb46 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -1,15 +1,26 @@ -from tests.models import Location, Permission, Group, User +from tests.models import Location, Permission, Group, User, Cat from dynamic_rest.serializers import DynamicModelSerializer from dynamic_rest.serializers import DynamicEphemeralSerializer from dynamic_rest.fields import DynamicRelationField, CountField, DynamicField +class CatSerializer(DynamicModelSerializer): + + class Meta: + model = Cat + name = 'cat' + deferred_fields = ('home', 'backup_home', 'hunting_grous') + + class LocationSerializer(DynamicModelSerializer): class Meta: model = Location name = 'location' - fields = ('id', 'name', 'users', 'user_count', 'address', 'metadata') + fields = ( + 'id', 'name', 'users', 'user_count', 'address', 'metadata', + 'cats', 'friendly_cats', 'bad_cats' + ) users = DynamicRelationField( 'UserSerializer', @@ -19,6 +30,12 @@ class Meta: user_count = CountField('users', required=False, deferred=True) address = DynamicField(source='blob', required=False, deferred=True) metadata = DynamicField(deferred=True, required=False) + cats = DynamicRelationField( + 'CatSerializer', source='cat_set', many=True, deferred=True) + friendly_cats = DynamicRelationField( + 'CatSerializer', many=True, deferred=True) + bad_cats = DynamicRelationField( + 'CatSerializer', source='annoying_cats', many=True, deferred=True) class PermissionSerializer(DynamicModelSerializer): diff --git a/tests/test_api.py b/tests/test_api.py index 5c4c920d..22e8827e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -498,6 +498,17 @@ def testFilterByUser(self): content = json.loads(response.content) self.assertEqual(1, len(content['locations'])) + def testCatFilters(self): + """Tests various filter rewrite scenarios""" + urls = [ + '/locations/?filter{cats}=1', + '/locations/?filter{friendly_cats}=1', + '/locations/?filter{bad_cats}=1' + ] + for url in urls: + response = self.client.get(url) + self.assertEqual(200, response.status_code) + class TestUserLocationsAPI(APITestCase): """