Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(idempotency): fix output serialization when using Optional type #5590

Merged
merged 7 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BaseIdempotencyModelSerializer,
BaseIdempotencySerializer,
)
from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type

DataClass = Any

Expand All @@ -37,6 +38,9 @@ def from_dict(self, data: dict) -> DataClass:

@classmethod
def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer:

model_type = get_actual_type(model_type=model_type)

if model_type is None:
raise IdempotencyNoSerializationModelError("No serialization model was supplied")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sys
from typing import Any, Optional, Union, get_args, get_origin

# Conditionally import or define UnionType based on Python version
if sys.version_info >= (3, 10):
from types import UnionType # Available in Python 3.10+
else:
UnionType = Union # Fallback for Python 3.8 and 3.9

from aws_lambda_powertools.utilities.idempotency.exceptions import (
IdempotencyModelTypeError,
)


def get_actual_type(model_type: Any) -> Any:
"""
Extract the actual type from a potentially Optional or Union type.
This function handles types that may be wrapped in Optional or Union,
including the Python 3.10+ Union syntax (Type | None).
Parameters
----------
model_type: Any
The type to analyze. Can be a simple type, Optional[Type], BaseModel, dataclass
Returns
-------
The actual type without Optional or Union wrappers.
Raises:
IdempotencyModelTypeError: If the type specification is invalid
(e.g., Union with multiple non-None types).
"""

# Get the origin of the type (e.g., Union, Optional)
origin = get_origin(model_type)

# Check if type is Union, Optional, or UnionType (Python 3.10+)
if origin in (Union, Optional) or (sys.version_info >= (3, 10) and origin in (Union, UnionType)):
# Get type arguments
args = get_args(model_type)

# Filter out NoneType
actual_type = _extract_non_none_types(args)

# Ensure only one non-None type exists
if len(actual_type) != 1:
raise IdempotencyModelTypeError(

Check warning on line 45 in aws_lambda_powertools/utilities/idempotency/serialization/functions.py

View check run for this annotation

Codecov / codecov/patch

aws_lambda_powertools/utilities/idempotency/serialization/functions.py#L45

Added line #L45 was not covered by tests
"Invalid type: expected a single type, optionally wrapped in Optional or Union with None.",
)

return actual_type[0]

# If not a Union/Optional type, return original type
return model_type


def _extract_non_none_types(args: tuple) -> list:
"""Extract non-None types from type arguments."""
return [arg for arg in args if arg is not type(None)]
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BaseIdempotencyModelSerializer,
BaseIdempotencySerializer,
)
from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type


class PydanticSerializer(BaseIdempotencyModelSerializer):
Expand All @@ -34,6 +35,9 @@ def from_dict(self, data: dict) -> BaseModel:

@classmethod
def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer:

model_type = get_actual_type(model_type=model_type)

if model_type is None:
raise IdempotencyNoSerializationModelError("No serialization model was supplied")

Expand Down
5 changes: 4 additions & 1 deletion docs/utilities/idempotency.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ By default, `idempotent_function` serializes, stores, and returns your annotated

The output serializer supports any JSON serializable data, **Python Dataclasses** and **Pydantic Models**.

!!! info "When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string."
!!! info
When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string.

Function returns must be annotated with a single type, optionally wrapped in `Optional` or `Union` with `None`.

=== "Pydantic"

Expand Down
49 changes: 48 additions & 1 deletion tests/functional/idempotency/_boto3/test_idempotency.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import datetime
import warnings
from typing import Any
from typing import Any, Optional
from unittest.mock import MagicMock, Mock

import jmespath
Expand Down Expand Up @@ -2014,3 +2014,50 @@ def lambda_handler(event, context):

stubber.assert_no_pending_responses()
stubber.deactivate()


@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"])
def test_idempotent_function_serialization_dataclass_with_optional_return(output_serializer_type: str):
# GIVEN
dataclasses = get_dataclasses_lib()
config = IdempotencyConfig(use_local_cache=True)
mock_event = {"customer_id": "fake", "transaction_id": "fake-id"}
idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_dataclass_with_optional_return.<locals>.collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501
persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key)

@dataclasses.dataclass
class PaymentInput:
customer_id: str
transaction_id: str

@dataclasses.dataclass
class PaymentOutput:
customer_id: str
transaction_id: str

if output_serializer_type == "explicit":
output_serializer = DataclassSerializer(
model=PaymentOutput,
)
else:
output_serializer = DataclassSerializer

@idempotent_function(
data_keyword_argument="payment",
persistence_store=persistence_layer,
config=config,
output_serializer=output_serializer,
)
def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]:
return PaymentOutput(**dataclasses.asdict(payment))

# WHEN
payment = PaymentInput(**mock_event)
first_call: PaymentOutput = collect_payment(payment=payment)
assert first_call.customer_id == payment.customer_id
assert first_call.transaction_id == payment.transaction_id
assert isinstance(first_call, PaymentOutput)
second_call: PaymentOutput = collect_payment(payment=payment)
assert isinstance(second_call, PaymentOutput)
assert second_call.customer_id == payment.customer_id
assert second_call.transaction_id == payment.transaction_id
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import pytest
from pydantic import BaseModel

Expand Down Expand Up @@ -219,3 +221,47 @@ def collect_payment(payment: Payment):

# THEN idempotency key assertion happens at MockPersistenceLayer
assert result == payment.transaction_id


@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"])
def test_idempotent_function_serialization_pydantic_with_optional_return(output_serializer_type: str):
# GIVEN
config = IdempotencyConfig(use_local_cache=True)
mock_event = {"customer_id": "fake", "transaction_id": "fake-id"}
idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_with_optional_return.<locals>.collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501
persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key)

class PaymentInput(BaseModel):
customer_id: str
transaction_id: str

class PaymentOutput(BaseModel):
customer_id: str
transaction_id: str

if output_serializer_type == "explicit":
output_serializer = PydanticSerializer(
model=PaymentOutput,
)
else:
output_serializer = PydanticSerializer

@idempotent_function(
data_keyword_argument="payment",
persistence_store=persistence_layer,
config=config,
output_serializer=output_serializer,
)
def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]:
return PaymentOutput(**payment.dict())

# WHEN
payment = PaymentInput(**mock_event)
first_call: PaymentOutput = collect_payment(payment=payment)
assert first_call.customer_id == payment.customer_id
assert first_call.transaction_id == payment.transaction_id
assert isinstance(first_call, PaymentOutput)
second_call: PaymentOutput = collect_payment(payment=payment)
assert isinstance(second_call, PaymentOutput)
assert second_call.customer_id == payment.customer_id
assert second_call.transaction_id == payment.transaction_id
Loading