Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mypy): complete mypy support for the entire codebase #943

Merged
merged 12 commits into from
Jan 31, 2022
2 changes: 2 additions & 0 deletions .github/workflows/python_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ jobs:
run: make dev
- name: Formatting and Linting
run: make lint
- name: Typing
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
run: make mypy
- name: Test with pytest
run: make test
- name: Security baseline
Expand Down
6 changes: 3 additions & 3 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enum import Enum
from functools import partial
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
Expand Down Expand Up @@ -167,7 +167,7 @@ class Route:
"""Internally used Route Configuration"""

def __init__(
self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str]
self, method: str, rule: Pattern, func: Callable, cors: bool, compress: bool, cache_control: Optional[str]
):
self.method = method.upper()
self.rule = rule
Expand Down Expand Up @@ -555,7 +555,7 @@ def _resolve(self) -> ResponseBuilder:
for route in self._routes:
if method != route.method:
continue
match_results: Optional[re.Match] = route.rule.match(path)
match_results: Optional[Match] = route.rule.match(path)
if match_results:
logger.debug("Found a registered route. Calling function")
return self._call_route(route, match_results.groupdict()) # pass fn args
Expand Down
5 changes: 3 additions & 2 deletions aws_lambda_powertools/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def lambda_handler():
----------
service : str, optional
service name to be used as metric dimension, by default "service_undefined"
namespace : str
namespace : str, optional
Namespace for metrics

Raises
Expand Down Expand Up @@ -209,5 +209,6 @@ def __add_cold_start_metric(self, context: Any) -> None:
logger.debug("Adding cold start metric and function_name dimension")
with single_metric(name="ColdStart", unit=MetricUnit.Count, value=1, namespace=self.namespace) as metric:
metric.add_dimension(name="function_name", value=context.function_name)
metric.add_dimension(name="service", value=self.service)
if self.service:
metric.add_dimension(name="service", value=str(self.service))
is_cold_start = False
12 changes: 7 additions & 5 deletions aws_lambda_powertools/shared/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Optional, Union


def strtobool(value: str) -> bool:
Expand Down Expand Up @@ -38,21 +38,23 @@ def resolve_truthy_env_var_choice(env: str, choice: Optional[bool] = None) -> bo
return choice if choice is not None else strtobool(env)


def resolve_env_var_choice(env: Any, choice: Optional[Any] = None) -> Union[bool, Any]:
def resolve_env_var_choice(
env: Optional[str] = None, choice: Optional[Union[str, float]] = None
) -> Optional[Union[str, float]]:
"""Pick explicit choice over env, if available, otherwise return env value received

NOTE: Environment variable should be resolved by the caller.

Parameters
----------
env : Any
env : str, Optional
environment variable actual value
choice : bool
choice : str|float, optional
explicit choice

Returns
-------
choice : str
choice : str, Optional
resolved choice as either bool or environment value
"""
return choice if choice is not None else env
2 changes: 1 addition & 1 deletion aws_lambda_powertools/utilities/batch/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from types import TracebackType
from typing import List, Optional, Tuple, Type

ExceptionInfo = Tuple[Type[BaseException], BaseException, TracebackType]
ExceptionInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]


class BaseBatchProcessingError(Exception):
Expand Down
14 changes: 11 additions & 3 deletions aws_lambda_powertools/utilities/batch/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
"""
import logging
import sys
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, cast

import boto3
from botocore.config import Config

from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord

from ...middleware_factory import lambda_handler_decorator
from .base import BasePartialProcessor
from .exceptions import SQSBatchProcessingError
Expand Down Expand Up @@ -84,11 +86,17 @@ def _get_queue_url(self) -> Optional[str]:
*_, account_id, queue_name = self.records[0]["eventSourceARN"].split(":")
return f"{self.client._endpoint.host}/{account_id}/{queue_name}"

def _get_entries_to_clean(self) -> List:
def _get_entries_to_clean(self) -> List[Dict[str, str]]:
"""
Format messages to use in batch deletion
"""
return [{"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]} for msg in self.success_messages]
entries = []
# success_messages has generic type of union of SQS, Dynamodb and Kinesis Streams records or Pydantic models.
# Here we get SQS Record only
messages = cast(List[SQSRecord], self.success_messages)
for msg in messages:
entries.append({"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]})
return entries

def _process_record(self, record) -> Tuple:
"""
Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/idempotency/idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def process_order(customer_id: str, order: dict, **kwargs):
return {"StatusCode": 200}
"""

if function is None:
if not function:
return cast(
AnyCallableT,
functools.partial(
Expand All @@ -132,7 +132,7 @@ def decorate(*args, **kwargs):

payload = kwargs.get(data_keyword_argument)

if payload is None:
if not payload:
raise RuntimeError(
f"Unable to extract '{data_keyword_argument}' from keyword arguments."
f" Ensure this exists in your function's signature as well as the caller used it as a keyword argument"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,16 @@ def status(self) -> str:
else:
raise IdempotencyInvalidStatusError(self._status)

def response_json_as_dict(self) -> dict:
def response_json_as_dict(self) -> Optional[dict]:
"""
Get response data deserialized to python dict

