Skip to content

Commit

Permalink
Add support for overriding fields in serializer subclasses.
Browse files Browse the repository at this point in the history
This is an updated version of encode#1053 applied to current master.
  • Loading branch information
craigds committed Jun 14, 2018
1 parent 28fcdc2 commit 57db4e8
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 7 deletions.
20 changes: 13 additions & 7 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,21 @@ def _get_declared_fields(cls, bases, attrs):
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
all_fields = OrderedDict()
for base in reversed(bases):
if hasattr(base, '_declared_fields'):
fields = [
(field_name, obj) for field_name, obj
in base._declared_fields.items()
if field_name not in attrs
] + fields

return OrderedDict(fields)
for name, field in base._declared_fields.items():
if name in attrs:
continue
if name in all_fields:
# Throw away old ordering, then replace with new one
all_fields.pop(name)
all_fields[name] = field

# if there are fields in both base_fields and fields, OrderedDict
# uses the *last* one defined. So fields needs to go last.
all_fields.update(OrderedDict(fields))
return all_fields

def __new__(cls, name, bases, attrs):
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
Expand Down
82 changes: 82 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,85 @@ class Grandchild(Child):
assert len(Parent().get_fields()) == 2
assert len(Child().get_fields()) == 2
assert len(Grandchild().get_fields()) == 2


class TestSerializerSupportsOverriddenFields:
def setup(self):
class Base1(serializers.Serializer):
a_field = serializers.CharField()
self.Base1 = Base1

class Base2(serializers.Serializer):
a_field = serializers.IntegerField()
self.Base2 = Base2

def test_base_fields_unchanged(self):
"""
Overriding a field in a subclassed serializer shouldn't change the
field on the superclass
"""
class OverriddenFields(self.Base1):
a_field = serializers.FloatField()

assert isinstance(
self.Base1._declared_fields['a_field'],
serializers.CharField,
)
s = self.Base1()
assert isinstance(s.fields['a_field'], serializers.CharField)

def test_overridden_fields_single_base(self):
"""
Subclassing a serializer and overriding a field should mean the field
on the subclass wins.
"""
class OverriddenFieldsWithSingleBase(self.Base1):
a_field = serializers.FloatField()

assert isinstance(
OverriddenFieldsWithSingleBase._declared_fields['a_field'],
serializers.FloatField,
)
s = OverriddenFieldsWithSingleBase()
assert isinstance(s.fields['a_field'], serializers.FloatField)

def test_overridden_fields_multiple_bases(self):
"""
For serializers with multiple bases, the field on the first base wins
(as per normal python method resolution order)
"""
class OverriddenFieldsMultipleBases1(self.Base1, self.Base2):
# first base takes precedence; a_field should be a CharField.
pass

assert isinstance(
OverriddenFieldsMultipleBases1._declared_fields['a_field'],
serializers.CharField,
)
s = OverriddenFieldsMultipleBases1()
assert isinstance(s.fields['a_field'], serializers.CharField)

class OverriddenFieldsMultipleBases2(self.Base2, self.Base1):
# first base takes precedence; a_field should be a IntegerField.
pass

assert isinstance(
OverriddenFieldsMultipleBases2._declared_fields['a_field'],
serializers.IntegerField,
)
s = OverriddenFieldsMultipleBases2()
assert isinstance(s.fields['a_field'], serializers.IntegerField)

def test_overridden_fields_multiple_bases_overridden(self):
"""
For serializers with multiple bases, locally defined fields still win.
"""
class OverriddenFieldsMultipleBasesOverridden(self.Base1, self.Base2):
a_field = serializers.FloatField()

assert isinstance(
OverriddenFieldsMultipleBasesOverridden._declared_fields['a_field'],
serializers.FloatField,
)
s = OverriddenFieldsMultipleBasesOverridden()
assert isinstance(s.fields['a_field'], serializers.FloatField)

0 comments on commit 57db4e8

Please sign in to comment.