diff --git a/modelkit/core/model.py b/modelkit/core/model.py index d72aafbd..94150af5 100644 --- a/modelkit/core/model.py +++ b/modelkit/core/model.py @@ -28,7 +28,7 @@ from rich.markup import escape from rich.tree import Tree from structlog import get_logger -from typing_extensions import Protocol +from typing_extensions import Annotated, Protocol from modelkit.core import errors from modelkit.core.settings import LibrarySettings @@ -229,7 +229,16 @@ def initialize_validation_models(self): ] if len(generic_aliases): _item_type, _return_type = generic_aliases[0].__args__ - if _item_type != ItemType: + if ( + _item_type != ItemType + and not ( + typing.get_origin(_item_type) is Annotated + and isinstance( + typing.get_args(_item_type)[1], pydantic.SkipValidation + ) + ) + and not self.service_settings.disable_validation + ): self._item_type = _item_type type_name = self.__class__.__name__ + "ItemTypeModel" self._item_model = pydantic.create_model( @@ -240,7 +249,16 @@ def initialize_validation_models(self): data=(self._item_type, ...), __base__=InternalDataModel, ) - if _return_type != ReturnType: + if ( + _return_type != ReturnType + and not ( + typing.get_origin(_item_type) is Annotated + and isinstance( + typing.get_args(_item_type)[1], pydantic.SkipValidation + ) + ) + and not self.service_settings.disable_validation + ): self._return_type = _return_type type_name = self.__class__.__name__ + "ReturnTypeModel" self._return_model = pydantic.create_model( diff --git a/modelkit/core/settings.py b/modelkit/core/settings.py index 2e8d53be..ae90bc97 100644 --- a/modelkit/core/settings.py +++ b/modelkit/core/settings.py @@ -119,6 +119,12 @@ class LibrarySettings(ModelkitSettings): False, validation_alias=pydantic.AliasChoices("lazy_loading", "MODELKIT_LAZY_LOADING"), ) + disable_validation: bool = pydantic.Field( + False, + validation_alias=pydantic.AliasChoices( + "disable_validation", "MODELKIT_DISABLE_VALIDATION" + ), + ) override_assets_dir: Optional[str] = pydantic.Field( None, validation_alias=pydantic.AliasChoices( diff --git a/tests/test_validate.py b/tests/test_validate.py index 55e6b19a..859fba90 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -188,3 +188,48 @@ class ListModel(pydantic.BaseModel): raise ModelkitDataValidationException( "test error", pydantic_exc=exc ) from exc + + +def test_validate_skipped_annotated(): + class Example(pydantic.BaseModel): + data: str + + class ValidatedModel(Model[Example, Example]): + def _predict(self, item): + return item + + class NotValidatedModel( + Model[pydantic.SkipValidation[Example], pydantic.SkipValidation[Example]] + ): + def _predict(self, item): + return item + + validated_model = ValidatedModel(service_settings=LibrarySettings()) + not_validated_model = NotValidatedModel(service_settings=LibrarySettings()) + + example = Example(data="test") + + assert validated_model(example) == example + assert validated_model({"data": "test"}) == example + with pytest.raises(ItemValidationException): + validated_model({"data": 1}) + + assert not_validated_model(example) == example + assert not_validated_model({"data": "test"}) == {"data": "test"} + assert not_validated_model({"data": 1}) == {"data": 1} + + +def test_validate_skipped_library_settings(): + class Example(pydantic.BaseModel): + data: str + + class MyModel(Model[Example, Example]): + def _predict(self, item): + return item + + model = MyModel(service_settings=LibrarySettings(disable_validation=True)) + example = Example(data="test") + + assert model(example) == example + assert model({"data": "test"}) == {"data": "test"} + assert model({"data": 1}) == {"data": 1}