diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py b/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py index be5d7007ef3..fc8b72252c0 100644 --- a/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py +++ b/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py @@ -11,6 +11,7 @@ BaseIdempotencyModelSerializer, BaseIdempotencySerializer, ) +from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type DataClass = Any @@ -37,6 +38,9 @@ def from_dict(self, data: dict) -> DataClass: @classmethod def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer: + + model_type = get_actual_type(model_type=model_type) + if model_type is None: raise IdempotencyNoSerializationModelError("No serialization model was supplied") diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py new file mode 100644 index 00000000000..72a8d6940c9 --- /dev/null +++ b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py @@ -0,0 +1,57 @@ +import sys +from typing import Any, Optional, Union, get_args, get_origin + +# Conditionally import or define UnionType based on Python version +if sys.version_info >= (3, 10): + from types import UnionType # Available in Python 3.10+ +else: + UnionType = Union # Fallback for Python 3.8 and 3.9 + +from aws_lambda_powertools.utilities.idempotency.exceptions import ( + IdempotencyModelTypeError, +) + + +def get_actual_type(model_type: Any) -> Any: + """ + Extract the actual type from a potentially Optional or Union type. + This function handles types that may be wrapped in Optional or Union, + including the Python 3.10+ Union syntax (Type | None). + Parameters + ---------- + model_type: Any + The type to analyze. Can be a simple type, Optional[Type], BaseModel, dataclass + Returns + ------- + The actual type without Optional or Union wrappers. + Raises: + IdempotencyModelTypeError: If the type specification is invalid + (e.g., Union with multiple non-None types). + """ + + # Get the origin of the type (e.g., Union, Optional) + origin = get_origin(model_type) + + # Check if type is Union, Optional, or UnionType (Python 3.10+) + if origin in (Union, Optional) or (sys.version_info >= (3, 10) and origin in (Union, UnionType)): + # Get type arguments + args = get_args(model_type) + + # Filter out NoneType + actual_type = _extract_non_none_types(args) + + # Ensure only one non-None type exists + if len(actual_type) != 1: + raise IdempotencyModelTypeError( + "Invalid type: expected a single type, optionally wrapped in Optional or Union with None.", + ) + + return actual_type[0] + + # If not a Union/Optional type, return original type + return model_type + + +def _extract_non_none_types(args: tuple) -> list: + """Extract non-None types from type arguments.""" + return [arg for arg in args if arg is not type(None)] diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py b/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py index 42ae179833f..8ba45a40583 100644 --- a/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py +++ b/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py @@ -12,6 +12,7 @@ BaseIdempotencyModelSerializer, BaseIdempotencySerializer, ) +from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type class PydanticSerializer(BaseIdempotencyModelSerializer): @@ -34,6 +35,9 @@ def from_dict(self, data: dict) -> BaseModel: @classmethod def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer: + + model_type = get_actual_type(model_type=model_type) + if model_type is None: raise IdempotencyNoSerializationModelError("No serialization model was supplied") diff --git a/docs/utilities/idempotency.md b/docs/utilities/idempotency.md index f263aa1cb6e..cfe85877961 100644 --- a/docs/utilities/idempotency.md +++ b/docs/utilities/idempotency.md @@ -212,7 +212,10 @@ By default, `idempotent_function` serializes, stores, and returns your annotated The output serializer supports any JSON serializable data, **Python Dataclasses** and **Pydantic Models**. -!!! info "When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string." +!!! info + When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string. + + Function returns must be annotated with a single type, optionally wrapped in `Optional` or `Union` with `None`. === "Pydantic" diff --git a/tests/functional/idempotency/_boto3/test_idempotency.py b/tests/functional/idempotency/_boto3/test_idempotency.py index 35f82333e9c..f2214e2fd65 100644 --- a/tests/functional/idempotency/_boto3/test_idempotency.py +++ b/tests/functional/idempotency/_boto3/test_idempotency.py @@ -1,7 +1,7 @@ import copy import datetime import warnings -from typing import Any +from typing import Any, Optional from unittest.mock import MagicMock, Mock import jmespath @@ -2014,3 +2014,50 @@ def lambda_handler(event, context): stubber.assert_no_pending_responses() stubber.deactivate() + + +@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"]) +def test_idempotent_function_serialization_dataclass_with_optional_return(output_serializer_type: str): + # GIVEN + dataclasses = get_dataclasses_lib() + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_dataclass_with_optional_return..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + @dataclasses.dataclass + class PaymentInput: + customer_id: str + transaction_id: str + + @dataclasses.dataclass + class PaymentOutput: + customer_id: str + transaction_id: str + + if output_serializer_type == "explicit": + output_serializer = DataclassSerializer( + model=PaymentOutput, + ) + else: + output_serializer = DataclassSerializer + + @idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=output_serializer, + ) + def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]: + return PaymentOutput(**dataclasses.asdict(payment)) + + # WHEN + payment = PaymentInput(**mock_event) + first_call: PaymentOutput = collect_payment(payment=payment) + assert first_call.customer_id == payment.customer_id + assert first_call.transaction_id == payment.transaction_id + assert isinstance(first_call, PaymentOutput) + second_call: PaymentOutput = collect_payment(payment=payment) + assert isinstance(second_call, PaymentOutput) + assert second_call.customer_id == payment.customer_id + assert second_call.transaction_id == payment.transaction_id diff --git a/tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py b/tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py index aaac5948e63..f8e3debbc30 100644 --- a/tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py +++ b/tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest from pydantic import BaseModel @@ -219,3 +221,47 @@ def collect_payment(payment: Payment): # THEN idempotency key assertion happens at MockPersistenceLayer assert result == payment.transaction_id + + +@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"]) +def test_idempotent_function_serialization_pydantic_with_optional_return(output_serializer_type: str): + # GIVEN + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_with_optional_return..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + class PaymentInput(BaseModel): + customer_id: str + transaction_id: str + + class PaymentOutput(BaseModel): + customer_id: str + transaction_id: str + + if output_serializer_type == "explicit": + output_serializer = PydanticSerializer( + model=PaymentOutput, + ) + else: + output_serializer = PydanticSerializer + + @idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=output_serializer, + ) + def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]: + return PaymentOutput(**payment.dict()) + + # WHEN + payment = PaymentInput(**mock_event) + first_call: PaymentOutput = collect_payment(payment=payment) + assert first_call.customer_id == payment.customer_id + assert first_call.transaction_id == payment.transaction_id + assert isinstance(first_call, PaymentOutput) + second_call: PaymentOutput = collect_payment(payment=payment) + assert isinstance(second_call, PaymentOutput) + assert second_call.customer_id == payment.customer_id + assert second_call.transaction_id == payment.transaction_id