Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parametrize component registry identity #1288 #1290

Merged
merged 1 commit into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions drf_spectacular/contrib/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_name(self, auto_schema, direction):
# of the entry model, we simply use the class name as string for object. This hack may
# create false positive warnings, so turn it off. However, this may suppress correct
# warnings involving the entry class.
# TODO suppression may be migrated to new ComponentIdentity system
set_override(self.target, 'suppress_collision_warning', True)
return self.target.__name__

Expand Down
7 changes: 6 additions & 1 deletion drf_spectacular/contrib/rest_framework_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any

from drf_spectacular.drainage import get_override, has_override
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import get_doc
from drf_spectacular.plumbing import ComponentIdentity, get_doc
from drf_spectacular.utils import Direction


Expand All @@ -18,6 +20,9 @@ def get_name(self):
return get_override(self.target.dataclass_definition.dataclass_type, 'component_name')
return self.target.dataclass_definition.dataclass_type.__name__

def get_identity(self, auto_schema, direction: Direction) -> Any:
return ComponentIdentity(self.target.dataclass_definition.dataclass_type)

def strip_library_doc(self, schema):
"""Strip the DataclassSerializer library documentation from the schema."""
from rest_framework_dataclasses.serializers import DataclassSerializer
Expand Down
5 changes: 3 additions & 2 deletions drf_spectacular/contrib/rest_polymorphic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import (
ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer,
ComponentIdentity, ResolvedComponent, build_basic_type, build_object_type,
is_patched_serializer,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
Expand All @@ -25,7 +26,7 @@ def map_serializer(self, auto_schema, direction):
component = ResolvedComponent(
name=auto_schema._get_serializer_name(sub_serializer, direction),
type=ResolvedComponent.SCHEMA,
object='virtual'
object=ComponentIdentity('virtual')
)
typed_component = self.build_typed_component(
auto_schema=auto_schema,
Expand Down
4 changes: 4 additions & 0 deletions drf_spectacular/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[
""" return str for overriding default name extraction """
return None

def get_identity(self, auto_schema: 'AutoSchema', direction: Direction) -> Any:
""" return anything to compare instances of target. Target will be used by default. """
return None

def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
""" override for customized serializer mapping """
return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True)
Expand Down
19 changes: 15 additions & 4 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,12 +1478,13 @@ def _get_response_for_code(self, serializer, status_code, media_types=None, dire
and is_serializer(serializer)
and (not is_list_serializer(serializer) or is_serializer(serializer.child))
):
paginated_name = self.get_paginated_name(self._get_serializer_name(serializer, "response"))
component = ResolvedComponent(
name=paginated_name,
name=self.get_paginated_name(self._get_serializer_name(serializer, 'response')),
type=ResolvedComponent.SCHEMA,
schema=paginator.get_paginated_response_schema(schema),
object=serializer.child if is_list_serializer(serializer) else serializer,
object=self.get_serializer_identity(
serializer.child if is_list_serializer(serializer) else serializer, 'response'
)
)
self.registry.register_on_missing(component)
schema = component.ref
Expand Down Expand Up @@ -1556,7 +1557,17 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> _

return result

def get_serializer_identity(self, serializer, direction: Direction) -> Any:
serializer_extension = OpenApiSerializerExtension.get_match(serializer)
if serializer_extension:
identity = serializer_extension.get_identity(self, direction)
if identity is not None:
return identity

return serializer

def get_serializer_name(self, serializer: serializers.Serializer, direction: Direction) -> str:
""" override this for custom behaviour """
return serializer.__class__.__name__

def _get_serializer_name(self, serializer, direction, bypass_extensions=False) -> str:
Expand Down Expand Up @@ -1612,7 +1623,7 @@ def resolve_serializer(
component = ResolvedComponent(
name=self._get_serializer_name(serializer, direction, bypass_extensions),
type=ResolvedComponent.SCHEMA,
object=serializer,
object=self.get_serializer_identity(serializer, direction),
)
if component in self.registry:
return self.registry[component] # return component with schema
Expand Down
31 changes: 25 additions & 6 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,17 @@ def ref(self) -> _SchemaType:
return {'$ref': f'#/components/{self.type}/{self.name}'}


class ComponentIdentity:
""" A container class to make object/component comparison explicit """
def __init__(self, obj):
self.obj = obj

def __eq__(self, other):
if isinstance(other, ComponentIdentity):
return self.obj == other.obj
return self.obj == other


class ComponentRegistry:
def __init__(self) -> None:
self._components: Dict[Tuple[str, str], ResolvedComponent] = {}
Expand All @@ -746,17 +757,25 @@ def __contains__(self, component):

query_obj = component.object
registry_obj = self._components[component.key].object
query_class = query_obj if inspect.isclass(query_obj) else query_obj.__class__
registry_class = query_obj if inspect.isclass(registry_obj) else registry_obj.__class__

if isinstance(query_obj, ComponentIdentity) or inspect.isclass(query_obj):
query_id = query_obj
else:
query_id = query_obj.__class__

if isinstance(registry_obj, ComponentIdentity) or inspect.isclass(registry_obj):
registry_id = registry_obj
else:
registry_id = registry_obj.__class__

suppress_collision_warning = (
get_override(registry_class, 'suppress_collision_warning', False)
or get_override(query_class, 'suppress_collision_warning', False)
get_override(registry_id, 'suppress_collision_warning', False)
or get_override(query_id, 'suppress_collision_warning', False)
)
if query_class != registry_class and not suppress_collision_warning:
if query_id != registry_id and not suppress_collision_warning:
warn(
f'Encountered 2 components with identical names "{component.name}" and '
f'different classes {query_class} and {registry_class}. This will very '
f'different identities {query_id} and {registry_id}. This will very '
f'likely result in an incorrect schema. Try renaming one.'
)
return True
Expand Down
46 changes: 46 additions & 0 deletions tests/contrib/test_rest_framework_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,49 @@ def custom_name_via_serializer_decoration(request):
generate_schema(None, patterns=urlpatterns),
'tests/contrib/test_rest_framework_dataclasses.yml'
)


@pytest.mark.contrib('rest_framework_dataclasses')
@pytest.mark.skipif(sys.version_info < (3, 7), reason='dataclass required by package')
def test_rest_framework_dataclasses_class_reuse(no_warnings):
from dataclasses import dataclass

from rest_framework_dataclasses.serializers import DataclassSerializer

@dataclass
class Person:
name: str
age: int

@dataclass
class Party:
person: Person
num_persons: int

class PartySerializer(DataclassSerializer[Party]):
class Meta:
dataclass = Party

class PersonSerializer(DataclassSerializer[Person]):
class Meta:
dataclass = Person

@extend_schema(responses=PartySerializer)
@api_view()
def party(request):
pass # pragma: no cover

@extend_schema(responses=PersonSerializer)
@api_view()
def person(request):
pass # pragma: no cover

urlpatterns = [
path('party', person),
path('person', party),
]

schema = generate_schema(None, patterns=urlpatterns)
# just existence is enough to check since its about no_warnings
assert 'Person' in schema['components']['schemas']
assert 'Party' in schema['components']['schemas']
2 changes: 1 addition & 1 deletion tests/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class X2Viewset(mixins.ListModelMixin, viewsets.GenericViewSet):
generate_schema(None, patterns=router.urls)

stderr = capsys.readouterr().err
assert 'Encountered 2 components with identical names "X" and different classes' in stderr
assert 'Encountered 2 components with identical names "X" and different identities' in stderr


def test_owned_serializer_naming_override_with_ref_name_collision(warnings):
Expand Down
Loading