diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml index 17aa08ead81..7af7bb8e4ba 100644 --- a/.github/workflows/python_build.yml +++ b/.github/workflows/python_build.yml @@ -30,6 +30,8 @@ jobs: run: make dev - name: Formatting and Linting run: make lint + - name: Static type checking + run: make mypy - name: Test with pytest run: make test - name: Security baseline diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0941fbc535b..060726ec11a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -49,13 +49,14 @@ You might find useful to run both the documentation website and the API referenc Category | Convention ------------------------------------------------- | --------------------------------------------------------------------------------- -**Docstring** | We use a slight variation of numpy convention with markdown to help generate more readable API references. -**Style guide** | We use black as well as flake8 extensions to enforce beyond good practices [PEP8](https://pep8.org/). We strive to make use of type annotation as much as possible, but don't overdo in creating custom types. +**Docstring** | We use a slight variation of Numpy convention with markdown to help generate more readable API references. +**Style guide** | We use black as well as flake8 extensions to enforce beyond good practices [PEP8](https://pep8.org/). We use type annotations and enforce static type checking at CI (mypy). **Core utilities** | Core utilities use a Class, always accept `service` as a constructor parameter, can work in isolation, and are also available in other languages implementation. **Utilities** | Utilities are not as strict as core and focus on solving a developer experience problem while following the project [Tenets](https://awslabs.github.io/aws-lambda-powertools-python/#tenets). **Exceptions** | Specific exceptions live within utilities themselves and use `Error` suffix e.g. `MetricUnitError`. -**Git commits** | We follow [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/). These are not enforced as we squash and merge PRs, but PR titles are enforced during CI. -**Documentation** | API reference docs are generated from docstrings which should have Examples section to allow developers to have what they need within their own IDE. Documentation website covers the wider usage, tips, and strive to be concise. +**Git commits** | We follow [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/). We do not enforce conventional commits on contributors to lower the entry bar. Instead, we enforce a conventional PR title so our label automation and changelog are generated correctly. +**API documentation** | API reference docs are generated from docstrings which should have Examples section to allow developers to have what they need within their own IDE. Documentation website covers the wider usage, tips, and strive to be concise. +**Documentation** | We treat it like a product. We sub-divide content aimed at getting started (80% of customers) vs advanced usage (20%). We also ensure customers know how to unit test their code when using our features. ## Finding contributions to work on diff --git a/Makefile b/Makefile index 5b8e9b0d689..fc350ac6923 100644 --- a/Makefile +++ b/Makefile @@ -29,7 +29,7 @@ coverage-html: pre-commit: pre-commit run --show-diff-on-failure -pr: lint pre-commit test security-baseline complexity-baseline +pr: lint mypy pre-commit test security-baseline complexity-baseline build: pr poetry build diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5017597c0f1..f423f1291fa 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -10,7 +10,7 @@ from enum import Enum from functools import partial from http import HTTPStatus -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError @@ -167,7 +167,7 @@ class Route: """Internally used Route Configuration""" def __init__( - self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] + self, method: str, rule: Pattern, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] ): self.method = method.upper() self.rule = rule @@ -555,7 +555,7 @@ def _resolve(self) -> ResponseBuilder: for route in self._routes: if method != route.method: continue - match_results: Optional[re.Match] = route.rule.match(path) + match_results: Optional[Match] = route.rule.match(path) if match_results: logger.debug("Found a registered route. Calling function") return self._call_route(route, match_results.groupdict()) # pass fn args diff --git a/aws_lambda_powertools/metrics/metrics.py b/aws_lambda_powertools/metrics/metrics.py index 927b6873648..00e083d4a7f 100644 --- a/aws_lambda_powertools/metrics/metrics.py +++ b/aws_lambda_powertools/metrics/metrics.py @@ -53,7 +53,7 @@ def lambda_handler(): ---------- service : str, optional service name to be used as metric dimension, by default "service_undefined" - namespace : str + namespace : str, optional Namespace for metrics Raises @@ -209,5 +209,6 @@ def __add_cold_start_metric(self, context: Any) -> None: logger.debug("Adding cold start metric and function_name dimension") with single_metric(name="ColdStart", unit=MetricUnit.Count, value=1, namespace=self.namespace) as metric: metric.add_dimension(name="function_name", value=context.function_name) - metric.add_dimension(name="service", value=self.service) + if self.service: + metric.add_dimension(name="service", value=str(self.service)) is_cold_start = False diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index 11c4e4ce77c..37621f8274a 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Optional, Union def strtobool(value: str) -> bool: @@ -38,21 +38,23 @@ def resolve_truthy_env_var_choice(env: str, choice: Optional[bool] = None) -> bo return choice if choice is not None else strtobool(env) -def resolve_env_var_choice(env: Any, choice: Optional[Any] = None) -> Union[bool, Any]: +def resolve_env_var_choice( + env: Optional[str] = None, choice: Optional[Union[str, float]] = None +) -> Optional[Union[str, float]]: """Pick explicit choice over env, if available, otherwise return env value received NOTE: Environment variable should be resolved by the caller. Parameters ---------- - env : Any + env : str, Optional environment variable actual value - choice : bool + choice : str|float, optional explicit choice Returns ------- - choice : str + choice : str, Optional resolved choice as either bool or environment value """ return choice if choice is not None else env diff --git a/aws_lambda_powertools/utilities/batch/exceptions.py b/aws_lambda_powertools/utilities/batch/exceptions.py index dc4ca300c7c..d90c25f12bc 100644 --- a/aws_lambda_powertools/utilities/batch/exceptions.py +++ b/aws_lambda_powertools/utilities/batch/exceptions.py @@ -5,7 +5,7 @@ from types import TracebackType from typing import List, Optional, Tuple, Type -ExceptionInfo = Tuple[Type[BaseException], BaseException, TracebackType] +ExceptionInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]] class BaseBatchProcessingError(Exception): diff --git a/aws_lambda_powertools/utilities/batch/sqs.py b/aws_lambda_powertools/utilities/batch/sqs.py index 38773a399dd..411e400615d 100644 --- a/aws_lambda_powertools/utilities/batch/sqs.py +++ b/aws_lambda_powertools/utilities/batch/sqs.py @@ -5,11 +5,13 @@ """ import logging import sys -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, cast import boto3 from botocore.config import Config +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + from ...middleware_factory import lambda_handler_decorator from .base import BasePartialProcessor from .exceptions import SQSBatchProcessingError @@ -84,11 +86,17 @@ def _get_queue_url(self) -> Optional[str]: *_, account_id, queue_name = self.records[0]["eventSourceARN"].split(":") return f"{self.client._endpoint.host}/{account_id}/{queue_name}" - def _get_entries_to_clean(self) -> List: + def _get_entries_to_clean(self) -> List[Dict[str, str]]: """ Format messages to use in batch deletion """ - return [{"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]} for msg in self.success_messages] + entries = [] + # success_messages has generic type of union of SQS, Dynamodb and Kinesis Streams records or Pydantic models. + # Here we get SQS Record only + messages = cast(List[SQSRecord], self.success_messages) + for msg in messages: + entries.append({"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]}) + return entries def _process_record(self, record) -> Tuple: """ diff --git a/aws_lambda_powertools/utilities/idempotency/idempotency.py b/aws_lambda_powertools/utilities/idempotency/idempotency.py index 42b8052fd32..4a7d8e71e1d 100644 --- a/aws_lambda_powertools/utilities/idempotency/idempotency.py +++ b/aws_lambda_powertools/utilities/idempotency/idempotency.py @@ -112,7 +112,7 @@ def process_order(customer_id: str, order: dict, **kwargs): return {"StatusCode": 200} """ - if function is None: + if not function: return cast( AnyCallableT, functools.partial( @@ -132,7 +132,7 @@ def decorate(*args, **kwargs): payload = kwargs.get(data_keyword_argument) - if payload is None: + if not payload: raise RuntimeError( f"Unable to extract '{data_keyword_argument}' from keyword arguments." f" Ensure this exists in your function's signature as well as the caller used it as a keyword argument" diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/base.py b/aws_lambda_powertools/utilities/idempotency/persistence/base.py index b07662e6432..e6ffea10de8 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/base.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/base.py @@ -92,16 +92,16 @@ def status(self) -> str: else: raise IdempotencyInvalidStatusError(self._status) - def response_json_as_dict(self) -> dict: + def response_json_as_dict(self) -> Optional[dict]: """ Get response data deserialized to python dict Returns ------- - dict + Optional[dict] previous response data deserialized """ - return json.loads(self.response_data) + return json.loads(self.response_data) if self.response_data else None class BasePersistenceLayer(ABC): @@ -121,7 +121,6 @@ def __init__(self): self.raise_on_no_idempotency_key = False self.expires_after_seconds: int = 60 * 60 # 1 hour default self.use_local_cache = False - self._cache: Optional[LRUDict] = None self.hash_function = None def configure(self, config: IdempotencyConfig, function_name: Optional[str] = None) -> None: diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index b059a3b2483..7e8588eb895 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -44,7 +44,7 @@ def get( transform: Optional[str] = None, force_fetch: bool = False, **sdk_options, - ) -> Union[str, list, dict, bytes]: + ) -> Optional[Union[str, dict, bytes]]: """ Retrieve a parameter value or return the cached value @@ -81,6 +81,7 @@ def get( # of supported transform is small and the probability that a given # parameter will always be used in a specific transform, this should be # an acceptable tradeoff. + value: Optional[Union[str, bytes, dict]] = None key = (name, transform) if not force_fetch and self._has_not_expired(key): @@ -92,7 +93,7 @@ def get( except Exception as exc: raise GetParameterError(str(exc)) - if transform is not None: + if transform: if isinstance(value, bytes): value = value.decode("utf-8") value = transform_value(value, transform) @@ -146,26 +147,25 @@ def get_multiple( TransformParameterError When the parameter provider fails to transform a parameter value. """ - key = (path, transform) if not force_fetch and self._has_not_expired(key): return self.store[key].value try: - values: Dict[str, Union[str, bytes, dict, None]] = self._get_multiple(path, **sdk_options) + values = self._get_multiple(path, **sdk_options) # Encapsulate all errors into a generic GetParameterError except Exception as exc: raise GetParameterError(str(exc)) - if transform is not None: - for (key, value) in values.items(): - _transform = get_transform_method(key, transform) - if _transform is None: + if transform: + transformed_values: dict = {} + for (item, value) in values.items(): + _transform = get_transform_method(item, transform) + if not _transform: continue - - values[key] = transform_value(value, _transform, raise_on_transform_error) - + transformed_values[item] = transform_value(value, _transform, raise_on_transform_error) + values.update(transformed_values) self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age)) return values @@ -217,7 +217,9 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[ return None -def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]: +def transform_value( + value: str, transform: str, raise_on_transform_error: Optional[bool] = True +) -> Optional[Union[dict, bytes]]: """ Apply a transform to a value diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 4cbb16354c7..fd55e40a95f 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -87,7 +87,9 @@ def __init__(self, config: Optional[Config] = None, boto3_session: Optional[boto super().__init__() - def get( + # We break Liskov substitution principle due to differences in signatures of this method and superclass get method + # We ignore mypy error, as changes to the signature here or in a superclass is a breaking change to users + def get( # type: ignore[override] self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, @@ -95,7 +97,7 @@ def get( decrypt: bool = False, force_fetch: bool = False, **sdk_options - ) -> Union[str, list, dict, bytes]: + ) -> Optional[Union[str, dict, bytes]]: """ Retrieve a parameter value or return the cached value diff --git a/aws_lambda_powertools/utilities/parser/envelopes/apigw.py b/aws_lambda_powertools/utilities/parser/envelopes/apigw.py index 6b74a3037e9..a9af93e9b9c 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/apigw.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/apigw.py @@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model]) Parsed detail payload with model provided """ logger.debug(f"Parsing incoming data with Api Gateway model {APIGatewayProxyEventModel}") - parsed_envelope = APIGatewayProxyEventModel.parse_obj(data) + parsed_envelope: APIGatewayProxyEventModel = APIGatewayProxyEventModel.parse_obj(data) logger.debug(f"Parsing event payload in `detail` with {model}") return self._parse(data=parsed_envelope.body, model=model) diff --git a/aws_lambda_powertools/utilities/parser/envelopes/apigwv2.py b/aws_lambda_powertools/utilities/parser/envelopes/apigwv2.py index a627e4da0e5..336645a2b73 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/apigwv2.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/apigwv2.py @@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model]) Parsed detail payload with model provided """ logger.debug(f"Parsing incoming data with Api Gateway model V2 {APIGatewayProxyEventV2Model}") - parsed_envelope = APIGatewayProxyEventV2Model.parse_obj(data) + parsed_envelope: APIGatewayProxyEventV2Model = APIGatewayProxyEventV2Model.parse_obj(data) logger.debug(f"Parsing event payload in `detail` with {model}") return self._parse(data=parsed_envelope.body, model=model) diff --git a/aws_lambda_powertools/utilities/parser/envelopes/event_bridge.py b/aws_lambda_powertools/utilities/parser/envelopes/event_bridge.py index ad1df09a65e..239bfd72025 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/event_bridge.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/event_bridge.py @@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model]) Parsed detail payload with model provided """ logger.debug(f"Parsing incoming data with EventBridge model {EventBridgeModel}") - parsed_envelope = EventBridgeModel.parse_obj(data) + parsed_envelope: EventBridgeModel = EventBridgeModel.parse_obj(data) logger.debug(f"Parsing event payload in `detail` with {model}") return self._parse(data=parsed_envelope.detail, model=model) diff --git a/aws_lambda_powertools/utilities/parser/envelopes/kinesis.py b/aws_lambda_powertools/utilities/parser/envelopes/kinesis.py index 9db8e4450f2..9ff221a7b7b 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/kinesis.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/kinesis.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union, cast from ..models import KinesisDataStreamModel from ..types import Model @@ -37,6 +37,9 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model]) logger.debug(f"Parsing incoming data with Kinesis model {KinesisDataStreamModel}") parsed_envelope: KinesisDataStreamModel = KinesisDataStreamModel.parse_obj(data) logger.debug(f"Parsing Kinesis records in `body` with {model}") - return [ - self._parse(data=record.kinesis.data.decode("utf-8"), model=model) for record in parsed_envelope.Records - ] + models = [] + for record in parsed_envelope.Records: + # We allow either AWS expected contract (bytes) or a custom Model, see #943 + data = cast(bytes, record.kinesis.data) + models.append(self._parse(data=data.decode("utf-8"), model=model)) + return models diff --git a/aws_lambda_powertools/utilities/parser/envelopes/sns.py b/aws_lambda_powertools/utilities/parser/envelopes/sns.py index d4fa7c2f663..50b9d406c23 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/sns.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/sns.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union, cast from ..models import SnsModel, SnsNotificationModel, SqsModel from ..types import Model @@ -69,6 +69,8 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model]) parsed_envelope = SqsModel.parse_obj(data) output = [] for record in parsed_envelope.Records: - sns_notification = SnsNotificationModel.parse_raw(record.body) + # We allow either AWS expected contract (str) or a custom Model, see #943 + body = cast(str, record.body) + sns_notification = SnsNotificationModel.parse_raw(body) output.append(self._parse(data=sns_notification.Message, model=model)) return output diff --git a/aws_lambda_powertools/utilities/parser/models/alb.py b/aws_lambda_powertools/utilities/parser/models/alb.py index 1112d0c04e4..d903e9f0fd8 100644 --- a/aws_lambda_powertools/utilities/parser/models/alb.py +++ b/aws_lambda_powertools/utilities/parser/models/alb.py @@ -1,9 +1,7 @@ -from typing import Dict, Union +from typing import Dict, Type, Union from pydantic import BaseModel -from aws_lambda_powertools.utilities.parser.types import Model - class AlbRequestContextData(BaseModel): targetGroupArn: str @@ -16,7 +14,7 @@ class AlbRequestContext(BaseModel): class AlbModel(BaseModel): httpMethod: str path: str - body: Union[str, Model] + body: Union[str, Type[BaseModel]] isBase64Encoded: bool headers: Dict[str, str] queryStringParameters: Dict[str, str] diff --git a/aws_lambda_powertools/utilities/parser/models/apigw.py b/aws_lambda_powertools/utilities/parser/models/apigw.py index ce519b8e0e3..78b40cd2c0c 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigw.py +++ b/aws_lambda_powertools/utilities/parser/models/apigw.py @@ -1,10 +1,10 @@ from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union from pydantic import BaseModel, root_validator from pydantic.networks import IPvAnyNetwork -from aws_lambda_powertools.utilities.parser.types import Literal, Model +from aws_lambda_powertools.utilities.parser.types import Literal class ApiGatewayUserCertValidity(BaseModel): @@ -89,4 +89,4 @@ class APIGatewayProxyEventModel(BaseModel): pathParameters: Optional[Dict[str, str]] stageVariables: Optional[Dict[str, str]] isBase64Encoded: bool - body: Optional[Union[str, Model]] + body: Optional[Union[str, Type[BaseModel]]] diff --git a/aws_lambda_powertools/utilities/parser/models/apigwv2.py b/aws_lambda_powertools/utilities/parser/models/apigwv2.py index ddaf2d7ef82..f97dad3bcb0 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigwv2.py +++ b/aws_lambda_powertools/utilities/parser/models/apigwv2.py @@ -1,10 +1,10 @@ from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union from pydantic import BaseModel, Field from pydantic.networks import IPvAnyNetwork -from aws_lambda_powertools.utilities.parser.types import Literal, Model +from aws_lambda_powertools.utilities.parser.types import Literal class RequestContextV2AuthorizerIamCognito(BaseModel): @@ -67,5 +67,5 @@ class APIGatewayProxyEventV2Model(BaseModel): pathParameters: Optional[Dict[str, str]] stageVariables: Optional[Dict[str, str]] requestContext: RequestContextV2 - body: Optional[Union[str, Model]] + body: Optional[Union[str, Type[BaseModel]]] isBase64Encoded: bool diff --git a/aws_lambda_powertools/utilities/parser/models/cloudwatch.py b/aws_lambda_powertools/utilities/parser/models/cloudwatch.py index 9b954ec3b13..71e560276a4 100644 --- a/aws_lambda_powertools/utilities/parser/models/cloudwatch.py +++ b/aws_lambda_powertools/utilities/parser/models/cloudwatch.py @@ -3,19 +3,17 @@ import logging import zlib from datetime import datetime -from typing import List, Union +from typing import List, Type, Union from pydantic import BaseModel, Field, validator -from aws_lambda_powertools.utilities.parser.types import Model - logger = logging.getLogger(__name__) class CloudWatchLogsLogEvent(BaseModel): id: str # noqa AA03 VNE003 timestamp: datetime - message: Union[str, Model] + message: Union[str, Type[BaseModel]] class CloudWatchLogsDecode(BaseModel): diff --git a/aws_lambda_powertools/utilities/parser/models/dynamodb.py b/aws_lambda_powertools/utilities/parser/models/dynamodb.py index fe7514bada0..772b8fb580f 100644 --- a/aws_lambda_powertools/utilities/parser/models/dynamodb.py +++ b/aws_lambda_powertools/utilities/parser/models/dynamodb.py @@ -1,16 +1,16 @@ from datetime import date -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union from pydantic import BaseModel -from aws_lambda_powertools.utilities.parser.types import Literal, Model +from aws_lambda_powertools.utilities.parser.types import Literal class DynamoDBStreamChangedRecordModel(BaseModel): ApproximateCreationDateTime: Optional[date] Keys: Dict[str, Dict[str, Any]] - NewImage: Optional[Union[Dict[str, Any], Model]] - OldImage: Optional[Union[Dict[str, Any], Model]] + NewImage: Optional[Union[Dict[str, Any], Type[BaseModel]]] + OldImage: Optional[Union[Dict[str, Any], Type[BaseModel]]] SequenceNumber: str SizeBytes: int StreamViewType: Literal["NEW_AND_OLD_IMAGES", "KEYS_ONLY", "NEW_IMAGE", "OLD_IMAGE"] diff --git a/aws_lambda_powertools/utilities/parser/models/event_bridge.py b/aws_lambda_powertools/utilities/parser/models/event_bridge.py index f98a263c680..68359f867bd 100644 --- a/aws_lambda_powertools/utilities/parser/models/event_bridge.py +++ b/aws_lambda_powertools/utilities/parser/models/event_bridge.py @@ -1,10 +1,8 @@ from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union from pydantic import BaseModel, Field -from aws_lambda_powertools.utilities.parser.types import Model - class EventBridgeModel(BaseModel): version: str @@ -15,5 +13,5 @@ class EventBridgeModel(BaseModel): region: str resources: List[str] detail_type: str = Field(None, alias="detail-type") - detail: Union[Dict[str, Any], Model] + detail: Union[Dict[str, Any], Type[BaseModel]] replay_name: Optional[str] = Field(None, alias="replay-name") diff --git a/aws_lambda_powertools/utilities/parser/models/kinesis.py b/aws_lambda_powertools/utilities/parser/models/kinesis.py index 1c7c31c97b4..be868ca44ba 100644 --- a/aws_lambda_powertools/utilities/parser/models/kinesis.py +++ b/aws_lambda_powertools/utilities/parser/models/kinesis.py @@ -1,11 +1,11 @@ import base64 import logging from binascii import Error as BinAsciiError -from typing import List, Union +from typing import List, Type, Union from pydantic import BaseModel, validator -from aws_lambda_powertools.utilities.parser.types import Literal, Model +from aws_lambda_powertools.utilities.parser.types import Literal logger = logging.getLogger(__name__) @@ -14,7 +14,7 @@ class KinesisDataStreamRecordPayload(BaseModel): kinesisSchemaVersion: str partitionKey: str sequenceNumber: str - data: Union[bytes, Model] # base64 encoded str is parsed into bytes + data: Union[bytes, Type[BaseModel]] # base64 encoded str is parsed into bytes approximateArrivalTimestamp: float @validator("data", pre=True, allow_reuse=True) diff --git a/aws_lambda_powertools/utilities/parser/models/s3_object_event.py b/aws_lambda_powertools/utilities/parser/models/s3_object_event.py index 778786bc8cb..ef59e9c2f98 100644 --- a/aws_lambda_powertools/utilities/parser/models/s3_object_event.py +++ b/aws_lambda_powertools/utilities/parser/models/s3_object_event.py @@ -1,9 +1,7 @@ -from typing import Dict, Optional, Union +from typing import Dict, Optional, Type, Union from pydantic import BaseModel, HttpUrl -from aws_lambda_powertools.utilities.parser.types import Model - class S3ObjectContext(BaseModel): inputS3Url: HttpUrl @@ -14,7 +12,7 @@ class S3ObjectContext(BaseModel): class S3ObjectConfiguration(BaseModel): accessPointArn: str supportingAccessPointArn: str - payload: Union[str, Model] + payload: Union[str, Type[BaseModel]] class S3ObjectUserRequest(BaseModel): diff --git a/aws_lambda_powertools/utilities/parser/models/sns.py b/aws_lambda_powertools/utilities/parser/models/sns.py index cdcd9549a98..e329162e5c8 100644 --- a/aws_lambda_powertools/utilities/parser/models/sns.py +++ b/aws_lambda_powertools/utilities/parser/models/sns.py @@ -1,10 +1,12 @@ from datetime import datetime -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional +from typing import Type as TypingType +from typing import Union from pydantic import BaseModel, root_validator from pydantic.networks import HttpUrl -from aws_lambda_powertools.utilities.parser.types import Literal, Model +from aws_lambda_powertools.utilities.parser.types import Literal class SnsMsgAttributeModel(BaseModel): @@ -18,7 +20,7 @@ class SnsNotificationModel(BaseModel): UnsubscribeUrl: HttpUrl Type: Literal["Notification"] MessageAttributes: Optional[Dict[str, SnsMsgAttributeModel]] - Message: Union[str, Model] + Message: Union[str, TypingType[BaseModel]] MessageId: str SigningCertUrl: HttpUrl Signature: str diff --git a/aws_lambda_powertools/utilities/parser/models/sqs.py b/aws_lambda_powertools/utilities/parser/models/sqs.py index 47871ab8840..1d56c4f8e34 100644 --- a/aws_lambda_powertools/utilities/parser/models/sqs.py +++ b/aws_lambda_powertools/utilities/parser/models/sqs.py @@ -1,9 +1,9 @@ from datetime import datetime -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Type, Union from pydantic import BaseModel -from aws_lambda_powertools.utilities.parser.types import Literal, Model +from aws_lambda_powertools.utilities.parser.types import Literal class SqsAttributesModel(BaseModel): @@ -52,7 +52,7 @@ class SqsMsgAttributeModel(BaseModel): class SqsRecordModel(BaseModel): messageId: str receiptHandle: str - body: Union[str, Model] + body: Union[str, Type[BaseModel]] attributes: SqsAttributesModel messageAttributes: Dict[str, SqsMsgAttributeModel] md5OfBody: str diff --git a/mypy.ini b/mypy.ini index faf6014a54d..3061cc4a2d9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,3 +31,6 @@ ignore_missing_imports = True [mypy-aws_xray_sdk.ext.aiohttp.client] ignore_missing_imports = True + +[mypy-dataclasses] +ignore_missing_imports = True \ No newline at end of file diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index 0732f1d58b1..5b76cda0475 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -648,6 +648,29 @@ def test_data_record_invalid_status_value(): assert e.value.args[0] == "UNSUPPORTED_STATUS" +def test_data_record_json_to_dict_mapping(): + # GIVEN a data record with status "INPROGRESS" and provided response data + data_record = DataRecord( + "key", status="INPROGRESS", response_data='{"body": "execution finished","statusCode": "200"}' + ) + + # WHEN translating response data to dictionary + response_data = data_record.response_json_as_dict() + + # THEN return dictionary + assert isinstance(response_data, dict) + + +def test_data_record_json_to_dict_mapping_when_response_data_none(): + # GIVEN a data record with status "INPROGRESS" and not set response data + data_record = DataRecord("key", status="INPROGRESS", response_data=None) + # WHEN translating response data to dictionary + response_data = data_record.response_json_as_dict() + + # THEN return null value + assert response_data is None + + @pytest.mark.parametrize("idempotency_config", [{"use_local_cache": True}], indirect=True) def test_in_progress_never_saved_to_cache( idempotency_config: IdempotencyConfig, persistence_store: DynamoDBPersistenceLayer diff --git a/tests/functional/test_metrics.py b/tests/functional/test_metrics.py index ae160c65d87..9d03a25e8b6 100644 --- a/tests/functional/test_metrics.py +++ b/tests/functional/test_metrics.py @@ -493,9 +493,29 @@ def lambda_handler(evt, context): output = capture_metrics_output(capsys) + # THEN ColdStart metric and function_name and service dimension should be logged + assert output["ColdStart"] == [1.0] + assert output["function_name"] == "example_fn" + assert output['service'] == service + +def test_log_metrics_capture_cold_start_metric_no_service(capsys, namespace): + # GIVEN Metrics is initialized without service + my_metrics = Metrics(namespace=namespace) + + # WHEN log_metrics is used with capture_cold_start_metric + @my_metrics.log_metrics(capture_cold_start_metric=True) + def lambda_handler(evt, context): + pass + + LambdaContext = namedtuple("LambdaContext", "function_name") + lambda_handler({}, LambdaContext("example_fn")) + + output = capture_metrics_output(capsys) + # THEN ColdStart metric and function_name dimension should be logged assert output["ColdStart"] == [1.0] assert output["function_name"] == "example_fn" + assert output.get('service') is None def test_emit_cold_start_metric_only_once(capsys, namespace, service, metric):