Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Enum collision with same choices & varying labels #790 #1104 #1113

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions drf_spectacular/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def create_enum_component(name, schema):
generator.registry.register_on_missing(component)
return component

def extract_hash(schema):
if 'x-spec-enum-id' in schema:
# try to use the injected enum hash first as it generated from (name, value) tuples,
# which prevents collisions on choice sets only differing in labels not values.
return schema['x-spec-enum-id']
else:
# fall back to actual list hashing when we encounter enums not generated by us.
# remove blank/null entry for hashing. will be reconstructed in the last step
return list_hash([(i, i) for i in schema['enum'] if i not in ('', None)])

schemas = result.get('components', {}).get('schemas', {})

overrides = load_enum_name_overrides()
Expand All @@ -58,8 +68,8 @@ def create_enum_component(name, schema):
prop_schema = prop_schema.get('items', {})
if 'enum' not in prop_schema:
continue
# remove blank/null entry for hashing. will be reconstructed in the last step
prop_enum_cleaned_hash = list_hash([i for i in prop_schema['enum'] if i not in ['', None]])

prop_enum_cleaned_hash = extract_hash(prop_schema)
prop_hash_mapping[prop_name].add(prop_enum_cleaned_hash)
hash_name_mapping[prop_enum_cleaned_hash].add((component_name, prop_name))

Expand Down Expand Up @@ -110,14 +120,14 @@ def create_enum_component(name, schema):

prop_enum_original_list = prop_schema['enum']
prop_schema['enum'] = [i for i in prop_schema['enum'] if i not in ['', None]]
prop_hash = list_hash(prop_schema['enum'])
prop_hash = extract_hash(prop_schema)
# when choice sets are reused under multiple names, the generated name cannot be
# resolved from the hash alone. fall back to prop_name and hash for resolution.
enum_name = enum_name_mapping.get(prop_hash) or enum_name_mapping[prop_hash, prop_name]

# split property into remaining property and enum component parts
enum_schema = {k: v for k, v in prop_schema.items() if k in ['type', 'enum']}
prop_schema = {k: v for k, v in prop_schema.items() if k not in ['type', 'enum']}
prop_schema = {k: v for k, v in prop_schema.items() if k not in ['type', 'enum', 'x-spec-enum-id']}

# separate actual description from name-value tuples
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION:
Expand Down Expand Up @@ -148,6 +158,31 @@ def create_enum_component(name, schema):

# sort again with additional components
result['components'] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS)

# remove remaining ids that were not part of this hook (operation parameters mainly)
postprocess_schema_enum_id_removal(result, generator)

return result


def postprocess_schema_enum_id_removal(result, generator, **kwargs):
"""
Iterative modifying approach to scanning the whole schema and removing the
temporary helper ids that allowed us to distinguish similar enums.
"""
def clean(sub_result):
if isinstance(sub_result, dict):
for key in list(sub_result):
if key == 'x-spec-enum-id':
del sub_result['x-spec-enum-id']
else:
clean(sub_result[key])
elif isinstance(sub_result, (list, tuple)):
for item in sub_result:
clean(item)

clean(result)

return result


Expand Down
15 changes: 11 additions & 4 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ def build_choice_field(field):
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION:
schema['description'] = build_choice_description_list(field.choices.items())

schema['x-spec-enum-id'] = list_hash([(k, v) for k, v in field.choices.items() if k not in ('', None)])

return schema


Expand Down Expand Up @@ -499,10 +501,12 @@ def build_root_object(paths, components, version):
def safe_ref(schema):
"""
ensure that $ref has its own context and does not remove potential sibling
entries when $ref is substituted.
entries when $ref is substituted. also remove useless singular "allOf" .
"""
if '$ref' in schema and len(schema) > 1:
return {'allOf': [{'$ref': schema.pop('$ref')}], **schema}
if 'allOf' in schema and len(schema) == 1 and len(schema['allOf']) == 1:
return schema['allOf'][0]
return schema


Expand Down Expand Up @@ -815,11 +819,12 @@ def load_enum_name_overrides():
if inspect.isclass(choices) and issubclass(choices, Choices):
choices = choices.choices
if inspect.isclass(choices) and issubclass(choices, Enum):
choices = [c.value for c in choices]
choices = [(c.value, c.name) for c in choices]
normalized_choices = []
for choice in choices:
# Allow None values in the simple values list case
if isinstance(choice, str) or choice is None:
# TODO warning
normalized_choices.append((choice, choice)) # simple choice list
elif isinstance(choice[1], (list, tuple)):
normalized_choices.extend(choice[1]) # categorized nested choices
Expand All @@ -828,7 +833,9 @@ def load_enum_name_overrides():

# Get all of choice values that should be used in the hash, blank and None values get excluded
# in the post-processing hook for enum overrides, so we do the same here to ensure the hashes match
hashable_values = [value for value, _ in normalized_choices if value not in ['', None]]
hashable_values = [
(value, label) for value, label in normalized_choices if value not in ['', None]
]
overrides[list_hash(hashable_values)] = name

if len(spectacular_settings.ENUM_NAME_OVERRIDES) != len(overrides):
Expand All @@ -840,7 +847,7 @@ def load_enum_name_overrides():


