From 36905b583a9de922f30f1802985d2c64ad1136ef Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 19 Feb 2024 17:23:40 +0100 Subject: [PATCH] fix(event-handler): multi-value query string and validation of scalar parameters (#3795) --- .../middlewares/openapi_validation.py | 23 +- .../utilities/data_classes/alb_event.py | 4 +- .../data_classes/api_gateway_proxy_event.py | 14 +- .../data_classes/bedrock_agent_event.py | 4 - .../utilities/data_classes/common.py | 14 +- .../utilities/data_classes/vpc_lattice.py | 26 +- tests/functional/event_handler/conftest.py | 32 ++ .../test_openapi_validation_middleware.py | 368 +++++++++++------- 8 files changed, 287 insertions(+), 198 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 25ac97ddf89..241a9972953 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -368,7 +368,10 @@ def _get_embed_body( return received_body, field_alias_omitted -def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, str]], params: Sequence[ModelField]): +def _normalize_multi_query_string_with_param( + query_string: Dict[str, List[str]], + params: Sequence[ModelField], +) -> Dict[str, Any]: """ Extract and normalize resolved_query_string_parameters @@ -383,15 +386,15 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st ------- A dictionary containing the processed multi_query_string_parameters. """ - if query_string: - for param in filter(is_scalar_field, params): - try: - # if the target parameter is a scalar, we keep the first value of the query string - # regardless if there are more in the payload - query_string[param.alias] = query_string[param.alias][0] - except KeyError: - pass - return query_string + resolved_query_string: Dict[str, Any] = query_string + for param in filter(is_scalar_field, params): + try: + # if the target parameter is a scalar, we keep the first value of the query string + # regardless if there are more in the payload + resolved_query_string[param.alias] = query_string[param.alias][0] + except KeyError: + pass + return resolved_query_string def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]): diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 98f37b4f415..1ec2535850b 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -36,11 +36,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: + def resolved_query_string_parameters(self) -> Dict[str, List[str]]: if self.multi_value_query_string_parameters: return self.multi_value_query_string_parameters - return self.query_string_parameters + return super().resolved_query_string_parameters @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index c37bd22ca53..ff24e908d1a 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -119,11 +119,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: + def resolved_query_string_parameters(self) -> Dict[str, List[str]]: if self.multi_value_query_string_parameters: return self.multi_value_query_string_parameters - return self.query_string_parameters + return super().resolved_query_string_parameters @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: @@ -318,16 +318,6 @@ def http_method(self) -> str: def header_serializer(self): return HttpApiHeadersSerializer() - @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: - if self.query_string_parameters is not None: - query_string = { - key: value.split(",") if "," in value else value for key, value in self.query_string_parameters.items() - } - return query_string - - return {} - @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: if self.headers is not None: diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 0fa97036a3e..399c435b3ec 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -109,10 +109,6 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: # together with the other parameters. So we just return all parameters here. return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None - @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.query_string_parameters - @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: return {} diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 25fb5a4c170..067706140fd 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -104,7 +104,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: return self.get("queryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: + def resolved_query_string_parameters(self) -> Dict[str, List[str]]: """ This property determines the appropriate query string parameter to be used as a trusted source for validating OpenAPI. @@ -112,7 +112,11 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: This is necessary because different resolvers use different formats to encode multi query string parameters. """ - return self.query_string_parameters + if self.query_string_parameters is not None: + query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} + return query_string + + return {} @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: @@ -186,8 +190,7 @@ def get_header_value( name: str, default_value: str, case_sensitive: Optional[bool] = False, - ) -> str: - ... + ) -> str: ... @overload def get_header_value( @@ -195,8 +198,7 @@ def get_header_value( name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False, - ) -> Optional[str]: - ... + ) -> Optional[str]: ... def get_header_value( self, diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index 15144e41d7d..f997d4b3f04 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -73,8 +73,7 @@ def get_header_value( name: str, default_value: str, case_sensitive: Optional[bool] = False, - ) -> str: - ... + ) -> str: ... @overload def get_header_value( @@ -82,8 +81,7 @@ def get_header_value( name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False, - ) -> Optional[str]: - ... + ) -> Optional[str]: ... def get_header_value( self, @@ -140,10 +138,6 @@ def query_string_parameters(self) -> Dict[str, str]: """The request query string parameters.""" return self["query_string_parameters"] - @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.query_string_parameters - @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: if self.headers is not None: @@ -255,17 +249,21 @@ def path(self) -> str: @property def request_context(self) -> vpcLatticeEventV2RequestContext: - """he VPC Lattice v2 Event request context.""" + """The VPC Lattice v2 Event request context.""" return vpcLatticeEventV2RequestContext(self["requestContext"]) @property def query_string_parameters(self) -> Optional[Dict[str, str]]: - """The request query string parameters.""" - return self.get("queryStringParameters") + """The request query string parameters. - @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.query_string_parameters + For VPC Lattice V2, the queryStringParameters will contain a Dict[str, List[str]] + so to keep compatibility with existing utilities, we merge all the values with a comma. + """ + params = self.get("queryStringParameters") + if params: + return {key: ",".join(value) for key, value in params.items()} + else: + return None @property def resolved_headers_field(self) -> Optional[Dict[str, str]]: diff --git a/tests/functional/event_handler/conftest.py b/tests/functional/event_handler/conftest.py index c7a4ac6e500..5c2bdb7729a 100644 --- a/tests/functional/event_handler/conftest.py +++ b/tests/functional/event_handler/conftest.py @@ -2,6 +2,8 @@ import pytest +from tests.functional.utils import load_event + @pytest.fixture def json_dump(): @@ -39,3 +41,33 @@ def validation_schema(): @pytest.fixture def raw_event(): return {"message": "hello hello", "username": "blah blah"} + + +@pytest.fixture +def gw_event(): + return load_event("apiGatewayProxyEvent.json") + + +@pytest.fixture +def gw_event_http(): + return load_event("apiGatewayProxyV2Event.json") + + +@pytest.fixture +def gw_event_alb(): + return load_event("albMultiValueQueryStringEvent.json") + + +@pytest.fixture +def gw_event_lambda_url(): + return load_event("lambdaFunctionUrlEventWithHeaders.json") + + +@pytest.fixture +def gw_event_vpc_lattice(): + return load_event("vpcLatticeV2EventWithHeaders.json") + + +@pytest.fixture +def gw_event_vpc_lattice_v1(): + return load_event("vpcLatticeEvent.json") diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index be3a13dd656..a9396644b98 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -15,22 +15,12 @@ Response, VPCLatticeResolver, VPCLatticeV2Resolver, - content_types, ) from aws_lambda_powertools.event_handler.openapi.params import Body, Header, Query from aws_lambda_powertools.shared.types import Annotated -from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent -from tests.functional.utils import load_event -LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") -LOAD_GW_EVENT_HTTP = load_event("apiGatewayProxyV2Event.json") -LOAD_GW_EVENT_ALB = load_event("albMultiValueQueryStringEvent.json") -LOAD_GW_EVENT_LAMBDA_URL = load_event("lambdaFunctionUrlEventWithHeaders.json") -LOAD_GW_EVENT_VPC_LATTICE = load_event("vpcLatticeV2EventWithHeaders.json") -LOAD_GW_EVENT_VPC_LATTICE_V1 = load_event("vpcLatticeEvent.json") - -def test_validate_scalars(): +def test_validate_scalars(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -40,22 +30,22 @@ def handler(user_id: int): print(user_id) # sending a number - LOAD_GW_EVENT["path"] = "/users/123" + gw_event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 # sending a string - LOAD_GW_EVENT["path"] = "/users/abc" + gw_event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) -def test_validate_scalars_with_default(): +def test_validate_scalars_with_default(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -65,22 +55,22 @@ def handler(user_id: int = 123): print(user_id) # sending a number - LOAD_GW_EVENT["path"] = "/users/123" + gw_event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 # sending a string - LOAD_GW_EVENT["path"] = "/users/abc" + gw_event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) -def test_validate_scalars_with_default_and_optional(): +def test_validate_scalars_with_default_and_optional(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -90,22 +80,22 @@ def handler(user_id: int = 123, include_extra: bool = False): print(user_id) # sending a number - LOAD_GW_EVENT["path"] = "/users/123" + gw_event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 # sending a string - LOAD_GW_EVENT["path"] = "/users/abc" + gw_event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) -def test_validate_return_type(): +def test_validate_return_type(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -114,16 +104,16 @@ def test_validate_return_type(): def handler() -> int: return 123 - LOAD_GW_EVENT["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be 123 - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert result["body"] == "123" -def test_validate_return_list(): +def test_validate_return_list(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -132,16 +122,16 @@ def test_validate_return_list(): def handler() -> List[int]: return [123, 234] - LOAD_GW_EVENT["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be [123, 234] - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == [123, 234] -def test_validate_return_tuple(): +def test_validate_return_tuple(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -152,16 +142,16 @@ def test_validate_return_tuple(): def handler() -> Tuple: return sample_tuple - LOAD_GW_EVENT["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a tuple - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == [1, 2, 3] -def test_validate_return_purepath(): +def test_validate_return_purepath(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -173,16 +163,16 @@ def test_validate_return_purepath(): def handler() -> str: return sample_path.as_posix() - LOAD_GW_EVENT["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a string - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert result["body"] == sample_path.as_posix() -def test_validate_return_enum(): +def test_validate_return_enum(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -194,16 +184,16 @@ class Model(Enum): def handler() -> Model: return Model.name.value - LOAD_GW_EVENT["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a string - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert result["body"] == "powertools" -def test_validate_return_dataclass(): +def test_validate_return_dataclass(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -217,16 +207,16 @@ class Model: def handler() -> Model: return Model(name="John", age=30) - LOAD_GW_EVENT["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_return_model(): +def test_validate_return_model(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -239,16 +229,16 @@ class Model(BaseModel): def handler() -> Model: return Model(name="John", age=30) - LOAD_GW_EVENT["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_invalid_return_model(): +def test_validate_invalid_return_model(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -261,16 +251,16 @@ class Model(BaseModel): def handler() -> Model: return {"name": "John"} # type: ignore - LOAD_GW_EVENT["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 422 # THEN the body must be a dict - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] -def test_validate_body_param(): +def test_validate_body_param(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -283,18 +273,18 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_body_param_with_stripped_headers(): +def test_validate_body_param_with_stripped_headers(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -308,19 +298,19 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["headers"] = {"Content-type": " application/json "} - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + gw_event["httpMethod"] = "POST" + gw_event["headers"] = {"Content-type": " application/json "} + gw_event["path"] = "/" + gw_event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_body_param_with_invalid_date(): +def test_validate_body_param_with_invalid_date(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -333,18 +323,18 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = "{" # invalid JSON + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = "{" # invalid JSON # THEN the handler should be invoked and return 422 # THEN the body must have the "json_invalid" error message - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert "json_invalid" in result["body"] -def test_validate_embed_body_param(): +def test_validate_embed_body_param(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -357,24 +347,24 @@ class Model(BaseModel): def handler(user: Annotated[Model, Body(embed=True)]) -> Model: return user - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 422 # THEN the body must be a dict - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] # THEN the handler should be invoked and return 200 # THEN the body must be a dict - LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) - result = app(LOAD_GW_EVENT, {}) + gw_event["body"] = json.dumps({"user": {"name": "John", "age": 30}}) + result = app(gw_event, {}) assert result["statusCode"] == 200 -def test_validate_response_return(): +def test_validate_response_return(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -387,18 +377,18 @@ class Model(BaseModel): def handler(user: Model) -> Response[Model]: return Response(body=user, status_code=200, content_type="application/json") - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a dict - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_response_invalid_return(): +def test_validate_response_invalid_return(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -411,13 +401,13 @@ class Model(BaseModel): def handler(user: Model) -> Response[Model]: return Response(body=user, status_code=200) - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({}) + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = json.dumps({}) # THEN the handler should be invoked and return 422 # THEN the body should have the word missing - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] @@ -431,12 +421,17 @@ def handler(user: Model) -> Response[Model]: ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_api_rest_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event, +): # GIVEN a APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) - LOAD_GW_EVENT["httpMethod"] = "GET" - LOAD_GW_EVENT["path"] = "/users" + gw_event["httpMethod"] = "GET" + gw_event["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -455,8 +450,8 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT["queryStringParameters"] = None - LOAD_GW_EVENT["multiValueQueryStringParameters"] = None + gw_event["queryStringParameters"] = None + gw_event["multiValueQueryStringParameters"] = None @app.get("/users") def handler3(): @@ -464,7 +459,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -480,13 +475,18 @@ def handler3(): ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_api_http_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_api_http_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_http, +): # GIVEN a APIGatewayHttpResolver with validation enabled app = APIGatewayHttpResolver(enable_validation=True) - LOAD_GW_EVENT_HTTP["rawPath"] = "/users" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + gw_event_http["rawPath"] = "/users" + gw_event_http["requestContext"]["http"]["method"] = "GET" + gw_event_http["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -505,7 +505,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT_HTTP["queryStringParameters"] = None + gw_event_http["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -513,7 +513,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_HTTP, {}) + result = app(gw_event_http, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -529,13 +529,18 @@ def handler3(): ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_alb_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_alb_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_alb, +): # GIVEN a ALBResolver with validation enabled app = ALBResolver(enable_validation=True) - LOAD_GW_EVENT_ALB["path"] = "/users" - # WHEN a handler is defined with various parameters and routes + gw_event_alb["path"] = "/users" + # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params if handler_func == "handler1_with_correct_params": @@ -552,7 +557,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT_HTTP["multiValueQueryStringParameters"] = None + gw_event_alb["multiValueQueryStringParameters"] = None @app.get("/users") def handler3(): @@ -560,7 +565,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_ALB, {}) + result = app(gw_event_alb, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -576,13 +581,18 @@ def handler3(): ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_lambda_url_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_lambda_url, +): # GIVEN a LambdaFunctionUrlResolver with validation enabled app = LambdaFunctionUrlResolver(enable_validation=True) - LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + gw_event_lambda_url["rawPath"] = "/users" + gw_event_lambda_url["requestContext"]["http"]["method"] = "GET" + gw_event_lambda_url["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -601,7 +611,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT_LAMBDA_URL["queryStringParameters"] = None + gw_event_lambda_url["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -609,7 +619,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) + result = app(gw_event_lambda_url, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -625,11 +635,16 @@ def handler3(): ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_vpc_lattice_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_vpc_lattice_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_vpc_lattice, +): # GIVEN a VPCLatticeV2Resolver with validation enabled app = VPCLatticeV2Resolver(enable_validation=True) - LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" + gw_event_vpc_lattice["path"] = "/users" # WHEN a handler is defined with various parameters and routes @@ -649,7 +664,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT_VPC_LATTICE["queryStringParameters"] = None + gw_event_vpc_lattice["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -657,7 +672,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) + result = app(gw_event_vpc_lattice, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -675,12 +690,17 @@ def handler3(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_api_rest_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event, +): # GIVEN a APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) - LOAD_GW_EVENT["httpMethod"] = "GET" - LOAD_GW_EVENT["path"] = "/users" + gw_event["httpMethod"] = "GET" + gw_event["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -709,8 +729,8 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT["headers"] = None - LOAD_GW_EVENT["multiValueHeaders"] = None + gw_event["headers"] = None + gw_event["multiValueHeaders"] = None @app.get("/users") def handler4(): @@ -718,7 +738,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT, {}) + result = app(gw_event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -735,13 +755,18 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_http_rest_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_http_rest_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_http, +): # GIVEN a APIGatewayHttpResolver with validation enabled app = APIGatewayHttpResolver(enable_validation=True) - LOAD_GW_EVENT_HTTP["rawPath"] = "/users" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + gw_event_http["rawPath"] = "/users" + gw_event_http["requestContext"]["http"]["method"] = "GET" + gw_event_http["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -770,7 +795,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_HTTP["headers"] = None + gw_event_http["headers"] = None @app.get("/users") def handler4(): @@ -778,7 +803,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_HTTP, {}) + result = app(gw_event_http, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -795,11 +820,16 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_alb_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_alb_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_alb, +): # GIVEN a ALBResolver with validation enabled app = ALBResolver(enable_validation=True) - LOAD_GW_EVENT_ALB["path"] = "/users" + gw_event_alb["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -828,7 +858,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_ALB["multiValueHeaders"] = None + gw_event_alb["multiValueHeaders"] = None @app.get("/users") def handler4(): @@ -836,7 +866,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_ALB, {}) + result = app(gw_event_alb, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -853,13 +883,18 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_lambda_url_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_lambda_url, +): # GIVEN a LambdaFunctionUrlResolver with validation enabled app = LambdaFunctionUrlResolver(enable_validation=True) - LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + gw_event_lambda_url["rawPath"] = "/users" + gw_event_lambda_url["requestContext"]["http"]["method"] = "GET" + gw_event_lambda_url["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -888,7 +923,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_LAMBDA_URL["headers"] = None + gw_event_lambda_url["headers"] = None @app.get("/users") def handler4(): @@ -896,7 +931,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) + result = app(gw_event_lambda_url, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -913,12 +948,17 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_vpc_lattice_v1_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_vpc_lattice_v1_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_vpc_lattice_v1, +): # GIVEN a VPCLatticeResolver with validation enabled app = VPCLatticeResolver(enable_validation=True) - LOAD_GW_EVENT_VPC_LATTICE_V1["raw_path"] = "/users" - LOAD_GW_EVENT_VPC_LATTICE_V1["method"] = "GET" + gw_event_vpc_lattice_v1["raw_path"] = "/users" + gw_event_vpc_lattice_v1["method"] = "GET" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -947,7 +987,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_VPC_LATTICE_V1["headers"] = None + gw_event_vpc_lattice_v1["headers"] = None @app.get("/users") def handler4(): @@ -955,7 +995,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_VPC_LATTICE_V1, {}) + result = app(gw_event_vpc_lattice_v1, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -972,12 +1012,17 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_vpc_lattice_v2_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_vpc_lattice_v2_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_vpc_lattice, +): # GIVEN a VPCLatticeV2Resolver with validation enabled app = VPCLatticeV2Resolver(enable_validation=True) - LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" - LOAD_GW_EVENT_VPC_LATTICE["method"] = "GET" + gw_event_vpc_lattice["path"] = "/users" + gw_event_vpc_lattice["method"] = "GET" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -1006,7 +1051,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_VPC_LATTICE["headers"] = None + gw_event_vpc_lattice["headers"] = None @app.get("/users") def handler3(): @@ -1014,7 +1059,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) + result = app(gw_event_vpc_lattice, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -1022,21 +1067,44 @@ def handler3(): assert any(text in result["body"] for text in expected_error_text) -def test_validation_with_alias(): - # GIVEN a Http API V2 proxy type event +def test_validation_with_alias(gw_event): + # GIVEN a REST API V2 proxy type event app = APIGatewayRestResolver(enable_validation=True) - event = load_event("apiGatewayProxyEvent.json") - class FunkyTown(BaseModel): - parameter: str + # GIVEN that it has a multiple parameters called "parameter1" + gw_event["queryStringParameters"] = { + "parameter1": "value1,value2", + } @app.get("/my/path") def my_path( parameter: Annotated[Optional[str], Query(alias="parameter1")] = None, - ) -> Response[FunkyTown]: - assert isinstance(app.current_event, APIGatewayProxyEvent) + ) -> str: assert parameter == "value1" - return Response(200, content_types.APPLICATION_JSON, FunkyTown(parameter=parameter)) + return parameter + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + - result = app(event, {}) +def test_validation_with_http_single_param(gw_event_http): + # GIVEN a HTTP API V2 proxy type event + app = APIGatewayHttpResolver(enable_validation=True) + + # GIVEN that it has a single parameter called "parameter2" + gw_event_http["queryStringParameters"] = { + "parameter1": "value1,value2", + "parameter2": "value", + } + + # WHEN a handler is defined with a single parameter + @app.post("/my/path") + def my_path( + parameter2: str, + ) -> str: + assert parameter2 == "value" + return parameter2 + + # THEN the handler should be invoked and return 200 + result = app(gw_event_http, {}) assert result["statusCode"] == 200