From 57db4e8616684ad65aa12c105c2a540b2ccf1cd0 Mon Sep 17 00:00:00 2001 From: Craig de Stigter Date: Fri, 4 Dec 2015 14:11:55 +1300 Subject: [PATCH] Add support for overriding fields in serializer subclasses. This is an updated version of tomchristie/django-rest-framework#1053 applied to current master. --- rest_framework/serializers.py | 20 ++++++--- tests/test_serializer.py | 82 +++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 7 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e2ea0d7440..677611acd4 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -303,15 +303,21 @@ def _get_declared_fields(cls, bases, attrs): # If this class is subclassing another Serializer, add that Serializer's # fields. Note that we loop over the bases in *reverse*. This is necessary # in order to maintain the correct order of fields. + all_fields = OrderedDict() for base in reversed(bases): if hasattr(base, '_declared_fields'): - fields = [ - (field_name, obj) for field_name, obj - in base._declared_fields.items() - if field_name not in attrs - ] + fields - - return OrderedDict(fields) + for name, field in base._declared_fields.items(): + if name in attrs: + continue + if name in all_fields: + # Throw away old ordering, then replace with new one + all_fields.pop(name) + all_fields[name] = field + + # if there are fields in both base_fields and fields, OrderedDict + # uses the *last* one defined. So fields needs to go last. + all_fields.update(OrderedDict(fields)) + return all_fields def __new__(cls, name, bases, attrs): attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index da1394515c..9b9076aeab 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -598,3 +598,85 @@ class Grandchild(Child): assert len(Parent().get_fields()) == 2 assert len(Child().get_fields()) == 2 assert len(Grandchild().get_fields()) == 2 + + +class TestSerializerSupportsOverriddenFields: + def setup(self): + class Base1(serializers.Serializer): + a_field = serializers.CharField() + self.Base1 = Base1 + + class Base2(serializers.Serializer): + a_field = serializers.IntegerField() + self.Base2 = Base2 + + def test_base_fields_unchanged(self): + """ + Overriding a field in a subclassed serializer shouldn't change the + field on the superclass + """ + class OverriddenFields(self.Base1): + a_field = serializers.FloatField() + + assert isinstance( + self.Base1._declared_fields['a_field'], + serializers.CharField, + ) + s = self.Base1() + assert isinstance(s.fields['a_field'], serializers.CharField) + + def test_overridden_fields_single_base(self): + """ + Subclassing a serializer and overriding a field should mean the field + on the subclass wins. + """ + class OverriddenFieldsWithSingleBase(self.Base1): + a_field = serializers.FloatField() + + assert isinstance( + OverriddenFieldsWithSingleBase._declared_fields['a_field'], + serializers.FloatField, + ) + s = OverriddenFieldsWithSingleBase() + assert isinstance(s.fields['a_field'], serializers.FloatField) + + def test_overridden_fields_multiple_bases(self): + """ + For serializers with multiple bases, the field on the first base wins + (as per normal python method resolution order) + """ + class OverriddenFieldsMultipleBases1(self.Base1, self.Base2): + # first base takes precedence; a_field should be a CharField. + pass + + assert isinstance( + OverriddenFieldsMultipleBases1._declared_fields['a_field'], + serializers.CharField, + ) + s = OverriddenFieldsMultipleBases1() + assert isinstance(s.fields['a_field'], serializers.CharField) + + class OverriddenFieldsMultipleBases2(self.Base2, self.Base1): + # first base takes precedence; a_field should be a IntegerField. + pass + + assert isinstance( + OverriddenFieldsMultipleBases2._declared_fields['a_field'], + serializers.IntegerField, + ) + s = OverriddenFieldsMultipleBases2() + assert isinstance(s.fields['a_field'], serializers.IntegerField) + + def test_overridden_fields_multiple_bases_overridden(self): + """ + For serializers with multiple bases, locally defined fields still win. + """ + class OverriddenFieldsMultipleBasesOverridden(self.Base1, self.Base2): + a_field = serializers.FloatField() + + assert isinstance( + OverriddenFieldsMultipleBasesOverridden._declared_fields['a_field'], + serializers.FloatField, + ) + s = OverriddenFieldsMultipleBasesOverridden() + assert isinstance(s.fields['a_field'], serializers.FloatField)