From ef6dea70a801dd2da8629c3f5d06dbd5a64480d0 Mon Sep 17 00:00:00 2001 From: Craig de Stigter Date: Fri, 4 Dec 2015 14:11:55 +1300 Subject: [PATCH] Add support for overriding fields in serializer subclasses. This is an updated version of tomchristie/django-rest-framework#1053 applied to current master. --- rest_framework/serializers.py | 14 ++++-- tests/test_serializer.py | 82 +++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 3 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4d1ed63aef..ae9f902ea1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -281,11 +281,19 @@ def _get_declared_fields(cls, bases, attrs): # If this class is subclassing another Serializer, add that Serializer's # fields. Note that we loop over the bases in *reverse*. This is necessary # in order to maintain the correct order of fields. + all_fields = OrderedDict() for base in reversed(bases): if hasattr(base, '_declared_fields'): - fields = list(base._declared_fields.items()) + fields - - return OrderedDict(fields) + for name, field in base._declared_fields.items(): + if name in all_fields: + # Throw away old ordering, then replace with new one + all_fields.pop(name) + all_fields[name] = field + + # if there are fields in both base_fields and fields, OrderedDict + # uses the *last* one defined. So fields needs to go last. + all_fields.update(OrderedDict(fields)) + return all_fields def __new__(cls, name, bases, attrs): attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index bd9ef95002..9efb4cccef 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -351,3 +351,85 @@ def test_validation_success(self): assert serializer.is_valid() assert serializer.validated_data == {'name': '2'} assert serializer.errors == {} + + +class TestSerializerSupportsOverriddenFields: + def setup(self): + class Base1(serializers.Serializer): + a_field = serializers.CharField() + self.Base1 = Base1 + + class Base2(serializers.Serializer): + a_field = serializers.IntegerField() + self.Base2 = Base2 + + def test_base_fields_unchanged(self): + """ + Overriding a field in a subclassed serializer shouldn't change the + field on the superclass + """ + class OverriddenFields(self.Base1): + a_field = serializers.FloatField() + + assert isinstance( + self.Base1._declared_fields['a_field'], + serializers.CharField, + ) + s = self.Base1() + assert isinstance(s.fields['a_field'], serializers.CharField) + + def test_overridden_fields_single_base(self): + """ + Subclassing a serializer and overriding a field should mean the field + on the subclass wins. + """ + class OverriddenFieldsWithSingleBase(self.Base1): + a_field = serializers.FloatField() + + assert isinstance( + OverriddenFieldsWithSingleBase._declared_fields['a_field'], + serializers.FloatField, + ) + s = OverriddenFieldsWithSingleBase() + assert isinstance(s.fields['a_field'], serializers.FloatField) + + def test_overridden_fields_multiple_bases(self): + """ + For serializers with multiple bases, the field on the first base wins + (as per normal python method resolution order) + """ + class OverriddenFieldsMultipleBases1(self.Base1, self.Base2): + # first base takes precedence; a_field should be a CharField. + pass + + assert isinstance( + OverriddenFieldsMultipleBases1._declared_fields['a_field'], + serializers.CharField, + ) + s = OverriddenFieldsMultipleBases1() + assert isinstance(s.fields['a_field'], serializers.CharField) + + class OverriddenFieldsMultipleBases2(self.Base2, self.Base1): + # first base takes precedence; a_field should be a IntegerField. + pass + + assert isinstance( + OverriddenFieldsMultipleBases2._declared_fields['a_field'], + serializers.IntegerField, + ) + s = OverriddenFieldsMultipleBases2() + assert isinstance(s.fields['a_field'], serializers.IntegerField) + + def test_overridden_fields_multiple_bases_overridden(self): + """ + For serializers with multiple bases, locally defined fields still win. + """ + class OverriddenFieldsMultipleBasesOverridden(self.Base1, self.Base2): + a_field = serializers.FloatField() + + assert isinstance( + OverriddenFieldsMultipleBasesOverridden._declared_fields['a_field'], + serializers.FloatField, + ) + s = OverriddenFieldsMultipleBasesOverridden() + assert isinstance(s.fields['a_field'], serializers.FloatField)