Skip to content

Commit

Permalink
Merge pull request #27 from AltSchool/fix/filter-remote
Browse files Browse the repository at this point in the history
fix filter rewrites for m2o/m2m fields
  • Loading branch information
aleontiev committed Jun 11, 2015
2 parents 8ac131d + 3302dbc commit 85d7f91
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 17 deletions.
39 changes: 26 additions & 13 deletions dynamic_rest/fields.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,52 @@
import importlib
from itertools import chain
import os

from rest_framework import fields
from rest_framework.exceptions import ParseError, NotFound
from django.conf import settings
from django.db.models.related import RelatedObject
from django.db.models import ManyToManyField

from dynamic_rest.bases import DynamicSerializerBase


def field_is_remote(model, field_name):
def get_model_field(model, field_name):
"""
Helper function to determine whether model field is remote or not.
Remote fields are many-to-many or many-to-one.
Helper function to get model field, including related fields.
"""
if not hasattr(model, '_meta'):
# ephemeral model with no metaclass
return False

meta = model._meta
try:
model_field = meta.get_field_by_name(field_name)[0]
return isinstance(model_field, (ManyToManyField, RelatedObject))
return meta.get_field_by_name(field_name)[0]
except:
related_object_names = {
o.get_accessor_name()
related_objects = {
o.get_accessor_name(): o
for o in chain(
meta.get_all_related_objects(),
meta.get_all_related_many_to_many_objects()
)
}
if field_name in related_object_names:
return True
if field_name in related_objects:
return related_objects[field_name]
else:
raise AttributeError(
'%s is not a valid field for %s' % (field_name, model)
)


def field_is_remote(model, field_name):
"""
Helper function to determine whether model field is remote or not.
Remote fields are many-to-many or many-to-one.
"""
if not hasattr(model, '_meta'):
# ephemeral model with no metaclass
return False

model_field = get_model_field(model, field_name)
return isinstance(model_field, (ManyToManyField, RelatedObject))


class DynamicField(fields.Field):

"""
Expand Down Expand Up @@ -187,6 +196,10 @@ def to_representation(self, instance):
return serializer.to_representation(related)
except Exception as e:
# Provide more context to help debug these cases
if getattr(settings, 'DEBUG', False) or os.environ.get(
'DREST_DEBUG', False):
import traceback
traceback.print_exc()
raise Exception(
"Failed to serialize %s.%s: %s\nObj: %s" %
(self.parent.__class__.__name__,
Expand Down
14 changes: 12 additions & 2 deletions dynamic_rest/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from django.db.models import Q, Prefetch
from django.db.models.related import RelatedObject
from dynamic_rest.datastructures import TreeMap
from dynamic_rest.fields import DynamicRelationField, field_is_remote
from dynamic_rest.fields import (
DynamicRelationField, field_is_remote, get_model_field
)

from rest_framework.exceptions import ValidationError
from rest_framework import serializers
Expand Down Expand Up @@ -53,9 +56,16 @@ def generate_query_key(self, serializer):

field = fields[field_name]

# For remote fields, strip off '_set' for filtering. This is a
# weird Django inconsistency.
model_field_name = field.source or field_name
model_field = get_model_field(s.get_model(), model_field_name)
if isinstance(model_field, RelatedObject):
model_field_name = model_field.field.related_query_name()

# If get_all_fields() was used above, field could be unbound,
# and field.source would be None
rewritten.append(field.source or field_name)
rewritten.append(model_field_name)

if i == last:
break
Expand Down
11 changes: 11 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ class User(models.Model):
location = models.ForeignKey('Location', null=True, blank=True)


class Cat(models.Model):
name = models.TextField()
home = models.ForeignKey('Location')
backup_home = models.ForeignKey('Location', related_name='friendly_cats')
hunting_grounds = models.ManyToManyField(
'Location',
related_name='annoying_cats',
related_query_name='getoffmylawn'
)


class Group(models.Model):
name = models.TextField()
permissions = models.ManyToManyField('Permission', related_name='groups')
Expand Down
21 changes: 19 additions & 2 deletions tests/serializers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from tests.models import Location, Permission, Group, User
from tests.models import Location, Permission, Group, User, Cat
from dynamic_rest.serializers import DynamicModelSerializer
from dynamic_rest.serializers import DynamicEphemeralSerializer
from dynamic_rest.fields import DynamicRelationField, CountField, DynamicField


class CatSerializer(DynamicModelSerializer):

class Meta:
model = Cat
name = 'cat'
deferred_fields = ('home', 'backup_home', 'hunting_grous')


class LocationSerializer(DynamicModelSerializer):

class Meta:
model = Location
name = 'location'
fields = ('id', 'name', 'users', 'user_count', 'address', 'metadata')
fields = (
'id', 'name', 'users', 'user_count', 'address', 'metadata',
'cats', 'friendly_cats', 'bad_cats'
)

users = DynamicRelationField(
'UserSerializer',
Expand All @@ -19,6 +30,12 @@ class Meta:
user_count = CountField('users', required=False, deferred=True)
address = DynamicField(source='blob', required=False, deferred=True)
metadata = DynamicField(deferred=True, required=False)
cats = DynamicRelationField(
'CatSerializer', source='cat_set', many=True, deferred=True)
friendly_cats = DynamicRelationField(
'CatSerializer', many=True, deferred=True)
bad_cats = DynamicRelationField(
'CatSerializer', source='annoying_cats', many=True, deferred=True)


class PermissionSerializer(DynamicModelSerializer):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,24 @@ def testCreate(self):
content = json.loads(response.content)
self.assertEqual(content['location']['metadata'], data['metadata'])

def testFilterByUser(self):
url = '/locations/?filter{users}=1'
response = self.client.get(url)
self.assertEqual(200, response.status_code)
content = json.loads(response.content)
self.assertEqual(1, len(content['locations']))

def testCatFilters(self):
"""Tests various filter rewrite scenarios"""
urls = [
'/locations/?filter{cats}=1',
'/locations/?filter{friendly_cats}=1',
'/locations/?filter{bad_cats}=1'
]
for url in urls:
response = self.client.get(url)
self.assertEqual(200, response.status_code)


class TestUserLocationsAPI(APITestCase):
"""
Expand Down

0 comments on commit 85d7f91

Please sign in to comment.