diff --git a/README.md b/README.md index f7dd9e3..16b1f0a 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,33 @@ class Person(BaseModel): ### Custom Serializer +In example below, `Person` will use `MyCustomSerializer` as its drf serializer. +`Employee` will have its own serializer generated by `drf_pydantic` because it +does not have a user-defined `drf_serializer` attribute (it's never inherited). +`Company` will have its own serializer generated by `drf_pydantic` and it will use +`Person`'s manually-defined serializer for its `ceo` field. + ```python -# TODO +from drf_pydantic import BaseModel +from rest_framework.serializers import Serializer + + +class MyCustomSerializer(Serializer): + name = CharField(allow_null=False, required=True) + age = IntegerField(allow_null=False, required=True) + + +class Person(BaseModel): + name: str + age: float + + drf_serializer = MyCustomSerializer + + +class Employee(Person): + salary: float + + +class Company(BaseModel): + ceo: Person ``` diff --git a/src/drf_pydantic/base_model.py b/src/drf_pydantic/base_model.py index 2190d74..c3669d5 100644 --- a/src/drf_pydantic/base_model.py +++ b/src/drf_pydantic/base_model.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import Any, ClassVar, Optional import pydantic @@ -34,11 +34,19 @@ def __new__( _create_model_module, **kwargs, ) - setattr(cls, "drf_serializer", create_serializer_from_model(cls)) + # Create serializer only if it's not already set by the user + # Serializer should never be inherited from the parent classes + if not hasattr(cls, "drf_serializer") or getattr(cls, "drf_serializer") in ( + getattr(base, "drf_serializer", None) for base in cls.__mro__[1:] + ): + setattr( + cls, + "drf_serializer", + create_serializer_from_model(cls), + ) return cls class BaseModel(pydantic.BaseModel, metaclass=ModelMetaclass): - if TYPE_CHECKING: - # Populated by the metaclass, defined here to help IDEs only - drf_serializer: ClassVar[type[serializers.Serializer]] + # Populated by the metaclass or manually set by the user + drf_serializer: ClassVar[type[serializers.Serializer]] diff --git a/tests/test_models.py b/tests/test_models.py index 0064b95..b311f23 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -166,3 +166,87 @@ class Building(BaseModel): assert len(apartment.fields) == 2 assert isinstance(apartment.fields["floor"], serializers.IntegerField) assert isinstance(apartment.fields["owner"], serializers.CharField) + + +def test_manual_serializer(): + class MyCustomSerializer(serializers.Serializer): + gender = serializers.ChoiceField(choices=["male", "female"]) + title = serializers.CharField() + peers = serializers.ListField(child=serializers.IntegerField()) + + class Person(BaseModel): + name: str + age: int + + drf_serializer = MyCustomSerializer + + serializer = Person.drf_serializer() + assert serializer.__class__.__name__ == "MyCustomSerializer" + assert len(serializer.fields) == 3 + assert isinstance(serializer.fields["gender"], serializers.ChoiceField) + assert isinstance(serializer.fields["title"], serializers.CharField) + assert isinstance(serializer.fields["peers"], serializers.ListField) + + +def test_manual_serializer_inheritance(): + """Ensure that manual serializer is not inherited from the parent class.""" + + class MyCustomSerializer(serializers.Serializer): + gender = serializers.ChoiceField(choices=["male", "female"]) + title = serializers.CharField() + peers = serializers.ListField(child=serializers.IntegerField()) + + class Person(BaseModel): + name: str + age: int + + drf_serializer = MyCustomSerializer + + class Employee(Person): + salary: float + office: str + + person_serializer = Person.drf_serializer() + assert person_serializer.__class__.__name__ == "MyCustomSerializer" + assert len(person_serializer.fields) == 3 + assert isinstance(person_serializer.fields["gender"], serializers.ChoiceField) + assert isinstance(person_serializer.fields["title"], serializers.CharField) + assert isinstance(person_serializer.fields["peers"], serializers.ListField) + + employee_serializer = Employee.drf_serializer() + assert employee_serializer.__class__.__name__ == "EmployeeSerializer" + assert len(employee_serializer.fields) == 4 + assert isinstance(employee_serializer.fields["name"], serializers.CharField) + assert isinstance(employee_serializer.fields["age"], serializers.IntegerField) + assert isinstance(employee_serializer.fields["salary"], serializers.FloatField) + assert isinstance(employee_serializer.fields["office"], serializers.CharField) + + +def test_nested_manual_serializer(): + class MyCustomSerializer(serializers.Serializer): + gender = serializers.ChoiceField(choices=["male", "female"]) + title = serializers.CharField() + peers = serializers.ListField(child=serializers.IntegerField()) + + class Job(BaseModel): + title: str + salary: float + + drf_serializer = MyCustomSerializer + + class Person(BaseModel): + name: str + job: Job + + serializer = Person.drf_serializer() + assert serializer.__class__.__name__ == "PersonSerializer" + assert len(serializer.fields) == 2 + assert isinstance(serializer.fields["name"], serializers.CharField) + assert isinstance(serializer.fields["job"], serializers.Serializer) + + job_serializer = serializer.fields["job"] + assert job_serializer.__class__.__name__ == "MyCustomSerializer" + assert len(job_serializer.fields) == 3 + assert isinstance(job_serializer.fields["gender"], serializers.ChoiceField) + assert isinstance(job_serializer.fields["title"], serializers.CharField) + assert isinstance(job_serializer.fields["peers"], serializers.ListField)