diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index c1b36096..4fa4ce36 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -796,13 +796,18 @@ def load_enum_name_overrides(): choices = [c.value for c in choices] normalized_choices = [] for choice in choices: - if isinstance(choice, str): + # Allow None values in the simple values list case + if isinstance(choice, str) or choice is None: normalized_choices.append((choice, choice)) # simple choice list elif isinstance(choice[1], (list, tuple)): normalized_choices.extend(choice[1]) # categorized nested choices else: normalized_choices.append(choice) # normal 2-tuple form - overrides[list_hash(list(dict(normalized_choices).keys()))] = name + + # Get all of choice values that should be used in the hash, blank and None values get excluded + # in the post-processing hook for enum overrides, so we do the same here to ensure the hashes match + hashable_values = [value for value, _ in normalized_choices if value not in ['', None]] + overrides[list_hash(hashable_values)] = name if len(spectacular_settings.ENUM_NAME_OVERRIDES) != len(overrides): error( diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index a3c1990a..3ef0c3fc 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -24,6 +24,15 @@ ('cn', 'cn'), ) +blank_null_language_choices = ( + ('en', 'en'), + ('es', 'es'), + ('ru', 'ru'), + ('cn', 'cn'), + ('', 'not provided'), + (None, 'unknown'), +) + vote_choices = ( (1, 'Positive'), (0, 'Neutral'), @@ -45,6 +54,29 @@ class LanguageChoices(TextChoices): EN = 'en' +blank_null_language_list = ['en', '', None] + + +class BlankNullLanguageEnum(Enum): + EN = 'en' + BLANK = '' + NULL = None + + +class BlankNullLanguageStrEnum(str, Enum): + EN = 'en' + BLANK = '' + # These will still be included since the values get cast to strings so 'None' != None + NULL = None + + +class BlankNullLanguageChoices(TextChoices): + EN = 'en' + BLANK = '' + # These will still be included since the values get cast to strings so 'None' != None + NULL = None + + class ASerializer(serializers.Serializer): language = serializers.ChoiceField(choices=language_choices) vote = serializers.ChoiceField(choices=vote_choices) @@ -96,6 +128,40 @@ class XView(generics.RetrieveAPIView): assert len(schema['components']['schemas']) == 2 +@mock.patch('drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES', { + 'LanguageEnum': 'tests.test_postprocessing.blank_null_language_choices' +}) +def test_global_enum_naming_override_with_blank_and_none(no_warnings, clear_caches): + """Test that choices with blank values can still have their name overridden.""" + class XSerializer(serializers.Serializer): + foo = serializers.ChoiceField(choices=blank_null_language_choices) + bar = serializers.ChoiceField(choices=blank_null_language_choices) + + class XView(generics.RetrieveAPIView): + serializer_class = XSerializer + + schema = generate_schema('/x', view=XView) + foo_data = schema['components']['schemas']['X']['properties']['foo'] + bar_data = schema['components']['schemas']['X']['properties']['bar'] + + assert len(foo_data['oneOf']) == 3 + assert len(bar_data['oneOf']) == 3 + + foo_ref_values = [ref_object['$ref'] for ref_object in foo_data['oneOf']] + bar_ref_values = [ref_object['$ref'] for ref_object in bar_data['oneOf']] + + assert foo_ref_values == [ + '#/components/schemas/LanguageEnum', + '#/components/schemas/BlankEnum', + '#/components/schemas/NullEnum' + ] + assert bar_ref_values == [ + '#/components/schemas/LanguageEnum', + '#/components/schemas/BlankEnum', + '#/components/schemas/NullEnum' + ] + + def test_enum_name_reuse_warning(capsys): class XSerializer(serializers.Serializer): foo = serializers.ChoiceField(choices=language_choices) @@ -191,6 +257,31 @@ def test_enum_override_variations(no_warnings): assert list_hash(['en']) in load_enum_name_overrides() +def test_enum_override_variations_with_blank_and_null(no_warnings): + enum_override_variations = [ + 'blank_null_language_list', + 'BlankNullLanguageEnum', + ('BlankNullLanguageStrEnum', ['en', 'None']) + ] + if DJANGO_VERSION > '3': + enum_override_variations += [ + ('BlankNullLanguageChoices', ['en', 'None']), + ('BlankNullLanguageChoices.choices', ['en', 'None']) + ] + + for variation in enum_override_variations: + expected_hashed_keys = ['en'] + if isinstance(variation, (list, tuple, )): + variation, expected_hashed_keys = variation + with mock.patch( + 'drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES', + {'LanguageEnum': f'tests.test_postprocessing.{variation}'} + ): + load_enum_name_overrides.cache_clear() + # Should match after None and blank strings are removed + assert list_hash(expected_hashed_keys) in load_enum_name_overrides() + + @mock.patch('drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES', { 'LanguageEnum': 'tests.test_postprocessing.NOTEXISTING' })