From c99f1bc006ca9410f5c82e7fc45d33cb025436d9 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Thu, 16 Nov 2023 15:31:56 -0300 Subject: [PATCH] Fix optional field conversion to str --- pyproject.toml | 2 +- retrack/engine/request_manager.py | 23 ++++++++++++++++++++++- tests/test_engine/test_request_manager.py | 8 ++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ba6cf5d..1ad71ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "retrack" -version = "1.0.0" +version = "1.0.1" description = "A business rules engine" authors = ["Gabriel Guarisa ", "Nathalia Trotte "] license = "MIT" diff --git a/retrack/engine/request_manager.py b/retrack/engine/request_manager.py index 7facb7f..f3d2222 100644 --- a/retrack/engine/request_manager.py +++ b/retrack/engine/request_manager.py @@ -7,6 +7,20 @@ from retrack.nodes.base import BaseNode, NodeKind +class StrFieldValidator: + def __init__(self, default: typing.Optional[typing.Any] = None): + self.default = default + + def __call__(self, value: typing.Any) -> typing.Any: + if value in [None, "None", "", "null"]: + if self.default is None: + raise ValueError("value cannot be None") + else: + return self.default + + return str(value) if type(value) != str else value + + class RequestManager: def __init__(self, inputs: typing.List[BaseNode]): self._model = None @@ -71,11 +85,18 @@ def __create_model( fields = {} for input_field in self.inputs: fields[input_field.data.name] = ( - typing.Annotated[str, pydantic.BeforeValidator(str)], + typing.Annotated[ + str if input_field.data.default is None else typing.Optional[str], + pydantic.BeforeValidator( + StrFieldValidator(input_field.data.default) + ), + ], pydantic.Field( default=Ellipsis if input_field.data.default is None else input_field.data.default, + optional=input_field.data.default is not None, + validate_default=False, ), ) diff --git a/tests/test_engine/test_request_manager.py b/tests/test_engine/test_request_manager.py index 3d30ac3..71211e5 100644 --- a/tests/test_engine/test_request_manager.py +++ b/tests/test_engine/test_request_manager.py @@ -55,3 +55,11 @@ def test_validate_dict_with_model(valid_input_dict_before_validation): assert issubclass(rm.model, pydantic.BaseModel) assert rm.model.model_validate({"example": 1111}) == rm.model(example="1111") + + +def test_validate_dict_with_none_value(valid_input_dict_before_validation): + rm = RequestManager([Input(**valid_input_dict_before_validation)]) + + assert issubclass(rm.model, pydantic.BaseModel) + assert rm.model(example=None) == rm.model(example="Hello World") + assert rm.model() == rm.model(example="Hello World")