def list_hash(lst):
return hashlib.sha256(json.dumps(list(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest()
return hashlib.sha256(json.dumps(list(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest()[:16]


def anchor_pattern(pattern: str) -> str:
Expand Down
18 changes: 17 additions & 1 deletion tests/test_plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from drf_spectacular.plumbing import (
analyze_named_regex_pattern, build_basic_type, build_choice_field, detype_pattern,
follow_field_source, force_instance, get_list_serializer, is_field, is_serializer,
resolve_type_hint,
resolve_type_hint, safe_ref,
)
from drf_spectacular.validation import validate_schema
from tests import generate_schema
Expand Down Expand Up @@ -377,3 +377,19 @@ def test_choicefield_choices_enum():
))
assert schema['enum'] == ['bluepill', 'redpill', '', None]
assert 'type' not in schema


def test_safe_ref():
schema = build_basic_type(str)
schema['$ref'] = '#/components/schemas/Foo'

schema = safe_ref(schema)
assert schema == {
'allOf': [{'$ref': '#/components/schemas/Foo'}],
'type': 'string'
}

del schema['type']
schema = safe_ref(schema)
assert schema == {'$ref': '#/components/schemas/Foo'}
assert safe_ref(schema) == safe_ref(schema)
74 changes: 60 additions & 14 deletions tests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

import pytest
from django import __version__ as DJANGO_VERSION
from django.utils.translation import gettext_lazy as _
from rest_framework import generics, mixins, serializers, viewsets
from rest_framework.decorators import action
from rest_framework.views import APIView

try:
from django.db.models.enums import TextChoices
from django.db.models.enums import IntegerChoices, TextChoices
except ImportError:
TextChoices = object # type: ignore # django < 3.0 handling
IntegerChoices = object # type: ignore # django < 3.0 handling

from drf_spectacular.plumbing import list_hash, load_enum_name_overrides
from drf_spectacular.utils import OpenApiParameter, extend_schema
Expand Down Expand Up @@ -244,35 +246,39 @@ def partial_update(self, request):


def test_enum_override_variations(no_warnings):
enum_override_variations = ['language_list', 'LanguageEnum', 'LanguageStrEnum']
enum_override_variations = [
('language_list', [('en', 'en')]),
('LanguageEnum', [('en', 'EN')]),
('LanguageStrEnum', [('en', 'EN')]),
]
if DJANGO_VERSION > '3':
enum_override_variations += ['LanguageChoices', 'LanguageChoices.choices']
enum_override_variations += [
('LanguageChoices', [('en', 'En')]),
('LanguageChoices.choices', [('en', 'En')])
]

for variation in enum_override_variations:
for variation, expected_hashed_keys in enum_override_variations:
with mock.patch(
'drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES',
{'LanguageEnum': f'tests.test_postprocessing.{variation}'}
):
load_enum_name_overrides.cache_clear()
assert list_hash(['en']) in load_enum_name_overrides()
assert list_hash(expected_hashed_keys) in load_enum_name_overrides()


def test_enum_override_variations_with_blank_and_null(no_warnings):
enum_override_variations = [
'blank_null_language_list',
'BlankNullLanguageEnum',
('BlankNullLanguageStrEnum', ['en', 'None'])
('blank_null_language_list', [('en', 'en')]),
('BlankNullLanguageEnum', [('en', 'EN')]),
('BlankNullLanguageStrEnum', [('en', 'EN'), ('None', 'NULL')])
]
if DJANGO_VERSION > '3':
enum_override_variations += [
('BlankNullLanguageChoices', ['en', 'None']),
('BlankNullLanguageChoices.choices', ['en', 'None'])
('BlankNullLanguageChoices', [('en', 'En'), ('None', 'Null')]),
('BlankNullLanguageChoices.choices', [('en', 'En'), ('None', 'Null')])
]

for variation in enum_override_variations:
expected_hashed_keys = ['en']
if isinstance(variation, (list, tuple, )):
variation, expected_hashed_keys = variation
for variation, expected_hashed_keys in enum_override_variations:
with mock.patch(
'drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES',
{'LanguageEnum': f'tests.test_postprocessing.{variation}'}
Expand Down Expand Up @@ -340,3 +346,43 @@ def get(self, request):
uuid.UUID('93d7527f-de3c-4a76-9cc2-5578675630d4'),
uuid.UUID('47a4b873-409e-4e43-81d5-fafc3faeb849')
]


@pytest.mark.skipif(DJANGO_VERSION < '3', reason='Not available before Django 3.0')
def test_equal_choices_different_semantics(no_warnings):

class Health(IntegerChoices):
OK = 0
FAIL = 1

class Status(IntegerChoices):
GREEN = 0
RED = 1

class Test(IntegerChoices):
A = 0, _("test group A")
B = 1, _("test group B")

class XSerializer(serializers.Serializer):
some_health = serializers.ChoiceField(choices=Health.choices)
some_status = serializers.ChoiceField(choices=Status.choices)
some_test = serializers.ChoiceField(choices=Test.choices)

class XAPIView(APIView):
@extend_schema(responses=XSerializer)
def get(self, request):
pass # pragma: no cover

# This should not generate a warning even though the enum list is identical
# in both Enums. We now also differentiate the Enums by their labels.
schema = generate_schema('x', view=XAPIView)

assert schema['components']['schemas']['SomeHealthEnum'] == {
'enum': [0, 1], 'type': 'integer', 'description': '* `0` - Ok\n* `1` - Fail'
}
assert schema['components']['schemas']['SomeStatusEnum'] == {
'enum': [0, 1], 'type': 'integer', 'description': '* `0` - Green\n* `1` - Red'
}
assert schema['components']['schemas']['SomeTestEnum'] == {
'enum': [0, 1], 'type': 'integer', 'description': '* `0` - test group A\n* `1` - test group B',
}