From 6066e829af8a9ad06ba7a446b72e7abd58c878e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20=C3=98stvik?= Date: Sun, 1 Dec 2024 15:58:45 +0100 Subject: [PATCH] correctly set required and nullable json schema values --- pydantic_partial/_compat.py | 10 +++++++--- pydantic_partial/partial.py | 2 +- tests/test_partial_without_mixin.py | 13 ++++++++++++- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/pydantic_partial/_compat.py b/pydantic_partial/_compat.py index 4ca0517..a768ac6 100644 --- a/pydantic_partial/_compat.py +++ b/pydantic_partial/_compat.py @@ -14,7 +14,7 @@ if PYDANTIC_V1: # pragma: no cover from pydantic.fields import ModelField # type: ignore - NULLABLE_KWARGS = {"nullable": True} + NULLABLE_KWARGS = {"nullable": True, "required": False} class PydanticCompat: # type: ignore model_class: type[pydantic.BaseModel] @@ -46,7 +46,7 @@ def copy_model_field_info(self, model_field: ModelField, **kwargs: Any) -> Field return copy_field_info(model_field.field_info, **kwargs) elif PYDANTIC_V2: # pragma: no cover - NULLABLE_KWARGS = {"json_schema_extra": {"nullable": True}} + NULLABLE_KWARGS = {"json_schema_extra": {"nullable": True, "required": False}} class PydanticCompat: # type: ignore model_class: type[pydantic.BaseModel] @@ -65,7 +65,11 @@ def get_model_field_info_annotation(self, field_info: FieldInfo) -> Optional[typ return field_info.annotation def is_model_field_info_required(self, field_info: FieldInfo) -> bool: - return field_info.is_required() # type: ignore + json_required = ( + field_info.json_schema_extra is not None + and field_info.json_schema_extra.get("required", False) + ) + return field_info.is_required() or json_required # type: ignore def copy_model_field_info(self, field_info: FieldInfo, **kwargs: Any) -> FieldInfo: return copy_field_info(field_info, **kwargs) diff --git a/pydantic_partial/partial.py b/pydantic_partial/partial.py index 7e6868c..e003ee0 100644 --- a/pydantic_partial/partial.py +++ b/pydantic_partial/partial.py @@ -110,7 +110,7 @@ def _partial_annotation_arg(field_name_: str, field_annotation: type) -> type: field_info, default=None, # Set default to None default_factory=None, # Remove default_factory if set - **NULLABLE_KWARGS, # For API usage: set field as nullable + **NULLABLE_KWARGS, # For API usage: set field as nullable and not required ), ) elif recursive or sub_fields_requested: diff --git a/tests/test_partial_without_mixin.py b/tests/test_partial_without_mixin.py index 783b84a..52876dc 100644 --- a/tests/test_partial_without_mixin.py +++ b/tests/test_partial_without_mixin.py @@ -12,7 +12,11 @@ def _field_is_required(model: Union[type[pydantic.BaseModel], pydantic.BaseModel elif PYDANTIC_V2: def _field_is_required(model: Union[type[pydantic.BaseModel], pydantic.BaseModel], field_name: str) -> bool: """Check if a field is required on a pydantic V2 model.""" - return model.model_fields[field_name].is_required() + json_required = ( + model.model_fields[field_name].json_schema_extra is not None + and model.model_fields[field_name].json_schema_extra.get("required", False) + ) + return model.model_fields[field_name].is_required() or json_required else: raise DeprecationWarning("Pydantic has to be in version 1 or 2.") @@ -21,6 +25,7 @@ class Something(pydantic.BaseModel): name: str age: int already_optional: None = None + already_required: int = pydantic.Field(default=1, json_schema_extra={"required": True}) class SomethingWithMixin(PartialModelMixin, pydantic.BaseModel): @@ -61,3 +66,9 @@ def test_partial_model_will_be_the_same_on_mixin(): SomethingWithMixinPartial2 = SomethingWithMixin.model_as_partial() assert SomethingWithMixinPartial1 is SomethingWithMixinPartial2 + +def test_partial_model_will_override_json_required(): + SomethingPartial = create_partial_model(Something) + assert _field_is_required(SomethingPartial, "already_required") is False + SomethingPartial.model_json_schema()["properties"]["already_required"]["nullable"] is True + SomethingPartial.model_json_schema()["properties"]["already_required"]["required"] is False