Returns
-------
dict
Optional[dict]
previous response data deserialized
"""
return json.loads(self.response_data)
return json.loads(self.response_data) if self.response_data else None


class BasePersistenceLayer(ABC):
Expand All @@ -121,7 +121,6 @@ def __init__(self):
self.raise_on_no_idempotency_key = False
self.expires_after_seconds: int = 60 * 60 # 1 hour default
self.use_local_cache = False
self._cache: Optional[LRUDict] = None
self.hash_function = None

def configure(self, config: IdempotencyConfig, function_name: Optional[str] = None) -> None:
Expand Down
26 changes: 14 additions & 12 deletions aws_lambda_powertools/utilities/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get(
transform: Optional[str] = None,
force_fetch: bool = False,
**sdk_options,
) -> Union[str, list, dict, bytes]:
) -> Optional[Union[str, dict, bytes]]:
"""
Retrieve a parameter value or return the cached value

Expand Down Expand Up @@ -81,6 +81,7 @@ def get(
# of supported transform is small and the probability that a given
# parameter will always be used in a specific transform, this should be
# an acceptable tradeoff.
value: Optional[Union[str, bytes, dict]] = None
key = (name, transform)

if not force_fetch and self._has_not_expired(key):
Expand All @@ -92,7 +93,7 @@ def get(
except Exception as exc:
raise GetParameterError(str(exc))

if transform is not None:
if transform:
if isinstance(value, bytes):
value = value.decode("utf-8")
value = transform_value(value, transform)
Expand Down Expand Up @@ -146,26 +147,25 @@ def get_multiple(
TransformParameterError
When the parameter provider fails to transform a parameter value.
"""

key = (path, transform)

if not force_fetch and self._has_not_expired(key):
return self.store[key].value

try:
values: Dict[str, Union[str, bytes, dict, None]] = self._get_multiple(path, **sdk_options)
values = self._get_multiple(path, **sdk_options)
# Encapsulate all errors into a generic GetParameterError
except Exception as exc:
raise GetParameterError(str(exc))

if transform is not None:
for (key, value) in values.items():
_transform = get_transform_method(key, transform)
if _transform is None:
if transform:
transformed_values: dict = {}
for (item, value) in values.items():
_transform = get_transform_method(item, transform)
if not _transform:
continue

values[key] = transform_value(value, _transform, raise_on_transform_error)

transformed_values[item] = transform_value(value, _transform, raise_on_transform_error)
values.update(transformed_values)
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age))

return values
Expand Down Expand Up @@ -217,7 +217,9 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[
return None


def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]:
def transform_value(
value: str, transform: str, raise_on_transform_error: Optional[bool] = True
) -> Optional[Union[dict, bytes]]:
"""
Apply a transform to a value

Expand Down
6 changes: 4 additions & 2 deletions aws_lambda_powertools/utilities/parameters/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,17 @@ def __init__(self, config: Optional[Config] = None, boto3_session: Optional[boto

super().__init__()

def get(
# We break Liskov substitution principle due to differences in signatures of this method and superclass get method
# We ignore mypy error, as changes to the signature here or in a superclass is a breaking change to users
def get( # type: ignore[override]
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
self,
name: str,
max_age: int = DEFAULT_MAX_AGE_SECS,
transform: Optional[str] = None,
decrypt: bool = False,
force_fetch: bool = False,
**sdk_options
) -> Union[str, list, dict, bytes]:
) -> Optional[Union[str, dict, bytes]]:
"""
Retrieve a parameter value or return the cached value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
Parsed detail payload with model provided
"""
logger.debug(f"Parsing incoming data with Api Gateway model {APIGatewayProxyEventModel}")
parsed_envelope = APIGatewayProxyEventModel.parse_obj(data)
parsed_envelope: APIGatewayProxyEventModel = APIGatewayProxyEventModel.parse_obj(data)
logger.debug(f"Parsing event payload in `detail` with {model}")
return self._parse(data=parsed_envelope.body, model=model)
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
Parsed detail payload with model provided
"""
logger.debug(f"Parsing incoming data with Api Gateway model V2 {APIGatewayProxyEventV2Model}")
parsed_envelope = APIGatewayProxyEventV2Model.parse_obj(data)
parsed_envelope: APIGatewayProxyEventV2Model = APIGatewayProxyEventV2Model.parse_obj(data)
logger.debug(f"Parsing event payload in `detail` with {model}")
return self._parse(data=parsed_envelope.body, model=model)
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
Parsed detail payload with model provided
"""
logger.debug(f"Parsing incoming data with EventBridge model {EventBridgeModel}")
parsed_envelope = EventBridgeModel.parse_obj(data)
parsed_envelope: EventBridgeModel = EventBridgeModel.parse_obj(data)
logger.debug(f"Parsing event payload in `detail` with {model}")
return self._parse(data=parsed_envelope.detail, model=model)
11 changes: 7 additions & 4 deletions aws_lambda_powertools/utilities/parser/envelopes/kinesis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union, cast

from ..models import KinesisDataStreamModel
from ..types import Model
Expand Down Expand Up @@ -37,6 +37,9 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
logger.debug(f"Parsing incoming data with Kinesis model {KinesisDataStreamModel}")
parsed_envelope: KinesisDataStreamModel = KinesisDataStreamModel.parse_obj(data)
logger.debug(f"Parsing Kinesis records in `body` with {model}")
return [
self._parse(data=record.kinesis.data.decode("utf-8"), model=model) for record in parsed_envelope.Records
]
models = []
for record in parsed_envelope.Records:
# We allow either AWS expected contract (bytes) or a custom Model, see #943
data = cast(bytes, record.kinesis.data)
models.append(self._parse(data=data.decode("utf-8"), model=model))
return models
6 changes: 4 additions & 2 deletions aws_lambda_powertools/utilities/parser/envelopes/sns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union, cast

from ..models import SnsModel, SnsNotificationModel, SqsModel
from ..types import Model
Expand Down Expand Up @@ -69,6 +69,8 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
parsed_envelope = SqsModel.parse_obj(data)
output = []
for record in parsed_envelope.Records:
sns_notification = SnsNotificationModel.parse_raw(record.body)
# We allow either AWS expected contract (str) or a custom Model, see #943
body = cast(str, record.body)
sns_notification = SnsNotificationModel.parse_raw(body)
output.append(self._parse(data=sns_notification.Message, model=model))
return output
6 changes: 2 additions & 4 deletions aws_lambda_powertools/utilities/parser/models/alb.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Dict, Union
from typing import Dict, Type, Union

from pydantic import BaseModel

from aws_lambda_powertools.utilities.parser.types import Model


class AlbRequestContextData(BaseModel):
targetGroupArn: str
Expand All @@ -16,7 +14,7 @@ class AlbRequestContext(BaseModel):
class AlbModel(BaseModel):
httpMethod: str
path: str
body: Union[str, Model]
body: Union[str, Type[BaseModel]]
isBase64Encoded: bool
headers: Dict[str, str]
queryStringParameters: Dict[str, str]
Expand Down
6 changes: 3 additions & 3 deletions aws_lambda_powertools/utilities/parser/models/apigw.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union

from pydantic import BaseModel, root_validator
from pydantic.networks import IPvAnyNetwork

from aws_lambda_powertools.utilities.parser.types import Literal, Model
from aws_lambda_powertools.utilities.parser.types import Literal


class ApiGatewayUserCertValidity(BaseModel):
Expand Down Expand Up @@ -89,4 +89,4 @@ class APIGatewayProxyEventModel(BaseModel):
pathParameters: Optional[Dict[str, str]]
stageVariables: Optional[Dict[str, str]]
isBase64Encoded: bool
body: Optional[Union[str, Model]]
body: Optional[Union[str, Type[BaseModel]]]
6 changes: 3 additions & 3 deletions aws_lambda_powertools/utilities/parser/models/apigwv2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union

from pydantic import BaseModel, Field
from pydantic.networks import IPvAnyNetwork

from aws_lambda_powertools.utilities.parser.types import Literal, Model
from aws_lambda_powertools.utilities.parser.types import Literal


class RequestContextV2AuthorizerIamCognito(BaseModel):
Expand Down Expand Up @@ -67,5 +67,5 @@ class APIGatewayProxyEventV2Model(BaseModel):
pathParameters: Optional[Dict[str, str]]
stageVariables: Optional[Dict[str, str]]
requestContext: RequestContextV2
body: Optional[Union[str, Model]]
body: Optional[Union[str, Type[BaseModel]]]
isBase64Encoded: bool
Loading