From 14c51e3fdf60d7bb5d1998e92e6e7079df7b5f51 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Wed, 26 Jun 2019 09:32:28 -0700 Subject: [PATCH 01/30] Start adding cors --- samcli/commands/local/lib/provider.py | 2 +- samcli/commands/local/lib/sam_api_provider.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index c8265c81e6..324b577cdd 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -229,7 +229,7 @@ def __hash__(self): return hash(self.path) * hash(self.method) * hash(self.function_name) -Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) +Cors = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"]) class ApiProvider(object): diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 19b0559a3b..fc919b7338 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -6,7 +6,7 @@ from six import string_types from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.provider import ApiProvider, Api +from samcli.commands.local.lib.provider import ApiProvider, Api, Cors from samcli.commands.local.lib.sam_base_provider import SamBaseProvider from samcli.commands.local.lib.swagger.reader import SamSwaggerReader from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException @@ -15,7 +15,6 @@ class SamApiProvider(ApiProvider): - _IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" _SERVERLESS_FUNCTION = "AWS::Serverless::Function" _SERVERLESS_API = "AWS::Serverless::Api" @@ -127,6 +126,7 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): body = properties.get("DefinitionBody") uri = properties.get("DefinitionUri") binary_media = properties.get("BinaryMediaTypes", []) + cors = properties.get("Cors", {}) if not body and not uri: # Swagger is not found anywhere. @@ -145,6 +145,12 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): collector.add_apis(logical_id, apis) collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template + collector.add_cors(logical_id, Cors( + allow_origin=cors.get("AllowOrigin"), + allow_methods=cors.get("AllowMethods"), + allow_headers=cors.get("AllowHeaders"), + max_age=cors.get("MaxAge") + )) @staticmethod def _merge_apis(collector): @@ -362,6 +368,9 @@ def add_apis(self, logical_id, apis): properties = self._get_properties(logical_id) properties.apis.extend(apis) + def add_cors(self, logical_id, cors): + pass + def add_binary_media_types(self, logical_id, binary_media_types): """ Stores the binary media type configuration for the API with given logical ID From d376c4f580ebd9a925e1226daabd0637fd72899c Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Wed, 26 Jun 2019 10:09:18 -0700 Subject: [PATCH 02/30] Reorganize cors header --- samcli/commands/local/lib/sam_api_provider.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 057bcd6d93..3835eb4012 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -127,7 +127,23 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): uri = properties.get("DefinitionUri") binary_media = properties.get("BinaryMediaTypes", []) - cors = properties.get("Cors", {}) + cors_prop = properties.get("Cors") + cors = None + if cors_prop and isinstance(cors_prop, dict): + cors = Cors( + allow_origin=cors_prop.get("AllowOrigin"), + allow_methods=cors_prop.get("AllowMethods", SamApiProvider._ANY_HTTP_METHODS), + allow_headers=cors.get("AllowHeaders"), + max_age=cors.get("MaxAge") + ) + elif cors_prop and isinstance(cors_prop, string_types): + cors = Cors( + allow_origin=cors_prop, + allow_methods=SamApiProvider._ANY_HTTP_METHODS, + allow_headers=None, + max_age=None + ) + stage_name = properties.get("StageName") stage_variables = properties.get("Variables") if not body and not uri: @@ -147,15 +163,10 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): collector.add_apis(logical_id, apis) collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template - collector.add_cors(logical_id, Cors( - allow_origin=cors.get("AllowOrigin"), - allow_methods=cors.get("AllowMethods"), - allow_headers=cors.get("AllowHeaders"), - max_age=cors.get("MaxAge") - )) - collector.add_stage_name(logical_id, stage_name) collector.add_stage_variables(logical_id, stage_variables) + if cors: + collector.add_cors(logical_id, cors) @staticmethod def _merge_apis(collector): From 7de2882bcfde6bbcba18c736f728c56243f868ea Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Wed, 26 Jun 2019 11:21:52 -0700 Subject: [PATCH 03/30] Update cors --- samcli/commands/local/lib/sam_api_provider.py | 2 +- .../local/start_api/test_start_api.py | 63 +++++++++++++++++++ .../local/lib/test_local_api_service.py | 2 + .../local/lib/test_sam_api_provider.py | 18 ++++++ .../local/apigw/test_local_apigw_service.py | 4 +- 5 files changed, 87 insertions(+), 2 deletions(-) diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 3835eb4012..8b5b632432 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -336,7 +336,7 @@ def _normalize_http_methods(api): else: yield http_method.upper() - if api.cors: + if api.cors and http_method.upper() != "OPTIONS": yield "OPTIONS" diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 700491260d..e319e6021c 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -550,3 +550,66 @@ def test_swagger_stage_variable(self): response_data = response.json() self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'}) + + +class TestServiceCorsRequests(StartApiIntegBaseClass): + """ + Test to check that the correct headers are being added with cors + """ + template_path = "/testdata/start_api/template.yaml" + binary_data_file = "testdata/start_api/binarydata.gif" + + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + + def test_binary_request(self): + """ + This tests that the service can accept and invoke a lambda when given binary data in a request + """ + input_data = self.get_binary_data(self.binary_data_file) + response = requests.post(self.url + '/echobase64eventbody', + headers={"Content-Type": "image/gif"}, + data=input_data) + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Content-Type"), "image/gif") + self.assertEquals(response.content, input_data) + pass + + def test_request_with_form_data(self): + """ + Form-encoded data should be put into the Event to Lambda + """ + response = requests.post(self.url + "/echoeventbody", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data='key=value') + + self.assertEquals(response.status_code, 200) + + response_data = response.json() + + self.assertEquals(response_data.get("headers").get("Content-Type"), "application/x-www-form-urlencoded") + self.assertEquals(response_data.get("body"), "key=value") + pass + + def test_request_to_an_endpoint_with_two_different_handlers(self): + response = requests.get(self.url + "/echoeventbody") + + self.assertEquals(response.status_code, 200) + + response_data = response.json() + + self.assertEquals(response_data.get("handler"), 'echo_event_handler_2') + pass + + def test_request_with_multi_value_headers(self): + response = requests.get(self.url + "/echoeventbody", + headers={"Content-Type": "application/x-www-form-urlencoded, image/gif"}) + + self.assertEquals(response.status_code, 200) + response_data = response.json() + self.assertEquals(response_data.get("multiValueHeaders").get("Content-Type"), + ["application/x-www-form-urlencoded, image/gif"]) + self.assertEquals(response_data.get("headers").get("Content-Type"), + "application/x-www-form-urlencoded, image/gif") + pass diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index 3cc5d2c4c3..321f1e3214 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -217,3 +217,5 @@ def test_make_routing_list(self): self.assertEquals(len(routing_list), len(expected_routes)) for index, r in enumerate(routing_list): self.assertEquals(r.__dict__, expected_routes[index].__dict__) + +#TODO add test to check that \ No newline at end of file diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 3ac01956be..a0357ba9fb 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -1147,6 +1147,24 @@ def test_multi_stage_get_all(self): self.assertIn(api3, result) +class TestSamCors(TestCase): + # TODO + def test_provider_parse_cors_string(self): + pass + + def test_provider_parse_cors_dict(self): + pass + + def test_global_cors(self): + pass + + def test_implicit_explicit_cors(self): + pass + + def test_multi_cors_get_all(self): + pass + + def make_swagger(apis, binary_media_types=None): """ Given a list of API configurations named tuples, returns a Swagger document diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index ba2d6316b5..cc655c8ae2 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -488,7 +488,7 @@ def setUp(self): '"Custom User Agent String", "caller": null, "cognitoAuthenticationType": null, "sourceIp": ' \ '"190.0.0.0", "user": null}, "accountId": "123456789012"}, "headers": {"Content-Type": ' \ '"application/json", "X-Test": "Value", "X-Forwarded-Port": "3000", "X-Forwarded-Proto": "http"}, ' \ - '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], '\ + '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], ' \ '"X-Forwarded-Port": ["3000"], "X-Forwarded-Proto": ["http"]}, ' \ '"stageVariables": null, "path": "path", "pathParameters": {"path": "params"}, ' \ '"isBase64Encoded": false}' @@ -590,3 +590,5 @@ def test_should_base64_encode_returns_true(self, test_case_name, binary_types, m ]) def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype): self.assertFalse(LocalApigwService._should_base64_encode(binary_types, mimetype)) + +# TODO add test here for cors with mock and add tests for _cors_to_headers to check the conversion From 3c58a956473d533ab0df6ce3060ee1a4e6ac212f Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Wed, 26 Jun 2019 13:52:56 -0700 Subject: [PATCH 04/30] Add cors tests --- samcli/commands/local/lib/provider.py | 22 +- samcli/commands/local/lib/sam_api_provider.py | 13 +- samcli/local/apigw/local_apigw_service.py | 29 +- .../local/start_api/test_start_api.py | 66 ++-- .../testdata/start_api/swagger-template.yaml | 7 +- .../testdata/start_api/template.yaml | 1 + .../local/lib/test_local_api_service.py | 1 - .../local/lib/test_sam_api_provider.py | 309 +++++++++++++++++- .../local/apigw/test_local_apigw_service.py | 17 +- 9 files changed, 391 insertions(+), 74 deletions(-) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 9ca968ad9d..895b9d3cac 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -222,10 +222,10 @@ def get_all(self): # The variables for that stage "stage_variables" ]) -_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, - None, # Stage name is optional with default None - None # Stage variables is optional with default None +_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None + [], # binary_media_types is optional and defaults to empty, + None, # Stage name is optional with default None + None # Stage variables is optional with default None ) @@ -235,7 +235,19 @@ def __hash__(self): return hash(self.path) * hash(self.method) * hash(self.function_name) -Cors = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"]) +_CorsTuple = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"]) + +_CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None + None, # Allow Methods is optional and defaults to empty + None, # Allow Headers is optional and defaults to empty + None # MaxAge is optional and defaults to empty + ) + + +class Cors(_CorsTuple): + def __hash__(self): + # Other properties are not a part of the hash + return hash(self.allow_origin) * hash(self.allow_headers) * hash(self.allow_methods) * hash(self.max_age) class ApiProvider(object): diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 8b5b632432..0289e909cd 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -130,11 +130,16 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): cors_prop = properties.get("Cors") cors = None if cors_prop and isinstance(cors_prop, dict): + allow_methods = cors_prop.get("AllowMethods", SamApiProvider._ANY_HTTP_METHODS) + if isinstance(allow_methods, string_types): + allow_methods = [allow_methods.upper()] + if "OPTIONS" not in allow_methods: + allow_methods.append("OPTIONS") cors = Cors( allow_origin=cors_prop.get("AllowOrigin"), - allow_methods=cors_prop.get("AllowMethods", SamApiProvider._ANY_HTTP_METHODS), - allow_headers=cors.get("AllowHeaders"), - max_age=cors.get("MaxAge") + allow_methods=allow_methods, + allow_headers=cors_prop.get("AllowHeaders"), + max_age=cors_prop.get("MaxAge") ) elif cors_prop and isinstance(cors_prop, string_types): cors = Cors( @@ -209,7 +214,7 @@ def _merge_apis(collector): for config in all_configs: # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP # method on explicit API. - for normalized_method in SamApiProvider._normalize_http_methods(config.method): + for normalized_method in SamApiProvider._normalize_http_methods(config): key = config.path + normalized_method all_apis[key] = config diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 35768ab6a6..98e432a6a6 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -144,13 +144,14 @@ def _request_handler(self, **kwargs): Response object """ route = self._get_current_route(request) - + cors_headers = LocalApigwService.cors_to_headers(route.cors) if request.method == 'OPTIONS': - return self.service_response('', LocalApigwService._cors_to_headers(route.cors), 200) + headers = Headers(cors_headers) + return self.service_response('', headers, 200) try: event = self._construct_event(request, self.port, route.binary_types, route.stage_name, - route.stage_variables) + route.stage_variables, cors_headers) except UnicodeDecodeError: return ServiceErrorResponses.lambda_failure_response() @@ -320,7 +321,7 @@ def _merge_response_headers(headers, multi_headers): return processed_headers @staticmethod - def _construct_event(flask_request, port, binary_types, stage_name=None, stage_variables=None): + def _construct_event(flask_request, port, binary_types, stage_name=None, stage_variables=None, cors_headers=None): """ Helper method that constructs the Event to be passed to Lambda @@ -354,7 +355,7 @@ def _construct_event(flask_request, port, binary_types, stage_name=None, stage_v identity=identity, path=endpoint) - headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port) + headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port, cors_headers) query_string_dict, multi_value_query_string_dict = LocalApigwService._query_string_params(flask_request) @@ -409,7 +410,7 @@ def _query_string_params(flask_request): return query_string_dict, multi_value_query_string_dict @staticmethod - def _event_headers(flask_request, port): + def _event_headers(flask_request, port, cors_headers=None): """ Constructs an APIGW equivalent headers dictionary @@ -419,6 +420,8 @@ def _event_headers(flask_request, port): Request from Flask int port Forwarded Port + cors_headers dict + Dict of the Cors properties Returns dict (str: str), dict (str: list of str) ------- @@ -439,7 +442,9 @@ def _event_headers(flask_request, port): headers_dict["X-Forwarded-Port"] = str(port) multi_value_headers_dict["X-Forwarded-Port"] = [str(port)] - + if cors_headers: + headers_dict.update(cors_headers) + multi_value_headers_dict.update(cors_headers) return headers_dict, multi_value_headers_dict @staticmethod @@ -462,7 +467,7 @@ def _should_base64_encode(binary_types, request_mimetype): return request_mimetype in binary_types or "*/*" in binary_types @staticmethod - def _cors_to_headers(cors): + def cors_to_headers(cors): """ Convert CORS object to headers dictionary Parameters @@ -475,12 +480,12 @@ def _cors_to_headers(cors): """ headers = {} if cors.allow_origin is not None: - headers['Access-Control-Allow-Origin'] = cors.allow_origin[1:-1] + headers['Access-Control-Allow-Origin'] = cors.allow_origin if cors.allow_methods is not None: - headers['Access-Control-Allow-Methods'] = cors.allow_methods[1:-1] + headers['Access-Control-Allow-Methods'] = ','.join(cors.allow_methods) if cors.allow_headers is not None: - headers['Access-Control-Allow-Headers'] = cors.allow_headers[1:-1] + headers['Access-Control-Allow-Headers'] = cors.allow_headers if cors.max_age is not None: - headers['Access-Control-Max-Age'] = cors.max_age[1:-1] + headers['Access-Control-Max-Age'] = cors.max_age return headers diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index e319e6021c..02746860c4 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -552,64 +552,48 @@ def test_swagger_stage_variable(self): self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'}) -class TestServiceCorsRequests(StartApiIntegBaseClass): +class TestServiceCorsSwaggerRequests(StartApiIntegBaseClass): """ Test to check that the correct headers are being added with cors """ - template_path = "/testdata/start_api/template.yaml" + template_path = "/testdata/start_api/swagger-template.yaml" binary_data_file = "testdata/start_api/binarydata.gif" def setUp(self): self.url = "http://127.0.0.1:{}".format(self.port) - def test_binary_request(self): + def test_cors_swagger_options(self): """ This tests that the service can accept and invoke a lambda when given binary data in a request """ - input_data = self.get_binary_data(self.binary_data_file) - response = requests.post(self.url + '/echobase64eventbody', - headers={"Content-Type": "image/gif"}, - data=input_data) + response = requests.options(self.url + '/echobase64eventbody') self.assertEquals(response.status_code, 200) - self.assertEquals(response.headers.get("Content-Type"), "image/gif") - self.assertEquals(response.content, input_data) - pass - - def test_request_with_form_data(self): - """ - Form-encoded data should be put into the Event to Lambda - """ - response = requests.post(self.url + "/echoeventbody", - headers={"Content-Type": "application/x-www-form-urlencoded"}, - data='key=value') - - self.assertEquals(response.status_code, 200) - - response_data = response.json() - - self.assertEquals(response_data.get("headers").get("Content-Type"), "application/x-www-form-urlencoded") - self.assertEquals(response_data.get("body"), "key=value") - pass - def test_request_to_an_endpoint_with_two_different_handlers(self): - response = requests.get(self.url + "/echoeventbody") + self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") + self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), "origin, x-requested-with") + self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS") + self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510') - self.assertEquals(response.status_code, 200) - response_data = response.json() +class TestServiceCorsGlobalRequests(StartApiIntegBaseClass): + """ + Test to check that the correct headers are being added with cors + """ + template_path = "/testdata/start_api/template.yaml" - self.assertEquals(response_data.get("handler"), 'echo_event_handler_2') - pass + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) - def test_request_with_multi_value_headers(self): - response = requests.get(self.url + "/echoeventbody", - headers={"Content-Type": "application/x-www-form-urlencoded, image/gif"}) + def test_cors_global(self): + """ + This tests that the service can accept and invoke a lambda when given binary data in a request + """ + response = requests.options(self.url + '/echobase64eventbody') self.assertEquals(response.status_code, 200) - response_data = response.json() - self.assertEquals(response_data.get("multiValueHeaders").get("Content-Type"), - ["application/x-www-form-urlencoded, image/gif"]) - self.assertEquals(response_data.get("headers").get("Content-Type"), - "application/x-www-form-urlencoded, image/gif") - pass + self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") + self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None) + self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), + "GET,DELETE,PUT,POST,HEAD,OPTIONS,PATCH") + self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) diff --git a/tests/integration/testdata/start_api/swagger-template.yaml b/tests/integration/testdata/start_api/swagger-template.yaml index 9f987c0d8c..cff33b1f43 100644 --- a/tests/integration/testdata/start_api/swagger-template.yaml +++ b/tests/integration/testdata/start_api/swagger-template.yaml @@ -1,4 +1,4 @@ -AWSTemplateFormatVersion : '2010-09-09' +AWSTemplateFormatVersion: '2010-09-09' Transform: AWS::Serverless-2016-10-31 Globals: @@ -14,6 +14,11 @@ Resources: StageName: dev Variables: VarName: varValue + Cors: + AllowOrigin: "*" + AllowMethods: "GET" + AllowHeaders: "origin, x-requested-with" + MaxAge: 510 DefinitionBody: swagger: "2.0" info: diff --git a/tests/integration/testdata/start_api/template.yaml b/tests/integration/testdata/start_api/template.yaml index d9786fb52f..56d99abf1f 100644 --- a/tests/integration/testdata/start_api/template.yaml +++ b/tests/integration/testdata/start_api/template.yaml @@ -9,6 +9,7 @@ Globals: - image~1png Variables: VarName: varValue + Cors: "*" Resources: HelloWorldFunction: Type: AWS::Serverless::Function diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index 321f1e3214..8600c54d2a 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -218,4 +218,3 @@ def test_make_routing_list(self): for index, r in enumerate(routing_list): self.assertEquals(r.__dict__, expected_routes[index].__dict__) -#TODO add test to check that \ No newline at end of file diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index a0357ba9fb..dbbef922e0 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -8,7 +8,7 @@ from six import assertCountEqual from samcli.commands.local.lib.sam_api_provider import SamApiProvider -from samcli.commands.local.lib.provider import Api +from samcli.commands.local.lib.provider import Api, Cors from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException @@ -1148,21 +1148,312 @@ def test_multi_stage_get_all(self): class TestSamCors(TestCase): - # TODO def test_provider_parse_cors_string(self): - pass + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": "*", + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + }, + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + } + } + } + } + } + } + + provider = SamApiProvider(template) + + result = [f for f in provider.get_all()] + + api1 = Api(path='/path2', method='POST', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", allow_methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"]), + ) + api2 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", allow_methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"]), + ) + api3 = Api(path='/path', method='GET', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", allow_methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"]), + ) + + api4 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", allow_methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"]), + ) + self.assertEquals(len(result), 4) + self.assertIn(api1, result) + self.assertIn(api2, result) + self.assertIn(api3, result) + self.assertIn(api4, result) def test_provider_parse_cors_dict(self): - pass + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": { + "AllowMethods": "POST", + "AllowOrigin": "*", + "AllowHeaders": "Upgrade-Insecure-Requests", + "MaxAge": 600 + }, + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + }, + "/path": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + } + } + } + } + } + } + + provider = SamApiProvider(template) + + result = [f for f in provider.get_all()] + + api1 = Api(path='/path2', method='POST', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", + allow_methods=["POST", "OPTIONS"], + allow_headers="Upgrade-Insecure-Requests", + max_age=600), + ) + api2 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", + allow_methods=["POST", "OPTIONS"], + allow_headers="Upgrade-Insecure-Requests", + max_age=600), + ) + api3 = Api(path='/path', method='POST', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", + allow_methods=["POST", "OPTIONS"], + allow_headers="Upgrade-Insecure-Requests", + max_age=600), + ) + api4 = Api(path='/path', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", + allow_methods=["POST", "OPTIONS"], + allow_headers="Upgrade-Insecure-Requests", + max_age=600), + ) + self.assertEquals(len(result), 4) + self.assertIn(api1, result) + self.assertIn(api2, result) + self.assertIn(api3, result) + self.assertIn(api4, result) + + def test_default_cors_dict_prop(self): + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": { + "AllowOrigin": "www.domain.com", + }, + "DefinitionBody": { + "paths": { + "/path2": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + } + } + } + } + } + } + + provider = SamApiProvider(template) + + result = [f for f in provider.get_all()] + + api1 = Api(path='/path2', method='GET', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="www.domain.com", allow_methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"]), + ) + api2 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="www.domain.com", allow_methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"]), + ) + self.assertEquals(len(result), 2) + self.assertIn(api1, result) + self.assertIn(api2, result) def test_global_cors(self): - pass + template = { + "Global": { + "TestApi": { + "Cors": { + "AllowMethods": "POST", + "AllowOrigin": "*", + "AllowHeaders": "Upgrade-Insecure-Requests", + "MaxAge": 600 + }, + } + }, + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + }, + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + } + } + } + } + } + } + + provider = SamApiProvider(template) - def test_implicit_explicit_cors(self): - pass + result = [f for f in provider.get_all()] - def test_multi_cors_get_all(self): - pass + api1 = Api(path='/path2', method='GET', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", + allow_methods=["POST", "OPTIONS"], + max_age=600), + ) + api2 = Api(path='/path', method='GET', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", + allow_methods=["POST", "OPTIONS"], + allow_headers="Upgrade-Insecure-Requests", + max_age=600), + ) + api3 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", + allow_methods=["POST", "OPTIONS"], + max_age=600), + ) + api4 = Api(path='/path', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", + cors=Cors(allow_origin="*", + allow_methods=["POST", "OPTIONS"], + allow_headers="Upgrade-Insecure-Requests", + max_age=600), + ) + self.assertEquals(len(result), 4) + self.assertIn(api1, result) + self.assertIn(api2, result) + self.assertIn(api3, result) + self.assertIn(api4, result) def make_swagger(apis, binary_media_types=None): diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index cc655c8ae2..4121d25040 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -6,6 +6,7 @@ from parameterized import parameterized, param from werkzeug.datastructures import Headers +from commands.local.lib.provider import Cors from samcli.local.apigw.local_apigw_service import LocalApigwService, Route from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -591,4 +592,18 @@ def test_should_base64_encode_returns_true(self, test_case_name, binary_types, m def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype): self.assertFalse(LocalApigwService._should_base64_encode(binary_types, mimetype)) -# TODO add test here for cors with mock and add tests for _cors_to_headers to check the conversion + +class TestServiceCorsToHeaders(TestCase): + def test_basic_conversion(self): + cors = Cors(allow_origin="*", allow_methods=["POST", "OPTIONS"], allow_headers="UPGRADE-HEADER", max_age=6) + headers = LocalApigwService.cors_to_headers(cors) + self.assertEquals(headers, {'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Methods': 'POST,OPTIONS', + 'Access-Control-Allow-Headers': 'UPGRADE-HEADER', 'Access-Control-Max-Age': 6}) + + def test_empty_elements(self): + cors = Cors(allow_origin="www.domain.com", allow_methods=["GET", "POST", "OPTIONS"]) + headers = LocalApigwService.cors_to_headers(cors) + self.assertEquals(headers, + {'Access-Control-Allow-Origin': 'www.domain.com', + 'Access-Control-Allow-Methods': 'GET,POST,OPTIONS', + 'Access-Control-Allow-Headers': None, 'Access-Control-Max-Age': None}) From 0094f291fe3c0cc1206fab62a8618f6974d7c7fb Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Wed, 26 Jun 2019 14:34:39 -0700 Subject: [PATCH 05/30] Update tests to check make pr passes --- samcli/commands/local/lib/provider.py | 6 +- samcli/commands/local/lib/sam_api_provider.py | 57 +++++---- samcli/local/apigw/local_apigw_service.py | 5 +- .../local/lib/test_local_api_service.py | 27 ++--- .../local/lib/test_sam_api_provider.py | 111 +++++++++--------- .../local/apigw/test_local_apigw_service.py | 13 +- 6 files changed, 118 insertions(+), 101 deletions(-) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 895b9d3cac..b2dcc46bd3 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -223,9 +223,9 @@ def get_all(self): "stage_variables" ]) _ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, + [], # binary_media_types is optional and defaults to empty, None, # Stage name is optional with default None - None # Stage variables is optional with default None + None # Stage variables is optional with default None ) @@ -240,7 +240,7 @@ def __hash__(self): _CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None None, # Allow Methods is optional and defaults to empty None, # Allow Headers is optional and defaults to empty - None # MaxAge is optional and defaults to empty + None # MaxAge is optional and defaults to empty ) diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 0289e909cd..16528ea523 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -126,29 +126,7 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): body = properties.get("DefinitionBody") uri = properties.get("DefinitionUri") binary_media = properties.get("BinaryMediaTypes", []) - - cors_prop = properties.get("Cors") - cors = None - if cors_prop and isinstance(cors_prop, dict): - allow_methods = cors_prop.get("AllowMethods", SamApiProvider._ANY_HTTP_METHODS) - if isinstance(allow_methods, string_types): - allow_methods = [allow_methods.upper()] - if "OPTIONS" not in allow_methods: - allow_methods.append("OPTIONS") - cors = Cors( - allow_origin=cors_prop.get("AllowOrigin"), - allow_methods=allow_methods, - allow_headers=cors_prop.get("AllowHeaders"), - max_age=cors_prop.get("MaxAge") - ) - elif cors_prop and isinstance(cors_prop, string_types): - cors = Cors( - allow_origin=cors_prop, - allow_methods=SamApiProvider._ANY_HTTP_METHODS, - allow_headers=None, - max_age=None - ) - + cors = self._extract_cors(properties) stage_name = properties.get("StageName") stage_variables = properties.get("Variables") if not body and not uri: @@ -173,6 +151,39 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): if cors: collector.add_cors(logical_id, cors) + def _extract_cors(self, properties): + """ + Extract Cors property from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result + is added to the Api. + + Parameters + ---------- + properties : dict + Resource properties + """ + cors_prop = properties.get("Cors") + cors = None + if cors_prop and isinstance(cors_prop, dict): + allow_methods = cors_prop.get("AllowMethods", ','.join(SamApiProvider._ANY_HTTP_METHODS)) + + if allow_methods and "OPTIONS" not in allow_methods: + allow_methods += ",OPTIONS" + + cors = Cors( + allow_origin=cors_prop.get("AllowOrigin"), + allow_methods=allow_methods, + allow_headers=cors_prop.get("AllowHeaders"), + max_age=cors_prop.get("MaxAge") + ) + elif cors_prop and isinstance(cors_prop, string_types): + cors = Cors( + allow_origin=cors_prop, + allow_methods=','.join(SamApiProvider._ANY_HTTP_METHODS), + allow_headers=None, + max_age=None + ) + return cors + @staticmethod def _merge_apis(collector): """ diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 98e432a6a6..a95886aa43 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -145,7 +145,8 @@ def _request_handler(self, **kwargs): """ route = self._get_current_route(request) cors_headers = LocalApigwService.cors_to_headers(route.cors) - if request.method == 'OPTIONS': + + if route.method == 'OPTIONS': headers = Headers(cors_headers) return self.service_response('', headers, 200) @@ -482,7 +483,7 @@ def cors_to_headers(cors): if cors.allow_origin is not None: headers['Access-Control-Allow-Origin'] = cors.allow_origin if cors.allow_methods is not None: - headers['Access-Control-Allow-Methods'] = ','.join(cors.allow_methods) + headers['Access-Control-Allow-Methods'] = cors.allow_methods if cors.allow_headers is not None: headers['Access-Control-Allow-Headers'] = cors.allow_headers if cors.max_age is not None: diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index 8600c54d2a..0dbce41bed 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -106,9 +106,9 @@ class TestLocalApiService_make_routing_list(TestCase): def test_must_return_routing_list_from_apis(self): api_provider = Mock() apis = [ - Api(path="/1", method="GET1", function_name="name1", cors="CORS1"), - Api(path="/2", method="GET2", function_name="name2", cors="CORS2"), - Api(path="/3", method="GET3", function_name="name3", cors="CORS3"), + Api(path="/1", method="GET1", function_name="name1", cors=None), + Api(path="/2", method="GET2", function_name="name2", cors=None), + Api(path="/3", method="GET3", function_name="name3", cors=None), ] expected = [ Route(path="/1", methods=["GET1"], function_name="name1"), @@ -132,11 +132,11 @@ def test_must_print_routes(self): api_provider = Mock() apis = [ - Api(path="/1", method="GET", function_name="name1", cors="CORS1"), - Api(path="/1", method="POST", function_name="name1", cors="CORS1"), - Api(path="/1", method="DELETE", function_name="othername1", cors="CORS1"), - Api(path="/2", method="GET2", function_name="name2", cors="CORS2"), - Api(path="/3", method="GET3", function_name="name3", cors="CORS3"), + Api(path="/1", method="GET", function_name="name1"), + Api(path="/1", method="POST", function_name="name1"), + Api(path="/1", method="DELETE", function_name="othername1"), + Api(path="/2", method="GET2", function_name="name2"), + Api(path="/3", method="GET3", function_name="name3"), ] api_provider.get_all.return_value = apis @@ -188,12 +188,12 @@ class TestRoutingList(TestCase): def setUp(self): self.function_name = "routingTest" apis = [ - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors"), - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors", stage_name="Dev"), - provider.Api(path="/post", method="POST", function_name=self.function_name, cors="cors", stage_name="Prod"), - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors", + provider.Api(path="/get", method="GET", function_name=self.function_name), + provider.Api(path="/get", method="GET", function_name=self.function_name, stage_name="Dev"), + provider.Api(path="/post", method="POST", function_name=self.function_name, stage_name="Prod"), + provider.Api(path="/get", method="GET", function_name=self.function_name, stage_variables={"test": "data"}), - provider.Api(path="/post", method="POST", function_name=self.function_name, cors="cors", stage_name="Prod", + provider.Api(path="/post", method="POST", function_name=self.function_name, stage_name="Prod", stage_variables={"data": "more data"}), ] self.api_provider_mock = Mock() @@ -217,4 +217,3 @@ def test_make_routing_list(self): self.assertEquals(len(routing_list), len(expected_routes)) for index, r in enumerate(routing_list): self.assertEquals(r.__dict__, expected_routes[index].__dict__) - diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index dbbef922e0..923e7331e8 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -1194,41 +1194,41 @@ def test_provider_parse_cors_string(self): result = [f for f in provider.get_all()] api1 = Api(path='/path2', method='POST', function_name='NoApiEventFunction', stage_name="Prod", - cors=Cors(allow_origin="*", allow_methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"]), + cors=Cors(allow_origin="*", allow_methods=','.join(["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"])), ) api2 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", - cors=Cors(allow_origin="*", allow_methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"]), + cors=Cors(allow_origin="*", allow_methods=','.join(["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"])), ) api3 = Api(path='/path', method='GET', function_name='NoApiEventFunction', stage_name="Prod", - cors=Cors(allow_origin="*", allow_methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"]), + cors=Cors(allow_origin="*", allow_methods=','.join(["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"])), ) api4 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", - cors=Cors(allow_origin="*", allow_methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"]), + cors=Cors(allow_origin="*", allow_methods=','.join(["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"])), ) self.assertEquals(len(result), 4) self.assertIn(api1, result) @@ -1288,25 +1288,25 @@ def test_provider_parse_cors_dict(self): api1 = Api(path='/path2', method='POST', function_name='NoApiEventFunction', stage_name="Prod", cors=Cors(allow_origin="*", - allow_methods=["POST", "OPTIONS"], + allow_methods=','.join(["POST", "OPTIONS"]), allow_headers="Upgrade-Insecure-Requests", max_age=600), ) api2 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", cors=Cors(allow_origin="*", - allow_methods=["POST", "OPTIONS"], + allow_methods=','.join(["POST", "OPTIONS"]), allow_headers="Upgrade-Insecure-Requests", max_age=600), ) api3 = Api(path='/path', method='POST', function_name='NoApiEventFunction', stage_name="Prod", cors=Cors(allow_origin="*", - allow_methods=["POST", "OPTIONS"], + allow_methods=','.join(["POST", "OPTIONS"]), allow_headers="Upgrade-Insecure-Requests", max_age=600), ) api4 = Api(path='/path', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", cors=Cors(allow_origin="*", - allow_methods=["POST", "OPTIONS"], + allow_methods=','.join(["POST", "OPTIONS"]), allow_headers="Upgrade-Insecure-Requests", max_age=600), ) @@ -1353,22 +1353,22 @@ def test_default_cors_dict_prop(self): result = [f for f in provider.get_all()] api1 = Api(path='/path2', method='GET', function_name='NoApiEventFunction', stage_name="Prod", - cors=Cors(allow_origin="www.domain.com", allow_methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"]), + cors=Cors(allow_origin="www.domain.com", allow_methods=','.join(["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"])), ) api2 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", - cors=Cors(allow_origin="www.domain.com", allow_methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"]), + cors=Cors(allow_origin="www.domain.com", allow_methods=','.join(["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"])), ) self.assertEquals(len(result), 2) self.assertIn(api1, result) @@ -1376,10 +1376,10 @@ def test_default_cors_dict_prop(self): def test_global_cors(self): template = { - "Global": { - "TestApi": { + "Globals": { + "Api": { "Cors": { - "AllowMethods": "POST", + "AllowMethods": "GET", "AllowOrigin": "*", "AllowHeaders": "Upgrade-Insecure-Requests", "MaxAge": 600 @@ -1390,10 +1390,11 @@ def test_global_cors(self): "TestApi": { "Type": "AWS::Serverless::Api", "Properties": { + "StageName": "Prod", "DefinitionBody": { "paths": { "/path2": { - "post": { + "get": { "x-amazon-apigateway-integration": { "type": "aws_proxy", "uri": { @@ -1429,24 +1430,26 @@ def test_global_cors(self): api1 = Api(path='/path2', method='GET', function_name='NoApiEventFunction', stage_name="Prod", cors=Cors(allow_origin="*", - allow_methods=["POST", "OPTIONS"], + allow_headers="Upgrade-Insecure-Requests", + allow_methods=','.join(["GET", "OPTIONS"]), max_age=600), ) api2 = Api(path='/path', method='GET', function_name='NoApiEventFunction', stage_name="Prod", cors=Cors(allow_origin="*", - allow_methods=["POST", "OPTIONS"], + allow_methods=','.join(["GET", "OPTIONS"]), allow_headers="Upgrade-Insecure-Requests", max_age=600), ) api3 = Api(path='/path2', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", cors=Cors(allow_origin="*", - allow_methods=["POST", "OPTIONS"], + allow_headers="Upgrade-Insecure-Requests", + allow_methods=','.join(["GET", "OPTIONS"]), max_age=600), ) api4 = Api(path='/path', method='OPTIONS', function_name='NoApiEventFunction', stage_name="Prod", cors=Cors(allow_origin="*", - allow_methods=["POST", "OPTIONS"], allow_headers="Upgrade-Insecure-Requests", + allow_methods=','.join(["GET", "OPTIONS"]), max_age=600), ) self.assertEquals(len(result), 4) diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index 4121d25040..a259026bf6 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -6,7 +6,7 @@ from parameterized import parameterized, param from werkzeug.datastructures import Headers -from commands.local.lib.provider import Cors +from samcli.commands.local.lib.provider import Cors from samcli.local.apigw.local_apigw_service import LocalApigwService, Route from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -154,6 +154,7 @@ def test_request_handles_error_when_invoke_cant_find_function(self, service_erro not_found_response_mock = Mock() self.service._construct_event = Mock() self.service._get_current_route = Mock() + service_error_responses_patch.lambda_not_found_response.return_value = not_found_response_mock self.lambda_runner.invoke.side_effect = FunctionNotFound() @@ -595,15 +596,17 @@ def test_should_base64_encode_returns_false(self, test_case_name, binary_types, class TestServiceCorsToHeaders(TestCase): def test_basic_conversion(self): - cors = Cors(allow_origin="*", allow_methods=["POST", "OPTIONS"], allow_headers="UPGRADE-HEADER", max_age=6) + cors = Cors(allow_origin="*", allow_methods=','.join(["POST", "OPTIONS"]), allow_headers="UPGRADE-HEADER", + max_age=6) headers = LocalApigwService.cors_to_headers(cors) + self.assertEquals(headers, {'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Methods': 'POST,OPTIONS', 'Access-Control-Allow-Headers': 'UPGRADE-HEADER', 'Access-Control-Max-Age': 6}) def test_empty_elements(self): - cors = Cors(allow_origin="www.domain.com", allow_methods=["GET", "POST", "OPTIONS"]) + cors = Cors(allow_origin="www.domain.com", allow_methods=','.join(["GET", "POST", "OPTIONS"])) headers = LocalApigwService.cors_to_headers(cors) + self.assertEquals(headers, {'Access-Control-Allow-Origin': 'www.domain.com', - 'Access-Control-Allow-Methods': 'GET,POST,OPTIONS', - 'Access-Control-Allow-Headers': None, 'Access-Control-Max-Age': None}) + 'Access-Control-Allow-Methods': 'GET,POST,OPTIONS'}) From d2fba5fda9c5e7391ff5ea483402684ad2ae76bc Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Thu, 27 Jun 2019 09:32:07 -0700 Subject: [PATCH 06/30] Fix headers so that it is returned with post/get requests --- samcli/local/apigw/local_apigw_service.py | 4 ++- .../local/start_api/test_start_api.py | 32 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index a95886aa43..0f718076e2 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -146,7 +146,7 @@ def _request_handler(self, **kwargs): route = self._get_current_route(request) cors_headers = LocalApigwService.cors_to_headers(route.cors) - if route.method == 'OPTIONS': + if 'OPTIONS' in route.methods: headers = Headers(cors_headers) return self.service_response('', headers, 200) @@ -174,6 +174,8 @@ def _request_handler(self, **kwargs): (status_code, headers, body) = self._parse_lambda_output(lambda_response, route.binary_types, request) + if cors_headers: + headers.extend(cors_headers) except (KeyError, TypeError, ValueError): LOG.error("Function returned an invalid response (must include one of: body, headers, multiValueHeaders or " "statusCode in the response object). Response received: %s", lambda_response) diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 02746860c4..45994c0473 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -575,6 +575,23 @@ def test_cors_swagger_options(self): self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS") self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510') + def test_cors_swagger_post(self): + """ + This tests that the service can accept and invoke a lambda when given binary data in a request + """ + input_data = self.get_binary_data(self.binary_data_file) + response = requests.post(self.url + '/echobase64eventbody', + headers={"Content-Type": "image/gif"}, + data=input_data) + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Content-Type"), "image/gif") + self.assertEquals(response.content, input_data) + self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") + self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), "origin, x-requested-with") + self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS") + self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510') + class TestServiceCorsGlobalRequests(StartApiIntegBaseClass): """ @@ -597,3 +614,18 @@ def test_cors_global(self): self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,DELETE,PUT,POST,HEAD,OPTIONS,PATCH") self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) + + def test_cors_global_get(self): + """ + This tests that the service can accept and invoke a lambda when given binary data in a request + """ + response = requests.get(self.url + "/onlysetstatuscode") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.content.decode('utf-8'), "no data") + self.assertEquals(response.headers.get("Content-Type"), "application/json") + self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") + self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None) + self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), + "GET,DELETE,PUT,POST,HEAD,OPTIONS,PATCH") + self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) \ No newline at end of file From 7d9f270ce8c2619a03f2051b0e7b622b166ea495 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Thu, 27 Jun 2019 12:48:13 -0700 Subject: [PATCH 07/30] Cleanup Tests and style --- .../local/start_api/test_start_api.py | 2 +- .../local/apigw/test_local_apigw_service.py | 28 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 45994c0473..b92e362f4e 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -628,4 +628,4 @@ def test_cors_global_get(self): self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None) self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,DELETE,PUT,POST,HEAD,OPTIONS,PATCH") - self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) \ No newline at end of file + self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index a259026bf6..b2c97297f9 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -1,5 +1,5 @@ from unittest import TestCase -from mock import Mock, patch, ANY +from mock import Mock, patch, ANY, MagicMock import json import base64 @@ -32,11 +32,12 @@ def test_request_must_invoke_lambda(self): make_response_mock = Mock() self.service.service_response = make_response_mock - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] self.service._construct_event = Mock() parse_output_mock = Mock() - parse_output_mock.return_value = ("status_code", "headers", "body") + parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") self.service._parse_lambda_output = parse_output_mock service_response_mock = Mock() @@ -56,11 +57,13 @@ def test_request_handler_returns_process_stdout_when_making_response(self, lambd make_response_mock = Mock() self.service.service_response = make_response_mock - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] + self.service._construct_event = Mock() parse_output_mock = Mock() - parse_output_mock.return_value = ("status_code", "headers", "body") + parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") self.service._parse_lambda_output = parse_output_mock lambda_logs = "logs" @@ -85,11 +88,12 @@ def test_request_handler_returns_make_response(self): make_response_mock = Mock() self.service.service_response = make_response_mock - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() self.service._construct_event = Mock() + self.service._get_current_route.methods = [] parse_output_mock = Mock() - parse_output_mock.return_value = ("status_code", "headers", "body") + parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") self.service._parse_lambda_output = parse_output_mock service_response_mock = Mock() @@ -153,7 +157,8 @@ def test_initalize_with_values(self): def test_request_handles_error_when_invoke_cant_find_function(self, service_error_responses_patch): not_found_response_mock = Mock() self.service._construct_event = Mock() - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] service_error_responses_patch.lambda_not_found_response.return_value = not_found_response_mock @@ -183,7 +188,8 @@ def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, s service_error_responses_patch.lambda_failure_response.return_value = failure_response_mock self.service._construct_event = Mock() - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] result = self.service._request_handler() @@ -202,7 +208,9 @@ def test_request_handler_errors_when_get_current_route_fails(self, service_error def test_request_handler_errors_when_unable_to_read_binary_data(self, service_error_responses_patch): _construct_event = Mock() _construct_event.side_effect = UnicodeDecodeError("utf8", b"obj", 1, 2, "reason") - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] + self.service._construct_event = _construct_event failure_mock = Mock() From 087c15208788d8071fa3e82aec6dc2b919d7fbba Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Thu, 27 Jun 2019 15:38:58 -0700 Subject: [PATCH 08/30] Update code with style comments --- samcli/local/apigw/local_apigw_service.py | 19 ++++++++----------- .../local/start_api/test_start_api.py | 12 ++++++------ 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 0f718076e2..9cbe68fa02 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -481,14 +481,11 @@ def cors_to_headers(cors): ------- Dictionary with CORS headers """ - headers = {} - if cors.allow_origin is not None: - headers['Access-Control-Allow-Origin'] = cors.allow_origin - if cors.allow_methods is not None: - headers['Access-Control-Allow-Methods'] = cors.allow_methods - if cors.allow_headers is not None: - headers['Access-Control-Allow-Headers'] = cors.allow_headers - if cors.max_age is not None: - headers['Access-Control-Max-Age'] = cors.max_age - - return headers + headers = { + 'Access-Control-Allow-Origin': cors.allow_origin, + 'Access-Control-Allow-Methods': cors.allow_methods, + 'Access-Control-Allow-Headers': cors.allow_headers, + 'Access-Control-Max-Age': cors.max_age + } + + return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None} diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index b92e362f4e..f24180d0fb 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -554,7 +554,7 @@ def test_swagger_stage_variable(self): class TestServiceCorsSwaggerRequests(StartApiIntegBaseClass): """ - Test to check that the correct headers are being added with cors + Test to check that the correct headers are being added with Cors with swagger code """ template_path = "/testdata/start_api/swagger-template.yaml" binary_data_file = "testdata/start_api/binarydata.gif" @@ -564,7 +564,7 @@ def setUp(self): def test_cors_swagger_options(self): """ - This tests that the service can accept and invoke a lambda when given binary data in a request + This tests that the Cors are added to option requests in the swagger template """ response = requests.options(self.url + '/echobase64eventbody') @@ -577,7 +577,7 @@ def test_cors_swagger_options(self): def test_cors_swagger_post(self): """ - This tests that the service can accept and invoke a lambda when given binary data in a request + This tests that the Cors are added to post requests in the swagger template """ input_data = self.get_binary_data(self.binary_data_file) response = requests.post(self.url + '/echobase64eventbody', @@ -595,7 +595,7 @@ def test_cors_swagger_post(self): class TestServiceCorsGlobalRequests(StartApiIntegBaseClass): """ - Test to check that the correct headers are being added with cors + Test to check that the correct headers are being added with Cors with the global property """ template_path = "/testdata/start_api/template.yaml" @@ -604,7 +604,7 @@ def setUp(self): def test_cors_global(self): """ - This tests that the service can accept and invoke a lambda when given binary data in a request + This tests that the Cors are added to options requests when the global property is set """ response = requests.options(self.url + '/echobase64eventbody') @@ -617,7 +617,7 @@ def test_cors_global(self): def test_cors_global_get(self): """ - This tests that the service can accept and invoke a lambda when given binary data in a request + This tests that the Cors are added to post requests when the global property is set """ response = requests.get(self.url + "/onlysetstatuscode") From 0f507516551f126a5db419d5db0eca5c26522d1c Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Tue, 16 Jul 2019 09:33:33 -0700 Subject: [PATCH 09/30] Run make pr and Fix merge errors --- samcli/commands/local/lib/api_collector.py | 5 +++ samcli/commands/local/lib/provider.py | 34 ++++++++++--------- samcli/commands/local/lib/sam_api_provider.py | 14 +++----- .../local/lib/test_sam_api_provider.py | 9 ++--- 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/samcli/commands/local/lib/api_collector.py b/samcli/commands/local/lib/api_collector.py index cbd198c6b7..275705f31c 100644 --- a/samcli/commands/local/lib/api_collector.py +++ b/samcli/commands/local/lib/api_collector.py @@ -113,6 +113,11 @@ def add_stage_variables(self, logical_id, stage_variables): properties = properties._replace(stage_variables=stage_variables) self._set_properties(logical_id, properties) + def add_cors(self, logical_id, cors): + properties = self._get_properties(logical_id) + properties = properties._replace(cors=cors) + self._set_properties(logical_id, properties) + def _get_apis_with_config(self, logical_id): """ Returns the list of APIs in this resource along with other extra configuration such as binary media types, diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 7cc87dfa8f..74c8a7e485 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -223,9 +223,9 @@ def get_all(self): "stage_variables" ]) _ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, + [], # binary_media_types is optional and defaults to empty, None, # Stage name is optional with default None - None # Stage variables is optional with default None + None # Stage variables is optional with default None ) @@ -240,7 +240,7 @@ def __hash__(self): _CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None None, # Allow Methods is optional and defaults to empty None, # Allow Headers is optional and defaults to empty - None # MaxAge is optional and defaults to empty + None # MaxAge is optional and defaults to empty ) @@ -254,13 +254,13 @@ class AbstractApiProvider(object): """ Abstract base class to return APIs and the functions they route to """ - _ANY_HTTP_METHODS = ["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"] + ANY_HTTP_METHODS = ["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"] def get_all(self): """ @@ -271,21 +271,23 @@ def get_all(self): raise NotImplementedError("not implemented") @staticmethod - def normalize_http_methods(http_method): + def normalize_http_methods(api): """ Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all supported Http Methods on Api Gateway. - - :param str http_method: Http method + :param api api: Api :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) """ - + http_method = api.method if http_method.upper() == 'ANY': - for method in AbstractApiProvider._ANY_HTTP_METHODS: + for method in AbstractApiProvider.ANY_HTTP_METHODS: yield method.upper() else: yield http_method.upper() + if api.cors and http_method.upper() != "OPTIONS": + yield "OPTIONS" + @staticmethod def normalize_apis(apis): """ @@ -304,7 +306,7 @@ def normalize_apis(apis): result = list() for api in apis: - for normalized_method in AbstractApiProvider.normalize_http_methods(api.method): + for normalized_method in AbstractApiProvider.normalize_http_methods(api): # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple result.append(api._replace(method=normalized_method)) diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index a64ee332ac..495bb5ddad 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -4,14 +4,10 @@ from six import string_types -from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.provider import ApiProvider, Api, Cors -from samcli.commands.local.lib.sam_base_provider import SamBaseProvider -from samcli.commands.local.lib.swagger.reader import SamSwaggerReader +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider from samcli.commands.local.lib.provider import Api, AbstractApiProvider - +from samcli.commands.local.lib.provider import Cors from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider LOG = logging.getLogger(__name__) @@ -109,7 +105,7 @@ def _extract_cors(self, properties): cors_prop = properties.get("Cors") cors = None if cors_prop and isinstance(cors_prop, dict): - allow_methods = cors_prop.get("AllowMethods", ','.join(SamApiProvider._ANY_HTTP_METHODS)) + allow_methods = cors_prop.get("AllowMethods", ','.join(AbstractApiProvider.ANY_HTTP_METHODS)) if allow_methods and "OPTIONS" not in allow_methods: allow_methods += ",OPTIONS" @@ -123,7 +119,7 @@ def _extract_cors(self, properties): elif cors_prop and isinstance(cors_prop, string_types): cors = Cors( allow_origin=cors_prop, - allow_methods=','.join(SamApiProvider._ANY_HTTP_METHODS), + allow_methods=','.join(AbstractApiProvider.ANY_HTTP_METHODS), allow_headers=None, max_age=None ) @@ -244,7 +240,7 @@ def merge_apis(collector): for config in all_configs: # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP # method on explicit API. - for normalized_method in AbstractApiProvider.normalize_http_methods(config.method): + for normalized_method in AbstractApiProvider.normalize_http_methods(config): key = config.path + normalized_method all_apis[key] = config diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 2422261e73..f22b433a5b 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -7,6 +7,7 @@ from six import assertCountEqual +from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.sam_api_provider import SamApiProvider from samcli.commands.local.lib.provider import Api, Cors from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException @@ -1167,7 +1168,7 @@ def test_provider_parse_cors_string(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] @@ -1260,7 +1261,7 @@ def test_provider_parse_cors_dict(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] @@ -1326,7 +1327,7 @@ def test_default_cors_dict_prop(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] @@ -1402,7 +1403,7 @@ def test_global_cors(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] From aa04f5e0a98753d467947ade09adb380ab69dba5 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Tue, 16 Jul 2019 14:09:20 -0700 Subject: [PATCH 10/30] Fix Merge Issue with ApiGateway RestApi ApiGateway::RestApi only supports within methods section. This requires parsing the requestParameters section of the ApiGateway resource. Since we currently don'd do this, the best functionality is to ignore cors in RestApi --- samcli/local/apigw/local_apigw_service.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 9cbe68fa02..0822fba65c 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -481,6 +481,8 @@ def cors_to_headers(cors): ------- Dictionary with CORS headers """ + if not cors: + return {} headers = { 'Access-Control-Allow-Origin': cors.allow_origin, 'Access-Control-Allow-Methods': cors.allow_methods, From 24d16edab2f6aac72e758b74fc589e50d09a5b7e Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Thu, 18 Jul 2019 14:46:02 -0700 Subject: [PATCH 11/30] Cleanup Cors class in provider --- samcli/commands/local/lib/provider.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 74c8a7e485..d005962c77 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -245,9 +245,7 @@ def __hash__(self): class Cors(_CorsTuple): - def __hash__(self): - # Other properties are not a part of the hash - return hash(self.allow_origin) * hash(self.allow_headers) * hash(self.allow_methods) * hash(self.max_age) + pass class AbstractApiProvider(object): @@ -285,8 +283,8 @@ def normalize_http_methods(api): else: yield http_method.upper() - if api.cors and http_method.upper() != "OPTIONS": - yield "OPTIONS" + if api.cors and http_method.upper() != "OPTIONS": + yield "OPTIONS" @staticmethod def normalize_apis(apis): From 49d16b4c0b6f8e610b7c0ebde033c94b79e2da91 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa <51216482+viksrivat@users.noreply.github.com> Date: Tue, 23 Jul 2019 11:23:18 -0700 Subject: [PATCH 12/30] feat(start-api): CloudFormation AWS::ApiGateway::RestApi support (#1238) --- samcli/commands/local/lib/api_collector.py | 215 ++++++++ samcli/commands/local/lib/api_provider.py | 94 ++++ samcli/commands/local/lib/cfn_api_provider.py | 69 +++ .../local/lib/cfn_base_api_provider.py | 70 +++ .../commands/local/lib/local_api_service.py | 22 +- samcli/commands/local/lib/provider.py | 57 ++- samcli/commands/local/lib/sam_api_provider.py | 462 +++--------------- .../commands/local/lib/sam_base_provider.py | 3 +- .../local/start_api/test_start_api.py | 106 ++++ .../start_api/swagger-rest-api-template.yaml | 69 +++ .../commands/local/lib/test_api_provider.py | 207 ++++++++ .../local/lib/test_cfn_api_provider.py | 215 ++++++++ .../local/lib/test_local_api_service.py | 4 +- .../local/lib/test_sam_api_provider.py | 84 ++-- 14 files changed, 1218 insertions(+), 459 deletions(-) create mode 100644 samcli/commands/local/lib/api_collector.py create mode 100644 samcli/commands/local/lib/api_provider.py create mode 100644 samcli/commands/local/lib/cfn_api_provider.py create mode 100644 samcli/commands/local/lib/cfn_base_api_provider.py create mode 100644 tests/integration/testdata/start_api/swagger-rest-api-template.yaml create mode 100644 tests/unit/commands/local/lib/test_api_provider.py create mode 100644 tests/unit/commands/local/lib/test_cfn_api_provider.py diff --git a/samcli/commands/local/lib/api_collector.py b/samcli/commands/local/lib/api_collector.py new file mode 100644 index 0000000000..cbd198c6b7 --- /dev/null +++ b/samcli/commands/local/lib/api_collector.py @@ -0,0 +1,215 @@ +""" +Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit +APIs in a standardized format +""" + +import logging +from collections import namedtuple + +from six import string_types + +LOG = logging.getLogger(__name__) + + +class ApiCollector(object): + # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. + # This is intentional because it allows us to easily extend this class to support future properties on the API. + # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit + # APIs. + Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) + + def __init__(self): + # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and + # value is the properties + self.by_resource = {} + + def __iter__(self): + """ + Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the + LogicalId of the API resource and a list of APIs available in this resource. + + Yields + ------- + str + LogicalID of the AWS::Serverless::Api resource + list samcli.commands.local.lib.provider.Api + List of the API available in this resource along with additional configuration like binary media types. + """ + + for logical_id, _ in self.by_resource.items(): + yield logical_id, self._get_apis_with_config(logical_id) + + def add_apis(self, logical_id, apis): + """ + Stores the given APIs tagged under the given logicalId + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + apis : list of samcli.commands.local.lib.provider.Api + List of APIs available in this resource + """ + properties = self._get_properties(logical_id) + properties.apis.extend(apis) + + def add_binary_media_types(self, logical_id, binary_media_types): + """ + Stores the binary media type configuration for the API with given logical ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + binary_media_types : list of str + List of binary media types supported by this resource + + """ + properties = self._get_properties(logical_id) + + binary_media_types = binary_media_types or [] + for value in binary_media_types: + normalized_value = self._normalize_binary_media_type(value) + + # If the value is not supported, then just skip it. + if normalized_value: + properties.binary_media_types.add(normalized_value) + else: + LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) + + def add_stage_name(self, logical_id, stage_name): + """ + Stores the stage name for the API with the given local ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + stage_name : str + The stage_name string + + """ + properties = self._get_properties(logical_id) + properties = properties._replace(stage_name=stage_name) + self._set_properties(logical_id, properties) + + def add_stage_variables(self, logical_id, stage_variables): + """ + Stores the stage variables for the API with the given local ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + stage_variables : dict + A dictionary containing stage variables. + + """ + properties = self._get_properties(logical_id) + properties = properties._replace(stage_variables=stage_variables) + self._set_properties(logical_id, properties) + + def _get_apis_with_config(self, logical_id): + """ + Returns the list of APIs in this resource along with other extra configuration such as binary media types, + cors etc. Additional configuration is merged directly into the API data because these properties, although + defined globally, actually apply to each API. + + Parameters + ---------- + logical_id : str + Logical ID of the resource to fetch data for + + Returns + ------- + list of samcli.commands.local.lib.provider.Api + List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, + then it returns an empty list + """ + + properties = self._get_properties(logical_id) + + # These configs need to be applied to each API + binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable + cors = properties.cors + stage_name = properties.stage_name + stage_variables = properties.stage_variables + + result = [] + for api in properties.apis: + # Create a copy of the API with updated configuration + updated_api = api._replace(binary_media_types=binary_media, + cors=cors, + stage_name=stage_name, + stage_variables=stage_variables) + result.append(updated_api) + + return result + + def _get_properties(self, logical_id): + """ + Returns the properties of resource with given logical ID. If a resource is not found, then it returns an + empty data. + + Parameters + ---------- + logical_id : str + Logical ID of the resource + + Returns + ------- + samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties + Properties object for this resource. + """ + + if logical_id not in self.by_resource: + self.by_resource[logical_id] = self.Properties(apis=[], + # Use a set() to be able to easily de-dupe + binary_media_types=set(), + cors=None, + stage_name=None, + stage_variables=None) + + return self.by_resource[logical_id] + + def _set_properties(self, logical_id, properties): + """ + Sets the properties of resource with given logical ID. If a resource is not found, it does nothing + + Parameters + ---------- + logical_id : str + Logical ID of the resource + properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties + Properties object for this resource. + """ + + if logical_id in self.by_resource: + self.by_resource[logical_id] = properties + + @staticmethod + def _normalize_binary_media_type(value): + """ + Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not + a string, then this method just returns None + + Parameters + ---------- + value : str + Value to be normalized + + Returns + ------- + str or None + Normalized value. If the input was not a string, then None is returned + """ + + if not isinstance(value, string_types): + # It is possible that user specified a dict value for one of the binary media types. We just skip them + return None + + return value.replace("~1", "/") diff --git a/samcli/commands/local/lib/api_provider.py b/samcli/commands/local/lib/api_provider.py new file mode 100644 index 0000000000..afc686e166 --- /dev/null +++ b/samcli/commands/local/lib/api_provider.py @@ -0,0 +1,94 @@ +"""Class that provides Apis from a SAM Template""" + +import logging + +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider +from samcli.commands.local.lib.api_collector import ApiCollector +from samcli.commands.local.lib.provider import AbstractApiProvider +from samcli.commands.local.lib.sam_base_provider import SamBaseProvider +from samcli.commands.local.lib.sam_api_provider import SamApiProvider +from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider + +LOG = logging.getLogger(__name__) + + +class ApiProvider(AbstractApiProvider): + + def __init__(self, template_dict, parameter_overrides=None, cwd=None): + """ + Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed + to be valid, normalized and a dictionary. template_dict should be normalized by running any and all + pre-processing before passing to this class. + This class does not perform any syntactic validation of the template. + + After the class is initialized, changes to ``template_dict`` will not be reflected in here. + You will need to explicitly update the class with new template, if necessary. + + Parameters + ---------- + template_dict : dict + SAM Template as a dictionary + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + """ + self.template_dict = SamBaseProvider.get_template(template_dict, parameter_overrides) + self.resources = self.template_dict.get("Resources", {}) + + LOG.debug("%d resources found in the template", len(self.resources)) + + # Store a set of apis + self.cwd = cwd + self.apis = self._extract_apis(self.resources) + + LOG.debug("%d APIs found in the template", len(self.apis)) + + def get_all(self): + """ + Yields all the Lambda functions with Api Events available in the SAM Template. + + :yields Api: namedtuple containing the Api information + """ + + for api in self.apis: + yield api + + def _extract_apis(self, resources): + """ + Extracts all the Apis by running through the one providers. The provider that has the first type matched + will be run across all the resources + + Parameters + ---------- + resources: dict + The dictionary containing the different resources within the template + Returns + --------- + list of Apis extracted from the resources + """ + collector = ApiCollector() + provider = self.find_api_provider(resources) + apis = provider.extract_resource_api(resources, collector, cwd=self.cwd) + return self.normalize_apis(apis) + + @staticmethod + def find_api_provider(resources): + """ + Finds the ApiProvider given the first api type of the resource + + Parameters + ----------- + resources: dict + The dictionary containing the different resources within the template + + Return + ---------- + Instance of the ApiProvider that will be run on the template with a default of SamApiProvider + """ + for _, resource in resources.items(): + if resource.get(CfnBaseApiProvider.RESOURCE_TYPE) in SamApiProvider.TYPES: + return SamApiProvider() + elif resource.get(CfnBaseApiProvider.RESOURCE_TYPE) in CfnApiProvider.TYPES: + return CfnApiProvider() + + return SamApiProvider() diff --git a/samcli/commands/local/lib/cfn_api_provider.py b/samcli/commands/local/lib/cfn_api_provider.py new file mode 100644 index 0000000000..0e3919611c --- /dev/null +++ b/samcli/commands/local/lib/cfn_api_provider.py @@ -0,0 +1,69 @@ +"""Parses SAM given a template""" +import logging + +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider + +LOG = logging.getLogger(__name__) + + +class CfnApiProvider(CfnBaseApiProvider): + APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" + TYPES = [ + APIGATEWAY_RESTAPI + ] + + def extract_resource_api(self, resources, collector, cwd=None): + """ + Extract the Api Object from a given resource and adds it to the ApiCollector. + + Parameters + ---------- + resources: dict + The dictionary containing the different resources within the template + + collector: ApiCollector + Instance of the API collector that where we will save the API information + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + + Return + ------- + Returns a list of Apis + """ + for logical_id, resource in resources.items(): + resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) + if resource_type == CfnApiProvider.APIGATEWAY_RESTAPI: + self._extract_cloud_formation_api(logical_id, resource, collector, cwd) + all_apis = [] + for _, apis in collector: + all_apis.extend(apis) + return all_apis + + def _extract_cloud_formation_api(self, logical_id, api_resource, collector, cwd=None): + """ + Extract APIs from AWS::ApiGateway::RestApi resource by reading and parsing Swagger documents. The result is + added to the collector. + + Parameters + ---------- + logical_id : str + Logical ID of the resource + + api_resource : dict + Resource definition, including its properties + + collector : ApiCollector + Instance of the API collector that where we will save the API information + """ + properties = api_resource.get("Properties", {}) + body = properties.get("Body") + body_s3_location = properties.get("BodyS3Location") + binary_media = properties.get("BinaryMediaTypes", []) + + if not body and not body_s3_location: + # Swagger is not found anywhere. + LOG.debug("Skipping resource '%s'. Swagger document not found in Body and BodyS3Location", + logical_id) + return + self.extract_swagger_api(logical_id, body, body_s3_location, binary_media, collector, cwd) diff --git a/samcli/commands/local/lib/cfn_base_api_provider.py b/samcli/commands/local/lib/cfn_base_api_provider.py new file mode 100644 index 0000000000..79bc6d8f1d --- /dev/null +++ b/samcli/commands/local/lib/cfn_base_api_provider.py @@ -0,0 +1,70 @@ +"""Class that parses the CloudFormation Api Template""" + +import logging + +from samcli.commands.local.lib.swagger.parser import SwaggerParser +from samcli.commands.local.lib.swagger.reader import SamSwaggerReader + +LOG = logging.getLogger(__name__) + + +class CfnBaseApiProvider(object): + RESOURCE_TYPE = "Type" + + def extract_resource_api(self, resources, collector, cwd=None): + """ + Extract the Api Object from a given resource and adds it to the ApiCollector. + + Parameters + ---------- + resources: dict + The dictionary containing the different resources within the template + + collector: ApiCollector + Instance of the API collector that where we will save the API information + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + + Return + ------- + Returns a list of Apis + """ + raise NotImplementedError("not implemented") + + @staticmethod + def extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd=None): + """ + Parse the Swagger documents and adds it to the ApiCollector. + + Parameters + ---------- + logical_id : str + Logical ID of the resource + + body : dict + The body of the RestApi + + uri : str or dict + The url to location of the RestApi + + binary_media: list + The link to the binary media + + collector: ApiCollector + Instance of the API collector that where we will save the API information + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + """ + reader = SamSwaggerReader(definition_body=body, + definition_uri=uri, + working_dir=cwd) + swagger = reader.read() + parser = SwaggerParser(swagger) + apis = parser.get_apis() + LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) + + collector.add_apis(logical_id, apis) + collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger + collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template diff --git a/samcli/commands/local/lib/local_api_service.py b/samcli/commands/local/lib/local_api_service.py index d0ebbbd975..d456e67a83 100644 --- a/samcli/commands/local/lib/local_api_service.py +++ b/samcli/commands/local/lib/local_api_service.py @@ -6,7 +6,7 @@ import logging from samcli.local.apigw.local_apigw_service import LocalApigwService, Route -from samcli.commands.local.lib.sam_api_provider import SamApiProvider +from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined LOG = logging.getLogger(__name__) @@ -38,9 +38,9 @@ def __init__(self, self.static_dir = static_dir self.cwd = lambda_invoke_context.get_cwd() - self.api_provider = SamApiProvider(lambda_invoke_context.template, - parameter_overrides=lambda_invoke_context.parameter_overrides, - cwd=self.cwd) + self.api_provider = ApiProvider(lambda_invoke_context.template, + parameter_overrides=lambda_invoke_context.parameter_overrides, + cwd=self.cwd) self.lambda_runner = lambda_invoke_context.local_lambda_runner self.stderr_stream = lambda_invoke_context.stderr @@ -89,7 +89,7 @@ def _make_routing_list(api_provider): Parameters ---------- - api_provider : samcli.commands.local.lib.sam_api_provider.SamApiProvider + api_provider : samcli.commands.local.lib.api_provider.ApiProvider Returns ------- @@ -116,10 +116,14 @@ def _print_routes(api_provider, host, port): Mounting Product at http://127.0.0.1:3000/path1/bar [GET, POST, DELETE] Mounting Product at http://127.0.0.1:3000/path2/bar [HEAD] - :param samcli.commands.local.lib.provider.ApiProvider api_provider: API Provider that can return a list of APIs - :param string host: Host name where the service is running - :param int port: Port number where the service is running - :returns list(string): List of lines that were printed to the console. Helps with testing + :param samcli.commands.local.lib.provider.AbstractApiProvider api_provider: + API Provider that can return a list of APIs + :param string host: + Host name where the service is running + :param int port: + Port number where the service is running + :returns list(string): + List of lines that were printed to the console. Helps with testing """ grouped_api_configs = {} diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index eead981089..959166e814 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -222,10 +222,10 @@ def get_all(self): # The variables for that stage "stage_variables" ]) -_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, - None, # Stage name is optional with default None - None # Stage variables is optional with default None +_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None + [], # binary_media_types is optional and defaults to empty, + None, # Stage name is optional with default None + None # Stage variables is optional with default None ) @@ -238,10 +238,17 @@ def __hash__(self): Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) -class ApiProvider(object): +class AbstractApiProvider(object): """ Abstract base class to return APIs and the functions they route to """ + _ANY_HTTP_METHODS = ["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"] def get_all(self): """ @@ -250,3 +257,43 @@ def get_all(self): :yields Api: namedtuple containing the API information """ raise NotImplementedError("not implemented") + + @staticmethod + def normalize_http_methods(http_method): + """ + Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all + supported Http Methods on Api Gateway. + + :param str http_method: Http method + :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) + """ + + if http_method.upper() == 'ANY': + for method in AbstractApiProvider._ANY_HTTP_METHODS: + yield method.upper() + else: + yield http_method.upper() + + @staticmethod + def normalize_apis(apis): + """ + Normalize the APIs to use standard method name + + Parameters + ---------- + apis : list of samcli.commands.local.lib.provider.Api + List of APIs to replace normalize + + Returns + ------- + list of samcli.commands.local.lib.provider.Api + List of normalized APIs + """ + + result = list() + for api in apis: + for normalized_method in AbstractApiProvider.normalize_http_methods(api.method): + # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple + result.append(api._replace(method=normalized_method)) + + return result diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 84336a8d2d..f0ec57b823 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -1,111 +1,60 @@ -"""Class that provides Apis from a SAM Template""" +"""Parses SAM given the template""" import logging -from collections import namedtuple -from six import string_types - -from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.provider import ApiProvider, Api -from samcli.commands.local.lib.sam_base_provider import SamBaseProvider -from samcli.commands.local.lib.swagger.reader import SamSwaggerReader +from samcli.commands.local.lib.provider import Api, AbstractApiProvider from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider LOG = logging.getLogger(__name__) -class SamApiProvider(ApiProvider): - _IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" - _SERVERLESS_FUNCTION = "AWS::Serverless::Function" - _SERVERLESS_API = "AWS::Serverless::Api" - _TYPE = "Type" - +class SamApiProvider(CfnBaseApiProvider): + SERVERLESS_FUNCTION = "AWS::Serverless::Function" + SERVERLESS_API = "AWS::Serverless::Api" + TYPES = [ + SERVERLESS_FUNCTION, + SERVERLESS_API + ] _FUNCTION_EVENT_TYPE_API = "Api" _FUNCTION_EVENT = "Events" _EVENT_PATH = "Path" _EVENT_METHOD = "Method" + _EVENT_TYPE = "Type" + IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" - _ANY_HTTP_METHODS = ["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"] - - def __init__(self, template_dict, parameter_overrides=None, cwd=None): + def extract_resource_api(self, resources, collector, cwd=None): """ - Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed - to be valid, normalized and a dictionary. template_dict should be normalized by running any and all - pre-processing before passing to this class. - This class does not perform any syntactic validation of the template. - - After the class is initialized, changes to ``template_dict`` will not be reflected in here. - You will need to explicitly update the class with new template, if necessary. + Extract the Api Object from a given resource and adds it to the ApiCollector. Parameters ---------- - template_dict : dict - SAM Template as a dictionary - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - """ - - self.template_dict = SamBaseProvider.get_template(template_dict, parameter_overrides) - self.resources = self.template_dict.get("Resources", {}) - - LOG.debug("%d resources found in the template", len(self.resources)) - - # Store a set of apis - self.cwd = cwd - self.apis = self._extract_apis(self.resources) - - LOG.debug("%d APIs found in the template", len(self.apis)) - - def get_all(self): - """ - Yields all the Lambda functions with Api Events available in the SAM Template. + resources: dict + The dictionary containing the different resources within the template - :yields Api: namedtuple containing the Api information - """ - - for api in self.apis: - yield api + collector: ApiCollector + Instance of the API collector that where we will save the API information - def _extract_apis(self, resources): - """ - Extract all Implicit Apis (Apis defined through Serverless Function with an Api Event + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file - :param dict resources: Dictionary of SAM/CloudFormation resources - :return: List of nametuple Api + Return + ------- + Returns a list of Apis """ - - # Some properties like BinaryMediaTypes, Cors are set once on the resource but need to be applied to each API. - # For Implicit APIs, which are defined on the Function resource, these properties - # are defined on a AWS::Serverless::Api resource with logical ID "ServerlessRestApi". Therefore, no matter - # if it is an implicit API or an explicit API, there is a corresponding resource of type AWS::Serverless::Api - # that contains these additional configurations. - # - # We use this assumption in the following loop to collect information from resources of type - # AWS::Serverless::Api. We also extract API from Serverless::Function resource and add them to the - # corresponding Serverless::Api resource. This is all done using the ``collector``. - - collector = ApiCollector() - + # AWS::Serverless::Function is currently included when parsing of Apis because when SamBaseProvider is run on + # the template we are creating the implicit apis due to plugins that translate it in the SAM repo, + # which we later merge with the explicit ones in SamApiProvider.merge_apis. This requires the code to be + # parsed here and in InvokeContext. for logical_id, resource in resources.items(): - - resource_type = resource.get(SamApiProvider._TYPE) - - if resource_type == SamApiProvider._SERVERLESS_FUNCTION: + resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) + if resource_type == SamApiProvider.SERVERLESS_FUNCTION: self._extract_apis_from_function(logical_id, resource, collector) + if resource_type == SamApiProvider.SERVERLESS_API: + self._extract_from_serverless_api(logical_id, resource, collector, cwd) + return self.merge_apis(collector) - if resource_type == SamApiProvider._SERVERLESS_API: - self._extract_from_serverless_api(logical_id, resource, collector) - - apis = SamApiProvider._merge_apis(collector) - return self._normalize_apis(apis) - - def _extract_from_serverless_api(self, logical_id, api_resource, collector): + def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=None): """ Extract APIs from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result is added to the collector. @@ -134,99 +83,11 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri", logical_id) return - - reader = SamSwaggerReader(definition_body=body, - definition_uri=uri, - working_dir=self.cwd) - swagger = reader.read() - parser = SwaggerParser(swagger) - apis = parser.get_apis() - LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) - - collector.add_apis(logical_id, apis) - collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger - collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template - + self.extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd) collector.add_stage_name(logical_id, stage_name) collector.add_stage_variables(logical_id, stage_variables) - @staticmethod - def _merge_apis(collector): - """ - Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API - definition wins because that conveys clear intent that the API is backed by a function. This method will - merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined - in both the places, only one wins. - - Parameters - ---------- - collector : ApiCollector - Collector object that holds all the APIs specified in the template - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of APIs obtained by combining both the input lists. - """ - - implicit_apis = [] - explicit_apis = [] - - # Store implicit and explicit APIs separately in order to merge them later in the correct order - # Implicit APIs are defined on a resource with logicalID ServerlessRestApi - for logical_id, apis in collector: - if logical_id == SamApiProvider._IMPLICIT_API_RESOURCE_ID: - implicit_apis.extend(apis) - else: - explicit_apis.extend(apis) - - # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. - # If an path+method combo already exists, then overwrite it if and only if this is an implicit API - all_apis = {} - - # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already - # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. - all_configs = explicit_apis + implicit_apis - - for config in all_configs: - # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP - # method on explicit API. - for normalized_method in SamApiProvider._normalize_http_methods(config.method): - key = config.path + normalized_method - all_apis[key] = config - - result = set(all_apis.values()) # Assign to a set() to de-dupe - LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", - len(explicit_apis), len(implicit_apis), len(result)) - - return list(result) - - @staticmethod - def _normalize_apis(apis): - """ - Normalize the APIs to use standard method name - - Parameters - ---------- - apis : list of samcli.commands.local.lib.provider.Api - List of APIs to replace normalize - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of normalized APIs - """ - - result = list() - for api in apis: - for normalized_method in SamApiProvider._normalize_http_methods(api.method): - # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple - result.append(api._replace(method=normalized_method)) - - return result - - @staticmethod - def _extract_apis_from_function(logical_id, function_resource, collector): + def _extract_apis_from_function(self, logical_id, function_resource, collector): """ Fetches a list of APIs configured for this SAM Function resource. @@ -243,11 +104,10 @@ def _extract_apis_from_function(logical_id, function_resource, collector): """ resource_properties = function_resource.get("Properties", {}) - serverless_function_events = resource_properties.get(SamApiProvider._FUNCTION_EVENT, {}) - SamApiProvider._extract_apis_from_events(logical_id, serverless_function_events, collector) + serverless_function_events = resource_properties.get(self._FUNCTION_EVENT, {}) + self.extract_apis_from_events(logical_id, serverless_function_events, collector) - @staticmethod - def _extract_apis_from_events(function_logical_id, serverless_function_events, collector): + def extract_apis_from_events(self, function_logical_id, serverless_function_events, collector): """ Given an AWS::Serverless::Function Event Dictionary, extract out all 'Api' events and store within the collector @@ -266,8 +126,8 @@ def _extract_apis_from_events(function_logical_id, serverless_function_events, c count = 0 for _, event in serverless_function_events.items(): - if SamApiProvider._FUNCTION_EVENT_TYPE_API == event.get(SamApiProvider._TYPE): - api_resource_id, api = SamApiProvider._convert_event_api(function_logical_id, event.get("Properties")) + if self._FUNCTION_EVENT_TYPE_API == event.get(self._EVENT_TYPE): + api_resource_id, api = self._convert_event_api(function_logical_id, event.get("Properties")) collector.add_apis(api_resource_id, [api]) count += 1 @@ -288,7 +148,7 @@ def _convert_event_api(lambda_logical_id, event_properties): # An API Event, can have RestApiId property which designates the resource that owns this API. If omitted, # the API is owned by Implicit API resource. This could either be a direct resource logical ID or a # "Ref" of the logicalID - api_resource_id = event_properties.get("RestApiId", SamApiProvider._IMPLICIT_API_RESOURCE_ID) + api_resource_id = event_properties.get("RestApiId", SamApiProvider.IMPLICIT_API_RESOURCE_ID) if isinstance(api_resource_id, dict) and "Ref" in api_resource_id: api_resource_id = api_resource_id["Ref"] @@ -302,226 +162,52 @@ def _convert_event_api(lambda_logical_id, event_properties): return api_resource_id, Api(path=path, method=method, function_name=lambda_logical_id) @staticmethod - def _normalize_http_methods(http_method): - """ - Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all - supported Http Methods on Api Gateway. - - :param str http_method: Http method - :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) - """ - - if http_method.upper() == 'ANY': - for method in SamApiProvider._ANY_HTTP_METHODS: - yield method.upper() - else: - yield http_method.upper() - - -class ApiCollector(object): - """ - Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit - APIs in a standardized format - """ - - # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. - # This is intentional because it allows us to easily extend this class to support future properties on the API. - # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit - # APIs. - Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) - - def __init__(self): - # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and - # value is the properties - self.by_resource = {} - - def __iter__(self): - """ - Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the - LogicalId of the API resource and a list of APIs available in this resource. - - Yields - ------- - str - LogicalID of the AWS::Serverless::Api resource - list samcli.commands.local.lib.provider.Api - List of the API available in this resource along with additional configuration like binary media types. - """ - - for logical_id, _ in self.by_resource.items(): - yield logical_id, self._get_apis_with_config(logical_id) - - def add_apis(self, logical_id, apis): - """ - Stores the given APIs tagged under the given logicalId - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - apis : list of samcli.commands.local.lib.provider.Api - List of APIs available in this resource - """ - properties = self._get_properties(logical_id) - properties.apis.extend(apis) - - def add_binary_media_types(self, logical_id, binary_media_types): - """ - Stores the binary media type configuration for the API with given logical ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - binary_media_types : list of str - List of binary media types supported by this resource - - """ - properties = self._get_properties(logical_id) - - binary_media_types = binary_media_types or [] - for value in binary_media_types: - normalized_value = self._normalize_binary_media_type(value) - - # If the value is not supported, then just skip it. - if normalized_value: - properties.binary_media_types.add(normalized_value) - else: - LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) - - def add_stage_name(self, logical_id, stage_name): - """ - Stores the stage name for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_name : str - The stage_name string - - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_name=stage_name) - self._set_properties(logical_id, properties) - - def add_stage_variables(self, logical_id, stage_variables): - """ - Stores the stage variables for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_variables : dict - A dictionary containing stage variables. - - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_variables=stage_variables) - self._set_properties(logical_id, properties) - - def _get_apis_with_config(self, logical_id): + def merge_apis(collector): """ - Returns the list of APIs in this resource along with other extra configuration such as binary media types, - cors etc. Additional configuration is merged directly into the API data because these properties, although - defined globally, actually apply to each API. + Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API + definition wins because that conveys clear intent that the API is backed by a function. This method will + merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined + in both the places, only one wins. Parameters ---------- - logical_id : str - Logical ID of the resource to fetch data for + collector : ApiCollector + Collector object that holds all the APIs specified in the template Returns ------- list of samcli.commands.local.lib.provider.Api - List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, - then it returns an empty list - """ - - properties = self._get_properties(logical_id) - - # These configs need to be applied to each API - binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable - cors = properties.cors - stage_name = properties.stage_name - stage_variables = properties.stage_variables - - result = [] - for api in properties.apis: - # Create a copy of the API with updated configuration - updated_api = api._replace(binary_media_types=binary_media, - cors=cors, - stage_name=stage_name, - stage_variables=stage_variables) - result.append(updated_api) - - return result - - def _get_properties(self, logical_id): - """ - Returns the properties of resource with given logical ID. If a resource is not found, then it returns an - empty data. - - Parameters - ---------- - logical_id : str - Logical ID of the resource - - Returns - ------- - samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ - - if logical_id not in self.by_resource: - self.by_resource[logical_id] = self.Properties(apis=[], - # Use a set() to be able to easily de-dupe - binary_media_types=set(), - cors=None, - stage_name=None, - stage_variables=None) - - return self.by_resource[logical_id] - - def _set_properties(self, logical_id, properties): + List of APIs obtained by combining both the input lists. """ - Sets the properties of resource with given logical ID. If a resource is not found, it does nothing - Parameters - ---------- - logical_id : str - Logical ID of the resource - properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ + implicit_apis = [] + explicit_apis = [] - if logical_id in self.by_resource: - self.by_resource[logical_id] = properties + # Store implicit and explicit APIs separately in order to merge them later in the correct order + # Implicit APIs are defined on a resource with logicalID ServerlessRestApi + for logical_id, apis in collector: + if logical_id == SamApiProvider.IMPLICIT_API_RESOURCE_ID: + implicit_apis.extend(apis) + else: + explicit_apis.extend(apis) - @staticmethod - def _normalize_binary_media_type(value): - """ - Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not - a string, then this method just returns None + # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. + # If an path+method combo already exists, then overwrite it if and only if this is an implicit API + all_apis = {} - Parameters - ---------- - value : str - Value to be normalized + # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already + # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. + all_configs = explicit_apis + implicit_apis - Returns - ------- - str or None - Normalized value. If the input was not a string, then None is returned - """ + for config in all_configs: + # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP + # method on explicit API. + for normalized_method in AbstractApiProvider.normalize_http_methods(config.method): + key = config.path + normalized_method + all_apis[key] = config - if not isinstance(value, string_types): - # It is possible that user specified a dict value for one of the binary media types. We just skip them - return None + result = set(all_apis.values()) # Assign to a set() to de-dupe + LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", + len(explicit_apis), len(implicit_apis), len(result)) - return value.replace("~1", "/") + return list(result) diff --git a/samcli/commands/local/lib/sam_base_provider.py b/samcli/commands/local/lib/sam_base_provider.py index bbf4d6381b..861e1fd47a 100644 --- a/samcli/commands/local/lib/sam_base_provider.py +++ b/samcli/commands/local/lib/sam_base_provider.py @@ -10,7 +10,6 @@ from samcli.lib.samlib.wrapper import SamTranslatorWrapper from samcli.lib.samlib.resource_metadata_normalizer import ResourceMetadataNormalizer - LOG = logging.getLogger(__name__) @@ -89,7 +88,7 @@ def _resolve_parameters(template_dict, parameter_overrides): supported_intrinsics = {action.intrinsic_name: action() for action in SamBaseProvider._SUPPORTED_INTRINSICS} # Intrinsics resolver will mutate the original template - return IntrinsicsResolver(parameters=parameter_values, supported_intrinsics=supported_intrinsics)\ + return IntrinsicsResolver(parameters=parameter_values, supported_intrinsics=supported_intrinsics) \ .resolve_parameter_refs(template_dict) @staticmethod diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 700491260d..321741e0bf 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -288,6 +288,112 @@ def test_binary_response(self): self.assertEquals(response.content, expected) +class TestStartApiWithSwaggerRestApis(StartApiIntegBaseClass): + template_path = "/testdata/start_api/swagger-rest-api-template.yaml" + binary_data_file = "testdata/start_api/binarydata.gif" + + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + + def test_get_call_with_path_setup_with_any_swagger(self): + """ + Get Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.get(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_post_call_with_path_setup_with_any_swagger(self): + """ + Post Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.post(self.url + "/anyandall", json={}) + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_put_call_with_path_setup_with_any_swagger(self): + """ + Put Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.put(self.url + "/anyandall", json={}) + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_head_call_with_path_setup_with_any_swagger(self): + """ + Head Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.head(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + + def test_delete_call_with_path_setup_with_any_swagger(self): + """ + Delete Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.delete(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_options_call_with_path_setup_with_any_swagger(self): + """ + Options Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.options(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + + def test_patch_call_with_path_setup_with_any_swagger(self): + """ + Patch Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.patch(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_function_not_defined_in_template(self): + response = requests.get(self.url + "/nofunctionfound") + + self.assertEquals(response.status_code, 502) + self.assertEquals(response.json(), {"message": "No function defined for resource method"}) + + def test_lambda_function_resource_is_reachable(self): + response = requests.get(self.url + "/nonserverlessfunction") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_binary_request(self): + """ + This tests that the service can accept and invoke a lambda when given binary data in a request + """ + input_data = self.get_binary_data(self.binary_data_file) + response = requests.post(self.url + '/echobase64eventbody', + headers={"Content-Type": "image/gif"}, + data=input_data) + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Content-Type"), "image/gif") + self.assertEquals(response.content, input_data) + + def test_binary_response(self): + """ + Binary data is returned correctly + """ + expected = self.get_binary_data(self.binary_data_file) + + response = requests.get(self.url + '/base64response') + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Content-Type"), "image/gif") + self.assertEquals(response.content, expected) + + class TestServiceResponses(StartApiIntegBaseClass): """ Test Class centered around the different responses that can happen in Lambda and pass through start-api diff --git a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml new file mode 100644 index 0000000000..5edeb8717f --- /dev/null +++ b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml @@ -0,0 +1,69 @@ +AWSTemplateFormatVersion: '2010-09-09' + +Resources: + Base64ResponseFunction: + Properties: + Code: "." + Handler: main.base64_response + Runtime: python3.6 + Type: AWS::Lambda::Function + EchoBase64EventBodyFunction: + Properties: + Code: "." + Handler: main.echo_base64_event_body + Runtime: python3.6 + Type: AWS::Lambda::Function + MyApi: + Properties: + Body: + info: + title: + Ref: AWS::StackName + paths: + "/anyandall": + x-amazon-apigateway-any-method: + x-amazon-apigateway-integration: + httpMethod: POST + responses: {} + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyNonServerlessLambdaFunction.Arn}/invocations + "/base64response": + get: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${Base64ResponseFunction.Arn}/invocations + "/echobase64eventbody": + post: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EchoBase64EventBodyFunction.Arn}/invocations + "/nofunctionfound": + get: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${WhatFunction.Arn}/invocations + "/nonserverlessfunction": + get: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyNonServerlessLambdaFunction.Arn}/invocations + swagger: '2.0' + x-amazon-apigateway-binary-media-types: + - image/gif + StageName: prod + Type: AWS::ApiGateway::RestApi + MyNonServerlessLambdaFunction: + Properties: + Code: "." + Handler: main.handler + Runtime: python3.6 + Type: AWS::Lambda::Function diff --git a/tests/unit/commands/local/lib/test_api_provider.py b/tests/unit/commands/local/lib/test_api_provider.py new file mode 100644 index 0000000000..50b8d073d4 --- /dev/null +++ b/tests/unit/commands/local/lib/test_api_provider.py @@ -0,0 +1,207 @@ +from collections import OrderedDict +from unittest import TestCase + +from mock import patch + +from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.commands.local.lib.sam_api_provider import SamApiProvider +from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider + + +class TestApiProvider_init(TestCase): + + @patch.object(ApiProvider, "_extract_apis") + @patch("samcli.commands.local.lib.api_provider.SamBaseProvider") + def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): + extract_api_mock.return_value = {"set", "of", "values"} + + template = {"Resources": {"a": "b"}} + SamBaseProviderMock.get_template.return_value = template + + provider = ApiProvider(template) + + self.assertEquals(len(provider.apis), 3) + self.assertEquals(provider.apis, set(["set", "of", "values"])) + self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) + self.assertEquals(provider.resources, {"a": "b"}) + + +class TestApiProviderSelection(TestCase): + def test_default_provider(self): + resources = { + "TestApi": { + "Type": "AWS::UNKNOWN_TYPE", + "Properties": { + "StageName": "dev", + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + self.assertTrue(isinstance(provider, SamApiProvider)) + + def test_api_provider_sam_api(self): + resources = { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "dev", + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + self.assertTrue(isinstance(provider, SamApiProvider)) + + def test_api_provider_sam_function(self): + resources = { + "TestApi": { + "Type": "AWS::Serverless::Function", + "Properties": { + "StageName": "dev", + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + + self.assertTrue(isinstance(provider, SamApiProvider)) + + def test_api_provider_cloud_formation(self): + resources = { + "TestApi": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "StageName": "dev", + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + self.assertTrue(isinstance(provider, CfnApiProvider)) + + def test_multiple_api_provider_cloud_formation(self): + resources = OrderedDict() + resources["TestApi"] = { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "StageName": "dev", + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + resources["OtherApi"] = { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "dev", + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + self.assertTrue(isinstance(provider, CfnApiProvider)) diff --git a/tests/unit/commands/local/lib/test_cfn_api_provider.py b/tests/unit/commands/local/lib/test_cfn_api_provider.py new file mode 100644 index 0000000000..723951eb11 --- /dev/null +++ b/tests/unit/commands/local/lib/test_cfn_api_provider.py @@ -0,0 +1,215 @@ +import json +import tempfile +from unittest import TestCase + +from mock import patch +from six import assertCountEqual + +from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.commands.local.lib.provider import Api +from tests.unit.commands.local.lib.test_sam_api_provider import make_swagger + + +class TestApiProviderWithApiGatewayRestApi(TestCase): + + def setUp(self): + self.binary_types = ["image/png", "image/jpg"] + self.input_apis = [ + Api(path="/path1", method="GET", function_name="SamFunc1", cors=None), + Api(path="/path1", method="POST", function_name="SamFunc1", cors=None), + + Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None), + Api(path="/path2", method="GET", function_name="SamFunc1", cors=None), + + Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None) + ] + + def test_with_no_apis(self): + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + }, + + } + } + } + + provider = ApiProvider(template) + + self.assertEquals(provider.apis, []) + + def test_with_inline_swagger_apis(self): + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": make_swagger(self.input_apis) + } + } + } + } + + provider = ApiProvider(template) + assertCountEqual(self, self.input_apis, provider.apis) + + def test_with_swagger_as_local_file(self): + with tempfile.NamedTemporaryFile(mode='w') as fp: + filename = fp.name + + swagger = make_swagger(self.input_apis) + json.dump(swagger, fp) + fp.flush() + + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "BodyS3Location": filename + } + } + } + } + + provider = ApiProvider(template) + assertCountEqual(self, self.input_apis, provider.apis) + + def test_body_with_swagger_as_local_file_expect_fail(self): + with tempfile.NamedTemporaryFile(mode='w') as fp: + filename = fp.name + + swagger = make_swagger(self.input_apis) + json.dump(swagger, fp) + fp.flush() + + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": filename + } + } + } + } + self.assertRaises(Exception, ApiProvider, template) + + @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") + def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): + body = {"some": "body"} + filename = "somefile.txt" + + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "BodyS3Location": filename, + "Body": body + } + } + } + } + + SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) + + cwd = "foo" + provider = ApiProvider(template, cwd=cwd) + assertCountEqual(self, self.input_apis, provider.apis) + SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) + + def test_swagger_with_any_method(self): + apis = [ + Api(path="/path", method="any", function_name="SamFunc1", cors=None) + ] + + expected_apis = [ + Api(path="/path", method="GET", function_name="SamFunc1", cors=None), + Api(path="/path", method="POST", function_name="SamFunc1", cors=None), + Api(path="/path", method="PUT", function_name="SamFunc1", cors=None), + Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None), + Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None), + Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None), + Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None) + ] + + template = { + "Resources": { + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": make_swagger(apis) + } + } + } + } + + provider = ApiProvider(template) + assertCountEqual(self, expected_apis, provider.apis) + + def test_with_binary_media_types(self): + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": make_swagger(self.input_apis, binary_media_types=self.binary_types) + } + } + } + } + + expected_binary_types = sorted(self.binary_types) + expected_apis = [ + Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types), + Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types), + + Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types), + Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types), + + Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types) + ] + + provider = ApiProvider(template) + assertCountEqual(self, expected_apis, provider.apis) + + def test_with_binary_media_types_in_swagger_and_on_resource(self): + input_apis = [ + Api(path="/path", method="OPTIONS", function_name="SamFunc1"), + ] + extra_binary_types = ["text/html"] + + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "BinaryMediaTypes": extra_binary_types, + "Body": make_swagger(input_apis, binary_media_types=self.binary_types) + } + } + } + } + + expected_binary_types = sorted(self.binary_types + extra_binary_types) + expected_apis = [ + Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=expected_binary_types), + ] + + provider = ApiProvider(template) + assertCountEqual(self, expected_apis, provider.apis) diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index 3cc5d2c4c3..cfa35af954 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -35,7 +35,7 @@ def setUp(self): self.lambda_invoke_context_mock.stderr = self.stderr_mock @patch("samcli.commands.local.lib.local_api_service.LocalApigwService") - @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") + @patch("samcli.commands.local.lib.local_api_service.ApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") @patch.object(LocalApiService, "_make_routing_list") @@ -77,7 +77,7 @@ def test_must_start_service(self, self.apigw_service.run.assert_called_with() @patch("samcli.commands.local.lib.local_api_service.LocalApigwService") - @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") + @patch("samcli.commands.local.lib.local_api_service.ApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") @patch.object(LocalApiService, "_make_routing_list") diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 3ac01956be..fa5f342e49 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -7,29 +7,11 @@ from six import assertCountEqual -from samcli.commands.local.lib.sam_api_provider import SamApiProvider +from samcli.commands.local.lib.api_provider import ApiProvider, SamApiProvider from samcli.commands.local.lib.provider import Api from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException -class TestSamApiProvider_init(TestCase): - - @patch.object(SamApiProvider, "_extract_apis") - @patch("samcli.commands.local.lib.sam_api_provider.SamBaseProvider") - def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): - extract_api_mock.return_value = {"set", "of", "values"} - - template = {"Resources": {"a": "b"}} - SamBaseProviderMock.get_template.return_value = template - - provider = SamApiProvider(template) - - self.assertEquals(len(provider.apis), 3) - self.assertEquals(provider.apis, set(["set", "of", "values"])) - self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) - self.assertEquals(provider.resources, {"a": "b"}) - - class TestSamApiProviderWithImplicitApis(TestCase): def test_provider_with_no_resource_properties(self): @@ -42,9 +24,8 @@ def test_provider_with_no_resource_properties(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) - self.assertEquals(len(provider.apis), 0) self.assertEquals(provider.apis, []) @parameterized.expand([("GET"), ("get")]) @@ -72,7 +53,7 @@ def test_provider_has_correct_api(self, method): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) self.assertEquals(len(provider.apis), 1) self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", cors=None, @@ -109,7 +90,7 @@ def test_provider_creates_api_for_all_events(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api_event1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api_event2 = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") @@ -159,7 +140,7 @@ def test_provider_has_correct_template(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api2 = Api(path="/path", method="POST", function_name="SamFunc2", cors=None, stage_name="Prod") @@ -190,7 +171,7 @@ def test_provider_with_no_api_events(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) self.assertEquals(provider.apis, []) @@ -209,7 +190,7 @@ def test_provider_with_no_serverless_function(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) self.assertEquals(provider.apis, []) @@ -254,7 +235,7 @@ def test_provider_get_all(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] @@ -267,7 +248,7 @@ def test_provider_get_all(self): def test_provider_get_all_with_no_apis(self): template = {} - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] @@ -298,7 +279,7 @@ def test_provider_with_any_method(self, method): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api_get = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api_post = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") @@ -351,7 +332,7 @@ def test_provider_must_support_binary_media_types(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) self.assertEquals(len(provider.apis), 1) self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", @@ -403,7 +384,7 @@ def test_provider_must_support_binary_media_types_with_any_method(self): Api(path="/path", method="PATCH", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod") ] - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, provider.apis, expected_apis) @@ -448,9 +429,8 @@ def test_with_no_apis(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) - self.assertEquals(len(provider.apis), 0) self.assertEquals(provider.apis, []) def test_with_inline_swagger_apis(self): @@ -467,7 +447,7 @@ def test_with_inline_swagger_apis(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, self.input_apis, provider.apis) def test_with_swagger_as_local_file(self): @@ -491,11 +471,11 @@ def test_with_swagger_as_local_file(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, self.input_apis, provider.apis) - @patch("samcli.commands.local.lib.sam_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri(self, SamSwaggerReaderMock): + @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") + def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): body = {"some": "body"} filename = "somefile.txt" @@ -516,7 +496,7 @@ def test_with_swagger_as_both_body_and_uri(self, SamSwaggerReaderMock): SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) cwd = "foo" - provider = SamApiProvider(template, cwd=cwd) + provider = ApiProvider(template, cwd=cwd) assertCountEqual(self, self.input_apis, provider.apis) SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) @@ -547,7 +527,7 @@ def test_swagger_with_any_method(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) def test_with_binary_media_types(self): @@ -580,7 +560,7 @@ def test_with_binary_media_types(self): binary_media_types=expected_binary_types, stage_name="Prod") ] - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) def test_with_binary_media_types_in_swagger_and_on_resource(self): @@ -609,7 +589,7 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): stage_name="Prod"), ] - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) @@ -686,7 +666,7 @@ def test_must_union_implicit_and_explicit(self): Api(path="/path3", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_prefer_implicit_api_over_explicit(self): @@ -723,7 +703,7 @@ def test_must_prefer_implicit_api_over_explicit(self): Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_prefer_implicit_with_any_method(self): @@ -757,7 +737,7 @@ def test_must_prefer_implicit_with_any_method(self): Api(path="/path", method="PATCH", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_with_any_method_on_both(self): @@ -802,8 +782,7 @@ def test_with_any_method_on_both(self): Api(path="/path2", method="POST", function_name="explicitfunction", cors=None, stage_name="Prod") ] - provider = SamApiProvider(self.template) - print(provider.apis) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_add_explicit_api_when_ref_with_rest_api_id(self): @@ -840,7 +819,7 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): Api(path="/newpath2", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_both_apis_must_get_binary_media_types(self): @@ -895,7 +874,7 @@ def test_both_apis_must_get_binary_media_types(self): stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_binary_media_types_with_rest_api_id_reference(self): @@ -955,7 +934,7 @@ def test_binary_media_types_with_rest_api_id_reference(self): stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) @@ -991,7 +970,7 @@ def test_provider_parse_stage_name(self): } } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], stage_name='dev', stage_variables=None) @@ -1033,7 +1012,7 @@ def test_provider_stage_variables(self): } } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], stage_name='dev', stage_variables={ @@ -1119,8 +1098,7 @@ def test_multi_stage_get_all(self): } } } - - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] From aeba546806a9beda938d1fdd141c9f1558da67a2 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa <51216482+viksrivat@users.noreply.github.com> Date: Fri, 26 Jul 2019 09:54:57 -0700 Subject: [PATCH 13/30] feat(start-api): CloudFormation AWS::ApiGateway::Stage Support (#1239) --- samcli/commands/local/lib/api_collector.py | 230 +++---- samcli/commands/local/lib/api_provider.py | 35 +- samcli/commands/local/lib/cfn_api_provider.py | 60 +- .../local/lib/cfn_base_api_provider.py | 31 +- .../commands/local/lib/local_api_service.py | 70 +- samcli/commands/local/lib/provider.py | 93 +-- samcli/commands/local/lib/sam_api_provider.py | 95 +-- samcli/commands/local/lib/swagger/parser.py | 15 +- samcli/commands/local/lib/swagger/reader.py | 2 +- samcli/local/apigw/local_apigw_service.py | 59 +- .../local/lib/test_local_api_service.py | 24 +- .../local/start_api/test_start_api.py | 28 + .../start_api/swagger-rest-api-template.yaml | 20 + .../commands/local/lib/swagger/test_parser.py | 30 +- .../commands/local/lib/swagger/test_reader.py | 44 +- .../commands/local/lib/test_api_provider.py | 10 +- .../local/lib/test_cfn_api_provider.py | 283 +++++++-- .../local/lib/test_local_api_service.py | 101 +-- .../local/lib/test_sam_api_provider.py | 600 ++++++++---------- .../local/apigw/test_local_apigw_service.py | 88 ++- 20 files changed, 1011 insertions(+), 907 deletions(-) diff --git a/samcli/commands/local/lib/api_collector.py b/samcli/commands/local/lib/api_collector.py index cbd198c6b7..be18cea8c8 100644 --- a/samcli/commands/local/lib/api_collector.py +++ b/samcli/commands/local/lib/api_collector.py @@ -1,207 +1,173 @@ """ Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit -APIs in a standardized format +routes in a standardized format """ import logging -from collections import namedtuple +from collections import defaultdict from six import string_types +from samcli.local.apigw.local_apigw_service import Route +from samcli.commands.local.lib.provider import Api + LOG = logging.getLogger(__name__) class ApiCollector(object): - # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. - # This is intentional because it allows us to easily extend this class to support future properties on the API. - # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit - # APIs. - Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) def __init__(self): - # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and - # value is the properties - self.by_resource = {} + # Route properties stored per resource. + self._route_per_resource = defaultdict(list) + + # processed values to be set before creating the api + self._routes = [] + self.binary_media_types_set = set() + self.stage_name = None + self.stage_variables = None def __iter__(self): """ - Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the - LogicalId of the API resource and a list of APIs available in this resource. - + Iterator to iterate through all the routes stored in the collector. In each iteration, this yields the + LogicalId of the route resource and a list of routes available in this resource. Yields ------- str - LogicalID of the AWS::Serverless::Api resource + LogicalID of the AWS::Serverless::Api or AWS::ApiGateway::RestApi resource list samcli.commands.local.lib.provider.Api List of the API available in this resource along with additional configuration like binary media types. """ - for logical_id, _ in self.by_resource.items(): - yield logical_id, self._get_apis_with_config(logical_id) + for logical_id, _ in self._route_per_resource.items(): + yield logical_id, self._get_routes(logical_id) - def add_apis(self, logical_id, apis): + def add_routes(self, logical_id, routes): """ - Stores the given APIs tagged under the given logicalId - + Stores the given routes tagged under the given logicalId Parameters ---------- logical_id : str - LogicalId of the AWS::Serverless::Api resource - - apis : list of samcli.commands.local.lib.provider.Api - List of APIs available in this resource + LogicalId of the AWS::Serverless::Api or AWS::ApiGateway::RestApi resource + routes : list of samcli.commands.local.agiw.local_apigw_service.Route + List of routes available in this resource """ - properties = self._get_properties(logical_id) - properties.apis.extend(apis) + self._get_routes(logical_id).extend(routes) - def add_binary_media_types(self, logical_id, binary_media_types): + def _get_routes(self, logical_id): """ - Stores the binary media type configuration for the API with given logical ID - + Returns the properties of resource with given logical ID. If a resource is not found, then it returns an + empty data. Parameters ---------- logical_id : str - LogicalId of the AWS::Serverless::Api resource - - binary_media_types : list of str - List of binary media types supported by this resource - - """ - properties = self._get_properties(logical_id) - - binary_media_types = binary_media_types or [] - for value in binary_media_types: - normalized_value = self._normalize_binary_media_type(value) - - # If the value is not supported, then just skip it. - if normalized_value: - properties.binary_media_types.add(normalized_value) - else: - LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) - - def add_stage_name(self, logical_id, stage_name): + Logical ID of the resource + Returns + ------- + samcli.commands.local.lib.Routes + Properties object for this resource. """ - Stores the stage name for the API with the given local ID - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource + return self._route_per_resource[logical_id] - stage_name : str - The stage_name string + @property + def routes(self): + return self._routes if self._routes else self.all_routes() - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_name=stage_name) - self._set_properties(logical_id, properties) + @routes.setter + def routes(self, routes): + self._routes = routes - def add_stage_variables(self, logical_id, stage_variables): + def all_routes(self): """ - Stores the stage variables for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_variables : dict - A dictionary containing stage variables. + Gets all the routes within the _route_per_resource + Return + ------- + All the routes within the _route_per_resource """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_variables=stage_variables) - self._set_properties(logical_id, properties) + routes = [] + for logical_id in self._route_per_resource.keys(): + routes.extend(self._get_routes(logical_id)) + return routes - def _get_apis_with_config(self, logical_id): + def get_api(self): """ - Returns the list of APIs in this resource along with other extra configuration such as binary media types, - cors etc. Additional configuration is merged directly into the API data because these properties, although - defined globally, actually apply to each API. + Creates the api using the parts from the ApiCollector. The routes are also deduped so that there is no + duplicate routes with the same function name, path, but different method. - Parameters - ---------- - logical_id : str - Logical ID of the resource to fetch data for + The normalised_routes are the routes that have been processed. By default, this will get all the routes. + However, it can be changed to override the default value of normalised routes such as in SamApiProvider - Returns + Return ------- - list of samcli.commands.local.lib.provider.Api - List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, - then it returns an empty list + An Api object with all the properties """ + api = Api() + api.routes = self.dedupe_function_routes(self.routes) + api.binary_media_types_set = self.binary_media_types_set + api.stage_name = self.stage_name + api.stage_variables = self.stage_variables + return api - properties = self._get_properties(logical_id) + @staticmethod + def dedupe_function_routes(routes): + """ + Remove duplicate routes that have the same function_name and method - # These configs need to be applied to each API - binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable - cors = properties.cors - stage_name = properties.stage_name - stage_variables = properties.stage_variables + route: list(Route) + List of Routes - result = [] - for api in properties.apis: - # Create a copy of the API with updated configuration - updated_api = api._replace(binary_media_types=binary_media, - cors=cors, - stage_name=stage_name, - stage_variables=stage_variables) - result.append(updated_api) + Return + ------- + A list of routes without duplicate routes with the same function_name and method + """ + grouped_routes = {} - return result + for route in routes: + key = "{}-{}".format(route.function_name, route.path) + config = grouped_routes.get(key, None) + methods = route.methods + if config: + methods += config.methods + sorted_methods = sorted(methods) + grouped_routes[key] = Route(function_name=route.function_name, path=route.path, methods=sorted_methods) + return list(grouped_routes.values()) - def _get_properties(self, logical_id): + def add_binary_media_types(self, logical_id, binary_media_types): """ - Returns the properties of resource with given logical ID. If a resource is not found, then it returns an - empty data. - + Stores the binary media type configuration for the API with given logical ID Parameters ---------- - logical_id : str - Logical ID of the resource - Returns - ------- - samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ - - if logical_id not in self.by_resource: - self.by_resource[logical_id] = self.Properties(apis=[], - # Use a set() to be able to easily de-dupe - binary_media_types=set(), - cors=None, - stage_name=None, - stage_variables=None) + logical_id : str + LogicalId of the AWS::Serverless::Api resource - return self.by_resource[logical_id] + api: samcli.commands.local.lib.provider.Api + Instance of the Api which will save all the api configurations - def _set_properties(self, logical_id, properties): + binary_media_types : list of str + List of binary media types supported by this resource """ - Sets the properties of resource with given logical ID. If a resource is not found, it does nothing - Parameters - ---------- - logical_id : str - Logical ID of the resource - properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ + binary_media_types = binary_media_types or [] + for value in binary_media_types: + normalized_value = self.normalize_binary_media_type(value) - if logical_id in self.by_resource: - self.by_resource[logical_id] = properties + # If the value is not supported, then just skip it. + if normalized_value: + self.binary_media_types_set.add(normalized_value) + else: + LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) @staticmethod - def _normalize_binary_media_type(value): + def normalize_binary_media_type(value): """ Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not a string, then this method just returns None - Parameters ---------- value : str Value to be normalized - Returns ------- str or None diff --git a/samcli/commands/local/lib/api_provider.py b/samcli/commands/local/lib/api_provider.py index afc686e166..20d31039f7 100644 --- a/samcli/commands/local/lib/api_provider.py +++ b/samcli/commands/local/lib/api_provider.py @@ -1,13 +1,13 @@ -"""Class that provides Apis from a SAM Template""" +"""Class that provides the Api with a list of routes from a Template""" import logging -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider from samcli.commands.local.lib.api_collector import ApiCollector +from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider from samcli.commands.local.lib.provider import AbstractApiProvider -from samcli.commands.local.lib.sam_base_provider import SamBaseProvider from samcli.commands.local.lib.sam_api_provider import SamApiProvider -from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider +from samcli.commands.local.lib.sam_base_provider import SamBaseProvider LOG = logging.getLogger(__name__) @@ -16,7 +16,7 @@ class ApiProvider(AbstractApiProvider): def __init__(self, template_dict, parameter_overrides=None, cwd=None): """ - Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed + Initialize the class with template data. The template_dict is assumed to be valid, normalized and a dictionary. template_dict should be normalized by running any and all pre-processing before passing to this class. This class does not perform any syntactic validation of the template. @@ -27,7 +27,7 @@ def __init__(self, template_dict, parameter_overrides=None, cwd=None): Parameters ---------- template_dict : dict - SAM Template as a dictionary + Template as a dictionary cwd : str Optional working directory with respect to which we will resolve relative path to Swagger file @@ -39,23 +39,22 @@ def __init__(self, template_dict, parameter_overrides=None, cwd=None): # Store a set of apis self.cwd = cwd - self.apis = self._extract_apis(self.resources) - - LOG.debug("%d APIs found in the template", len(self.apis)) + self.api = self._extract_api(self.resources) + self.routes = self.api.routes + LOG.debug("%d APIs found in the template", len(self.routes)) def get_all(self): """ - Yields all the Lambda functions with Api Events available in the SAM Template. + Yields all the Apis in the current Provider - :yields Api: namedtuple containing the Api information + :yields api: an Api object with routes and properties """ - for api in self.apis: - yield api + yield self.api - def _extract_apis(self, resources): + def _extract_api(self, resources): """ - Extracts all the Apis by running through the one providers. The provider that has the first type matched + Extracts all the routes by running through the one providers. The provider that has the first type matched will be run across all the resources Parameters @@ -64,12 +63,12 @@ def _extract_apis(self, resources): The dictionary containing the different resources within the template Returns --------- - list of Apis extracted from the resources + An Api from the parsed template """ collector = ApiCollector() provider = self.find_api_provider(resources) - apis = provider.extract_resource_api(resources, collector, cwd=self.cwd) - return self.normalize_apis(apis) + provider.extract_resources(resources, collector, cwd=self.cwd) + return collector.get_api() @staticmethod def find_api_provider(resources): diff --git a/samcli/commands/local/lib/cfn_api_provider.py b/samcli/commands/local/lib/cfn_api_provider.py index 0e3919611c..dc1c16848f 100644 --- a/samcli/commands/local/lib/cfn_api_provider.py +++ b/samcli/commands/local/lib/cfn_api_provider.py @@ -1,6 +1,7 @@ """Parses SAM given a template""" import logging +from samcli.commands.local.cli_common.user_exceptions import InvalidSamTemplateException from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider LOG = logging.getLogger(__name__) @@ -8,20 +9,22 @@ class CfnApiProvider(CfnBaseApiProvider): APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" + APIGATEWAY_STAGE = "AWS::ApiGateway::Stage" TYPES = [ - APIGATEWAY_RESTAPI + APIGATEWAY_RESTAPI, + APIGATEWAY_STAGE ] - def extract_resource_api(self, resources, collector, cwd=None): + def extract_resources(self, resources, collector, cwd=None): """ - Extract the Api Object from a given resource and adds it to the ApiCollector. + Extract the Route Object from a given resource and adds it to the RouteCollector. Parameters ---------- resources: dict The dictionary containing the different resources within the template - collector: ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Instance of the API collector that where we will save the API information cwd : str @@ -29,18 +32,17 @@ def extract_resource_api(self, resources, collector, cwd=None): Return ------- - Returns a list of Apis + Returns a list of routes """ for logical_id, resource in resources.items(): resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) if resource_type == CfnApiProvider.APIGATEWAY_RESTAPI: - self._extract_cloud_formation_api(logical_id, resource, collector, cwd) - all_apis = [] - for _, apis in collector: - all_apis.extend(apis) - return all_apis + self._extract_cloud_formation_route(logical_id, resource, collector, cwd=cwd) - def _extract_cloud_formation_api(self, logical_id, api_resource, collector, cwd=None): + if resource_type == CfnApiProvider.APIGATEWAY_STAGE: + self._extract_cloud_formation_stage(resources, resource, collector) + + def _extract_cloud_formation_route(self, logical_id, api_resource, collector, cwd=None): """ Extract APIs from AWS::ApiGateway::RestApi resource by reading and parsing Swagger documents. The result is added to the collector. @@ -66,4 +68,38 @@ def _extract_cloud_formation_api(self, logical_id, api_resource, collector, cwd= LOG.debug("Skipping resource '%s'. Swagger document not found in Body and BodyS3Location", logical_id) return - self.extract_swagger_api(logical_id, body, body_s3_location, binary_media, collector, cwd) + self.extract_swagger_route(logical_id, body, body_s3_location, binary_media, collector, cwd) + + @staticmethod + def _extract_cloud_formation_stage(resources, stage_resource, collector): + """ + Extract the stage from AWS::ApiGateway::Stage resource by reading and adds it to the collector. + Parameters + ---------- + resources: dict + All Resource definition, including its properties + + stage_resource : dict + Stage Resource definition, including its properties + + collector : ApiCollector + Instance of the API collector that where we will save the API information + """ + properties = stage_resource.get("Properties", {}) + stage_name = properties.get("StageName") + stage_variables = properties.get("Variables") + + # Currently, we aren't resolving any Refs or other intrinsic properties that come with it + # A separate pr will need to fully resolve intrinsics + logical_id = properties.get("RestApiId") + if not logical_id: + raise InvalidSamTemplateException("The AWS::ApiGateway::Stage must have a RestApiId property") + + rest_api_resource_type = resources.get(logical_id, {}).get("Type") + if rest_api_resource_type != CfnApiProvider.APIGATEWAY_RESTAPI: + raise InvalidSamTemplateException( + "The AWS::ApiGateway::Stage must have a valid RestApiId that points to RestApi resource {}".format( + logical_id)) + + collector.stage_name = stage_name + collector.stage_variables = stage_variables diff --git a/samcli/commands/local/lib/cfn_base_api_provider.py b/samcli/commands/local/lib/cfn_base_api_provider.py index 79bc6d8f1d..8d0d4c3774 100644 --- a/samcli/commands/local/lib/cfn_base_api_provider.py +++ b/samcli/commands/local/lib/cfn_base_api_provider.py @@ -1,9 +1,8 @@ """Class that parses the CloudFormation Api Template""" - import logging from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.swagger.reader import SamSwaggerReader +from samcli.commands.local.lib.swagger.reader import SwaggerReader LOG = logging.getLogger(__name__) @@ -11,16 +10,16 @@ class CfnBaseApiProvider(object): RESOURCE_TYPE = "Type" - def extract_resource_api(self, resources, collector, cwd=None): + def extract_resources(self, resources, collector, cwd=None): """ - Extract the Api Object from a given resource and adds it to the ApiCollector. + Extract the Route Object from a given resource and adds it to the RouteCollector. Parameters ---------- resources: dict The dictionary containing the different resources within the template - collector: ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Instance of the API collector that where we will save the API information cwd : str @@ -28,12 +27,11 @@ def extract_resource_api(self, resources, collector, cwd=None): Return ------- - Returns a list of Apis + Returns a list of routes """ raise NotImplementedError("not implemented") - @staticmethod - def extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd=None): + def extract_swagger_route(self, logical_id, body, uri, binary_media, collector, cwd=None): """ Parse the Swagger documents and adds it to the ApiCollector. @@ -51,20 +49,21 @@ def extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd=None binary_media: list The link to the binary media - collector: ApiCollector - Instance of the API collector that where we will save the API information + collector: samcli.commands.local.lib.route_collector.RouteCollector + Instance of the Route collector that where we will save the route information cwd : str Optional working directory with respect to which we will resolve relative path to Swagger file """ - reader = SamSwaggerReader(definition_body=body, - definition_uri=uri, - working_dir=cwd) + reader = SwaggerReader(definition_body=body, + definition_uri=uri, + working_dir=cwd) swagger = reader.read() parser = SwaggerParser(swagger) - apis = parser.get_apis() - LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) + routes = parser.get_routes() + LOG.debug("Found '%s' APIs in resource '%s'", len(routes), logical_id) + + collector.add_routes(logical_id, routes) - collector.add_apis(logical_id, apis) collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template diff --git a/samcli/commands/local/lib/local_api_service.py b/samcli/commands/local/lib/local_api_service.py index d456e67a83..441d6c3cbc 100644 --- a/samcli/commands/local/lib/local_api_service.py +++ b/samcli/commands/local/lib/local_api_service.py @@ -2,20 +2,20 @@ Connects the CLI with Local API Gateway service. """ -import os import logging +import os -from samcli.local.apigw.local_apigw_service import LocalApigwService, Route -from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined +from samcli.local.apigw.local_apigw_service import LocalApigwService +from samcli.commands.local.lib.api_provider import ApiProvider LOG = logging.getLogger(__name__) class LocalApiService(object): """ - Implementation of Local API service that is capable of serving APIs defined in a SAM file that invoke a Lambda - function. + Implementation of Local API service that is capable of serving API defined in a configuration file that invoke a + Lambda function. """ def __init__(self, @@ -53,10 +53,8 @@ def start(self): NOTE: This is a blocking call that will not return until the thread is interrupted with SIGINT/SIGTERM """ - routing_list = self._make_routing_list(self.api_provider) - - if not routing_list: - raise NoApisDefined("No APIs available in SAM template") + if not self.api_provider.api.routes: + raise NoApisDefined("No APIs available in template") static_dir_path = self._make_static_dir_path(self.cwd, self.static_dir) @@ -64,7 +62,7 @@ def start(self): # contains the response to the API which is sent out as HTTP response. Only stderr needs to be printed # to the console or a log file. stderr from Docker container contains runtime logs and output of print # statements from the Lambda function - service = LocalApigwService(routing_list=routing_list, + service = LocalApigwService(api=self.api_provider.api, lambda_runner=self.lambda_runner, static_dir=static_dir_path, port=self.port, @@ -74,7 +72,7 @@ def start(self): service.create() # Print out the list of routes that will be mounted - self._print_routes(self.api_provider, self.host, self.port) + self._print_routes(self.api_provider.api.routes, self.host, self.port) LOG.info("You can now browse to the above endpoints to invoke your functions. " "You do not need to restart/reload SAM CLI while working on your functions, " "changes will be reflected instantly/automatically. You only need to restart " @@ -83,30 +81,7 @@ def start(self): service.run() @staticmethod - def _make_routing_list(api_provider): - """ - Returns a list of routes to configure the Local API Service based on the APIs configured in the template. - - Parameters - ---------- - api_provider : samcli.commands.local.lib.api_provider.ApiProvider - - Returns - ------- - list(samcli.local.apigw.service.Route) - List of Routes to pass to the service - """ - - routes = [] - for api in api_provider.get_all(): - route = Route(methods=[api.method], function_name=api.function_name, path=api.path, - binary_types=api.binary_media_types, stage_name=api.stage_name, - stage_variables=api.stage_variables) - routes.append(route) - return routes - - @staticmethod - def _print_routes(api_provider, host, port): + def _print_routes(routes, host, port): """ Helper method to print the APIs that will be mounted. This method is purely for printing purposes. This method takes in a list of Route Configurations and prints out the Routes grouped by path. @@ -116,8 +91,8 @@ def _print_routes(api_provider, host, port): Mounting Product at http://127.0.0.1:3000/path1/bar [GET, POST, DELETE] Mounting Product at http://127.0.0.1:3000/path2/bar [HEAD] - :param samcli.commands.local.lib.provider.AbstractApiProvider api_provider: - API Provider that can return a list of APIs + :param list(Route) routes: + List of routes grouped by the same function_name and path :param string host: Host name where the service is running :param int port: @@ -125,28 +100,15 @@ def _print_routes(api_provider, host, port): :returns list(string): List of lines that were printed to the console. Helps with testing """ - grouped_api_configs = {} - - for api in api_provider.get_all(): - key = "{}-{}".format(api.function_name, api.path) - - config = grouped_api_configs.get(key, {}) - config.setdefault("methods", []) - - config["function_name"] = api.function_name - config["path"] = api.path - config["methods"].append(api.method) - - grouped_api_configs[key] = config print_lines = [] - for _, config in grouped_api_configs.items(): - methods_str = "[{}]".format(', '.join(config["methods"])) + for route in routes: + methods_str = "[{}]".format(', '.join(route.methods)) output = "Mounting {} at http://{}:{}{} {}".format( - config["function_name"], + route.function_name, host, port, - config["path"], + route.path, methods_str) print_lines.append(output) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 959166e814..94789ba799 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -199,40 +199,32 @@ def get_all(self): raise NotImplementedError("not implemented") -_ApiTuple = namedtuple("Api", [ +class Api(object): + def __init__(self, routes=None): + if routes is None: + routes = [] + self.routes = routes - # String. Path that this API serves. Ex: /foo, /bar/baz - "path", + # Optional Dictionary containing CORS configuration on this path+method If this configuration is set, + # then API server will automatically respond to OPTIONS HTTP method on this path and respond with appropriate + # CORS headers based on configuration. - # String. HTTP Method this API responds with - "method", + self.cors = None + # If this configuration is set, then API server will automatically respond to OPTIONS HTTP method on this + # path and - # String. Name of the Function this API connects to - "function_name", + self.binary_media_types_set = set() - # Optional Dictionary containing CORS configuration on this path+method - # If this configuration is set, then API server will automatically respond to OPTIONS HTTP method on this path and - # respond with appropriate CORS headers based on configuration. - "cors", + self.stage_name = None + self.stage_variables = None - # List(Str). List of the binary media types the API - "binary_media_types", - # The Api stage name - "stage_name", - # The variables for that stage - "stage_variables" -]) -_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, - None, # Stage name is optional with default None - None # Stage variables is optional with default None - ) - - -class Api(_ApiTuple): def __hash__(self): # Other properties are not a part of the hash - return hash(self.path) * hash(self.method) * hash(self.function_name) + return hash(self.routes) * hash(self.cors) * hash(self.binary_media_types_set) + + @property + def binary_media_types(self): + return list(self.binary_media_types_set) Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) @@ -242,13 +234,6 @@ class AbstractApiProvider(object): """ Abstract base class to return APIs and the functions they route to """ - _ANY_HTTP_METHODS = ["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"] def get_all(self): """ @@ -257,43 +242,3 @@ def get_all(self): :yields Api: namedtuple containing the API information """ raise NotImplementedError("not implemented") - - @staticmethod - def normalize_http_methods(http_method): - """ - Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all - supported Http Methods on Api Gateway. - - :param str http_method: Http method - :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) - """ - - if http_method.upper() == 'ANY': - for method in AbstractApiProvider._ANY_HTTP_METHODS: - yield method.upper() - else: - yield http_method.upper() - - @staticmethod - def normalize_apis(apis): - """ - Normalize the APIs to use standard method name - - Parameters - ---------- - apis : list of samcli.commands.local.lib.provider.Api - List of APIs to replace normalize - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of normalized APIs - """ - - result = list() - for api in apis: - for normalized_method in AbstractApiProvider.normalize_http_methods(api.method): - # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple - result.append(api._replace(method=normalized_method)) - - return result diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index f0ec57b823..1710edbf2d 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -2,9 +2,9 @@ import logging -from samcli.commands.local.lib.provider import Api, AbstractApiProvider -from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider +from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +from samcli.local.apigw.local_apigw_service import Route LOG = logging.getLogger(__name__) @@ -23,24 +23,21 @@ class SamApiProvider(CfnBaseApiProvider): _EVENT_TYPE = "Type" IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" - def extract_resource_api(self, resources, collector, cwd=None): + def extract_resources(self, resources, collector, cwd=None): """ - Extract the Api Object from a given resource and adds it to the ApiCollector. + Extract the Route Object from a given resource and adds it to the RouteCollector. Parameters ---------- resources: dict The dictionary containing the different resources within the template - collector: ApiCollector + collector: samcli.commands.local.lib.route_collector.ApiCollector Instance of the API collector that where we will save the API information cwd : str Optional working directory with respect to which we will resolve relative path to Swagger file - Return - ------- - Returns a list of Apis """ # AWS::Serverless::Function is currently included when parsing of Apis because when SamBaseProvider is run on # the template we are creating the implicit apis due to plugins that translate it in the SAM repo, @@ -49,10 +46,11 @@ def extract_resource_api(self, resources, collector, cwd=None): for logical_id, resource in resources.items(): resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) if resource_type == SamApiProvider.SERVERLESS_FUNCTION: - self._extract_apis_from_function(logical_id, resource, collector) + self._extract_routes_from_function(logical_id, resource, collector) if resource_type == SamApiProvider.SERVERLESS_API: - self._extract_from_serverless_api(logical_id, resource, collector, cwd) - return self.merge_apis(collector) + self._extract_from_serverless_api(logical_id, resource, collector, cwd=cwd) + + collector.routes = self.merge_routes(collector) def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=None): """ @@ -67,8 +65,12 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= api_resource : dict Resource definition, including its properties - collector : ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Instance of the API collector that where we will save the API information + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + """ properties = api_resource.get("Properties", {}) @@ -83,13 +85,13 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri", logical_id) return - self.extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd) - collector.add_stage_name(logical_id, stage_name) - collector.add_stage_variables(logical_id, stage_variables) + self.extract_swagger_route(logical_id, body, uri, binary_media, collector, cwd=cwd) + collector.stage_name = stage_name + collector.stage_variables = stage_variables - def _extract_apis_from_function(self, logical_id, function_resource, collector): + def _extract_routes_from_function(self, logical_id, function_resource, collector): """ - Fetches a list of APIs configured for this SAM Function resource. + Fetches a list of routes configured for this SAM Function resource. Parameters ---------- @@ -99,17 +101,17 @@ def _extract_apis_from_function(self, logical_id, function_resource, collector): function_resource : dict Contents of the function resource including its properties - collector : ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Instance of the API collector that where we will save the API information """ resource_properties = function_resource.get("Properties", {}) serverless_function_events = resource_properties.get(self._FUNCTION_EVENT, {}) - self.extract_apis_from_events(logical_id, serverless_function_events, collector) + self.extract_routes_from_events(logical_id, serverless_function_events, collector) - def extract_apis_from_events(self, function_logical_id, serverless_function_events, collector): + def extract_routes_from_events(self, function_logical_id, serverless_function_events, collector): """ - Given an AWS::Serverless::Function Event Dictionary, extract out all 'Api' events and store within the + Given an AWS::Serverless::Function Event Dictionary, extract out all 'route' events and store within the collector Parameters @@ -120,27 +122,27 @@ def extract_apis_from_events(self, function_logical_id, serverless_function_even serverless_function_events : dict Event Dictionary of a AWS::Serverless::Function - collector : ApiCollector - Instance of the API collector that where we will save the API information + collector: samcli.commands.local.lib.route_collector.RouteCollector + Instance of the Route collector that where we will save the route information """ count = 0 for _, event in serverless_function_events.items(): if self._FUNCTION_EVENT_TYPE_API == event.get(self._EVENT_TYPE): - api_resource_id, api = self._convert_event_api(function_logical_id, event.get("Properties")) - collector.add_apis(api_resource_id, [api]) + route_resource_id, route = self._convert_event_route(function_logical_id, event.get("Properties")) + collector.add_routes(route_resource_id, [route]) count += 1 LOG.debug("Found '%d' API Events in Serverless function with name '%s'", count, function_logical_id) @staticmethod - def _convert_event_api(lambda_logical_id, event_properties): + def _convert_event_route(lambda_logical_id, event_properties): """ - Converts a AWS::Serverless::Function's Event Property to an Api configuration usable by the provider. + Converts a AWS::Serverless::Function's Event Property to an Route configuration usable by the provider. :param str lambda_logical_id: Logical Id of the AWS::Serverless::Function :param dict event_properties: Dictionary of the Event's Property - :return tuple: tuple of API resource name and Api namedTuple + :return tuple: tuple of route resource name and route """ path = event_properties.get(SamApiProvider._EVENT_PATH) method = event_properties.get(SamApiProvider._EVENT_METHOD) @@ -159,55 +161,54 @@ def _convert_event_api(lambda_logical_id, event_properties): "It should either be a LogicalId string or a Ref of a Logical Id string" .format(lambda_logical_id)) - return api_resource_id, Api(path=path, method=method, function_name=lambda_logical_id) + return api_resource_id, Route(path=path, methods=[method], function_name=lambda_logical_id) @staticmethod - def merge_apis(collector): + def merge_routes(collector): """ - Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API + Quite often, an API is defined both in Implicit and Explicit Route definitions. In such cases, Implicit API definition wins because that conveys clear intent that the API is backed by a function. This method will - merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined + merge two such list of routes with the right order of precedence. If a Path+Method combination is defined in both the places, only one wins. Parameters ---------- - collector : ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Collector object that holds all the APIs specified in the template Returns ------- - list of samcli.commands.local.lib.provider.Api - List of APIs obtained by combining both the input lists. + list of samcli.local.apigw.local_apigw_service.Route + List of routes obtained by combining both the input lists. """ - implicit_apis = [] - explicit_apis = [] + implicit_routes = [] + explicit_routes = [] # Store implicit and explicit APIs separately in order to merge them later in the correct order # Implicit APIs are defined on a resource with logicalID ServerlessRestApi for logical_id, apis in collector: if logical_id == SamApiProvider.IMPLICIT_API_RESOURCE_ID: - implicit_apis.extend(apis) + implicit_routes.extend(apis) else: - explicit_apis.extend(apis) + explicit_routes.extend(apis) # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. # If an path+method combo already exists, then overwrite it if and only if this is an implicit API - all_apis = {} + all_routes = {} # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. - all_configs = explicit_apis + implicit_apis + all_configs = explicit_routes + implicit_routes for config in all_configs: # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP - # method on explicit API. - for normalized_method in AbstractApiProvider.normalize_http_methods(config.method): + # method on explicit route. + for normalized_method in config.methods: key = config.path + normalized_method - all_apis[key] = config + all_routes[key] = config - result = set(all_apis.values()) # Assign to a set() to de-dupe + result = set(all_routes.values()) # Assign to a set() to de-dupe LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", - len(explicit_apis), len(implicit_apis), len(result)) - + len(explicit_routes), len(implicit_routes), len(result)) return list(result) diff --git a/samcli/commands/local/lib/swagger/parser.py b/samcli/commands/local/lib/swagger/parser.py index 076161993c..072e71c378 100644 --- a/samcli/commands/local/lib/swagger/parser.py +++ b/samcli/commands/local/lib/swagger/parser.py @@ -2,8 +2,8 @@ import logging -from samcli.commands.local.lib.provider import Api from samcli.commands.local.lib.swagger.integration_uri import LambdaUri, IntegrationType +from samcli.local.apigw.local_apigw_service import Route LOG = logging.getLogger(__name__) @@ -34,7 +34,7 @@ def get_binary_media_types(self): """ return self.swagger.get(self._BINARY_MEDIA_TYPES_EXTENSION_KEY) or [] - def get_apis(self): + def get_routes(self): """ Parses a swagger document and returns a list of APIs configured in the document. @@ -62,15 +62,13 @@ def get_apis(self): Returns ------- - list of samcli.commands.local.lib.provider.Api + list of list of samcli.commands.local.apigw.local_apigw_service.Route List of APIs that are configured in the Swagger document """ result = [] paths_dict = self.swagger.get("paths", {}) - binary_media_types = self.get_binary_media_types() - for full_path, path_config in paths_dict.items(): for method, method_config in path_config.items(): @@ -83,11 +81,8 @@ def get_apis(self): if method.lower() == self._ANY_METHOD_EXTENSION_KEY: # Convert to a more commonly used method notation method = self._ANY_METHOD - - api = Api(path=full_path, method=method, function_name=function_name, cors=None, - binary_media_types=binary_media_types) - result.append(api) - + route = Route(function_name, full_path, methods=[method]) + result.append(route) return result def _get_integration_function_name(self, method_config): diff --git a/samcli/commands/local/lib/swagger/reader.py b/samcli/commands/local/lib/swagger/reader.py index d3235170c6..02c2c1edb7 100644 --- a/samcli/commands/local/lib/swagger/reader.py +++ b/samcli/commands/local/lib/swagger/reader.py @@ -57,7 +57,7 @@ def parse_aws_include_transform(data): return location -class SamSwaggerReader(object): +class SwaggerReader(object): """ Class to read and parse Swagger document from a variety of sources. This class accepts the same data formats as available in Serverless::Api SAM resource diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 7aef82dc06..c18304064a 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -18,35 +18,63 @@ class Route(object): - - def __init__(self, methods, function_name, path, binary_types=None, stage_name=None, stage_variables=None): + _ANY_HTTP_METHODS = ["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"] + + def __init__(self, function_name, path, methods): """ Creates an ApiGatewayRoute - :param list(str) methods: List of HTTP Methods + :param list(str) methods: http method :param function_name: Name of the Lambda function this API is connected to :param str path: Path off the base url """ - self.methods = methods + self.methods = self.normalize_method(methods) self.function_name = function_name self.path = path - self.binary_types = binary_types or [] - self.stage_name = stage_name - self.stage_variables = stage_variables + + def __eq__(self, other): + return isinstance(other, Route) and \ + sorted(self.methods) == sorted( + other.methods) and self.function_name == other.function_name and self.path == other.path + + def __hash__(self): + route_hash = hash(self.function_name) * hash(self.path) + for method in sorted(self.methods): + route_hash *= hash(method) + return route_hash + + def normalize_method(self, methods): + """ + Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all + supported Http Methods on Api Gateway. + + :param list methods: Http methods + :return list: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) + """ + methods = [method.upper() for method in methods] + if "ANY" in methods: + return self._ANY_HTTP_METHODS + return methods class LocalApigwService(BaseLocalService): _DEFAULT_PORT = 3000 _DEFAULT_HOST = '127.0.0.1' - def __init__(self, routing_list, lambda_runner, static_dir=None, port=None, host=None, stderr=None): + def __init__(self, api, lambda_runner, static_dir=None, port=None, host=None, stderr=None): """ Creates an ApiGatewayService Parameters ---------- - routing_list list(ApiGatewayCallModel) - A list of the Model that represent the service paths to create. + api: Api + an Api object that contains the list of routes and properties lambda_runner samcli.commands.local.lib.local_lambda.LocalLambdaRunner The Lambda runner class capable of invoking the function static_dir str @@ -61,7 +89,7 @@ def __init__(self, routing_list, lambda_runner, static_dir=None, port=None, host Optional stream writer where the stderr from Docker container should be written to """ super(LocalApigwService, self).__init__(lambda_runner.is_debugging(), port=port, host=host) - self.routing_list = routing_list + self.api = api self.lambda_runner = lambda_runner self.static_dir = static_dir self._dict_of_routes = {} @@ -77,12 +105,11 @@ def create(self): static_folder=self.static_dir # Serve static files from this directory ) - for api_gateway_route in self.routing_list: + for api_gateway_route in self.api.routes: path = PathConverter.convert_path_to_flask(api_gateway_route.path) for route_key in self._generate_route_keys(api_gateway_route.methods, path): self._dict_of_routes[route_key] = api_gateway_route - self._app.add_url_rule(path, endpoint=path, view_func=self._request_handler, @@ -144,8 +171,8 @@ def _request_handler(self, **kwargs): route = self._get_current_route(request) try: - event = self._construct_event(request, self.port, route.binary_types, route.stage_name, - route.stage_variables) + event = self._construct_event(request, self.port, self.api.binary_media_types, self.api.stage_name, + self.api.stage_variables) except UnicodeDecodeError: return ServiceErrorResponses.lambda_failure_response() @@ -165,7 +192,7 @@ def _request_handler(self, **kwargs): try: (status_code, headers, body) = self._parse_lambda_output(lambda_response, - route.binary_types, + self.api.binary_media_types, request) except (KeyError, TypeError, ValueError): LOG.error("Function returned an invalid response (must include one of: body, headers, multiValueHeaders or " diff --git a/tests/functional/commands/local/lib/test_local_api_service.py b/tests/functional/commands/local/lib/test_local_api_service.py index a507304bae..23df3e9025 100644 --- a/tests/functional/commands/local/lib/test_local_api_service.py +++ b/tests/functional/commands/local/lib/test_local_api_service.py @@ -10,6 +10,8 @@ import time import logging +from samcli.commands.local.lib.provider import Api +from samcli.local.apigw.local_apigw_service import Route from samcli.commands.local.lib import provider from samcli.commands.local.lib.local_lambda import LocalLambdaRunner from samcli.local.lambdafn.runtime import LambdaRuntime @@ -42,7 +44,7 @@ def setUp(self): self.static_dir = "mystaticdir" self.static_file_name = "myfile.txt" self.static_file_content = "This is a static file" - self._setup_static_file(os.path.join(self.cwd, self.static_dir), # Create static directory with in cwd + self._setup_static_file(os.path.join(self.cwd, self.static_dir), # Create static directory with in cwd self.static_file_name, self.static_file_content) @@ -56,12 +58,14 @@ def setUp(self): self.mock_function_provider.get.return_value = self.function # Setup two APIs pointing to the same function - apis = [ - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors"), - provider.Api(path="/post", method="POST", function_name=self.function_name, cors="cors"), + routes = [ + Route(path="/get", methods=["GET"], function_name=self.function_name), + Route(path="/post", methods=["POST"], function_name=self.function_name), ] + api = Api(routes=routes) + self.api_provider_mock = Mock() - self.api_provider_mock.get_all.return_value = apis + self.api_provider_mock.get_all.return_value = api # Now wire up the Lambda invoker and pass it through the context self.lambda_invoke_context_mock = Mock() @@ -69,7 +73,9 @@ def setUp(self): layer_downloader = LayerDownloader("./", "./") lambda_image = LambdaImage(layer_downloader, False, False) local_runtime = LambdaRuntime(manager, lambda_image) - lambda_runner = LocalLambdaRunner(local_runtime, self.mock_function_provider, self.cwd, env_vars_values=None, + lambda_runner = LocalLambdaRunner(local_runtime, + self.mock_function_provider, + self.cwd, debug_context=None) self.lambda_invoke_context_mock.local_lambda_runner = lambda_runner self.lambda_invoke_context_mock.get_cwd.return_value = self.cwd @@ -77,7 +83,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.code_abs_path) - @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") + @patch("samcli.commands.local.lib.sam_api_provider.SamApiProvider") def test_must_start_service_and_serve_endpoints(self, sam_api_provider_mock): sam_api_provider_mock.return_value = self.api_provider_mock @@ -97,7 +103,7 @@ def test_must_start_service_and_serve_endpoints(self, sam_api_provider_mock): response = requests.get(self.url + '/post') self.assertEquals(response.status_code, 403) # "HTTP GET /post" must not exist - @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") + @patch("samcli.commands.local.lib.sam_api_provider.SamApiProvider") def test_must_serve_static_files(self, sam_api_provider_mock): sam_api_provider_mock.return_value = self.api_provider_mock @@ -123,10 +129,8 @@ def _start_service_thread(service): @staticmethod def _setup_static_file(directory, filename, contents): - if not os.path.isdir(directory): os.mkdir(directory) with open(os.path.join(directory, filename), "w") as fp: fp.write(contents) - diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 321741e0bf..84fbf2da79 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -656,3 +656,31 @@ def test_swagger_stage_variable(self): response_data = response.json() self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'}) + + +class TestStartApiWithCloudFormationStage(StartApiIntegBaseClass): + """ + Test Class centered around the different responses that can happen in Lambda and pass through start-api + """ + template_path = "/testdata/start_api/swagger-rest-api-template.yaml" + + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + + def test_default_stage_name(self): + response = requests.get(self.url + "/echoeventbody") + + self.assertEquals(response.status_code, 200) + + response_data = response.json() + print(response_data) + self.assertEquals(response_data.get("requestContext", {}).get("stage"), "Dev") + + def test_global_stage_variables(self): + response = requests.get(self.url + "/echoeventbody") + + self.assertEquals(response.status_code, 200) + + response_data = response.json() + + self.assertEquals(response_data.get("stageVariables"), {"Stack": "Dev"}) diff --git a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml index 5edeb8717f..5e7be3a95e 100644 --- a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml +++ b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml @@ -13,6 +13,12 @@ Resources: Handler: main.echo_base64_event_body Runtime: python3.6 Type: AWS::Lambda::Function + EchoEventBodyFunction: + Properties: + Code: "." + Handler: main.echo_event_handler + Runtime: python3.6 + Type: AWS::Lambda::Function MyApi: Properties: Body: @@ -35,6 +41,13 @@ Resources: type: aws_proxy uri: Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${Base64ResponseFunction.Arn}/invocations + "/echoeventbody": + get: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EchoEventBodyFunction.Arn}/invocations "/echobase64eventbody": post: x-amazon-apigateway-integration: @@ -61,6 +74,13 @@ Resources: - image/gif StageName: prod Type: AWS::ApiGateway::RestApi + Dev: + Type: AWS::ApiGateway::Stage + Properties: + StageName: Dev + RestApiId: MyApi + Variables: + Stack: Dev MyNonServerlessLambdaFunction: Properties: Code: "." diff --git a/tests/unit/commands/local/lib/swagger/test_parser.py b/tests/unit/commands/local/lib/swagger/test_parser.py index 59db1ea969..827f49be1c 100644 --- a/tests/unit/commands/local/lib/swagger/test_parser.py +++ b/tests/unit/commands/local/lib/swagger/test_parser.py @@ -1,14 +1,14 @@ """ Test the swagger parser """ - -from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.provider import Api - from unittest import TestCase + from mock import patch, Mock from parameterized import parameterized, param +from samcli.commands.local.lib.swagger.parser import SwaggerParser +from samcli.local.apigw.local_apigw_service import Route + class TestSwaggerParser_get_apis(TestCase): @@ -31,8 +31,8 @@ def test_with_one_path_method(self): parser._get_integration_function_name = Mock() parser._get_integration_function_name.return_value = function_name - expected = [Api(path="/path1", method="get", function_name=function_name, cors=None)] - result = parser.get_apis() + expected = [Route(path="/path1", methods=["get"], function_name=function_name)] + result = parser.get_routes() self.assertEquals(expected, result) parser._get_integration_function_name.assert_called_with({ @@ -77,11 +77,11 @@ def test_with_combination_of_paths_methods(self): parser._get_integration_function_name.return_value = function_name expected = { - Api(path="/path1", method="get", function_name=function_name, cors=None), - Api(path="/path1", method="delete", function_name=function_name, cors=None), - Api(path="/path2", method="post", function_name=function_name, cors=None), + Route(path="/path1", methods=["get"], function_name=function_name), + Route(path="/path1", methods=["delete"], function_name=function_name), + Route(path="/path2", methods=["post"], function_name=function_name), } - result = parser.get_apis() + result = parser.get_routes() self.assertEquals(expected, set(result)) @@ -104,8 +104,9 @@ def test_with_any_method(self): parser._get_integration_function_name = Mock() parser._get_integration_function_name.return_value = function_name - expected = [Api(path="/path1", method="ANY", function_name=function_name, cors=None)] - result = parser.get_apis() + expected = [Route(methods=["ANY"], path="/path1", + function_name=function_name)] + result = parser.get_routes() self.assertEquals(expected, result) @@ -128,7 +129,7 @@ def test_does_not_have_function_name(self): parser._get_integration_function_name.return_value = None # Function Name could not be resolved expected = [] - result = parser.get_apis() + result = parser.get_routes() self.assertEquals(expected, result) @@ -146,9 +147,8 @@ def test_does_not_have_function_name(self): }}) ]) def test_invalid_swagger(self, test_case_name, swagger): - parser = SwaggerParser(swagger) - result = parser.get_apis() + result = parser.get_routes() expected = [] self.assertEquals(expected, result) diff --git a/tests/unit/commands/local/lib/swagger/test_reader.py b/tests/unit/commands/local/lib/swagger/test_reader.py index 8112b2f21c..9ecb4d276d 100644 --- a/tests/unit/commands/local/lib/swagger/test_reader.py +++ b/tests/unit/commands/local/lib/swagger/test_reader.py @@ -8,7 +8,7 @@ from parameterized import parameterized, param from mock import Mock, patch -from samcli.commands.local.lib.swagger.reader import parse_aws_include_transform, SamSwaggerReader +from samcli.commands.local.lib.swagger.reader import parse_aws_include_transform, SwaggerReader class TestParseAwsIncludeTransform(TestCase): @@ -57,7 +57,7 @@ class TestSamSwaggerReader_init(TestCase): def test_definition_body_and_uri_required(self): with self.assertRaises(ValueError): - SamSwaggerReader() + SwaggerReader() class TestSamSwaggerReader_read(TestCase): @@ -67,7 +67,7 @@ def test_must_read_first_from_definition_body(self): uri = "./file.txt" expected = {"some": "value"} - reader = SamSwaggerReader(definition_body=body, definition_uri=uri) + reader = SwaggerReader(definition_body=body, definition_uri=uri) reader._download_swagger = Mock() reader._read_from_definition_body = Mock() reader._read_from_definition_body.return_value = expected @@ -82,7 +82,7 @@ def test_read_from_definition_uri(self): uri = "./file.txt" expected = {"some": "value"} - reader = SamSwaggerReader(definition_uri=uri) + reader = SwaggerReader(definition_uri=uri) reader._download_swagger = Mock() reader._download_swagger.return_value = expected @@ -96,7 +96,7 @@ def test_must_use_definition_uri_if_body_does_not_exist(self): uri = "./file.txt" expected = {"some": "value"} - reader = SamSwaggerReader(definition_body=body, definition_uri=uri) + reader = SwaggerReader(definition_body=body, definition_uri=uri) reader._download_swagger = Mock() reader._download_swagger.return_value = expected @@ -119,7 +119,7 @@ def test_must_work_with_include_transform(self, parse_mock): expected = {'k': 'v'} location = "some location" - reader = SamSwaggerReader(definition_body=body) + reader = SwaggerReader(definition_body=body) reader._download_swagger = Mock() reader._download_swagger.return_value = expected parse_mock.return_value = location @@ -132,7 +132,7 @@ def test_must_work_with_include_transform(self, parse_mock): def test_must_get_body_directly(self, parse_mock): body = {'this': 'swagger'} - reader = SamSwaggerReader(definition_body=body) + reader = SwaggerReader(definition_body=body) parse_mock.return_value = None # No location is returned from aws_include parser actual = reader._read_from_definition_body() @@ -151,7 +151,7 @@ def test_must_download_from_s3_for_s3_locations(self, yaml_parse_mock): swagger_str = "some swagger str" expected = "some data" - reader = SamSwaggerReader(definition_uri=location) + reader = SwaggerReader(definition_uri=location) reader._download_from_s3 = Mock() reader._download_from_s3.return_value = swagger_str yaml_parse_mock.return_value = expected @@ -169,7 +169,7 @@ def test_must_skip_non_s3_dictionaries(self, yaml_parse_mock): location = {"some": "value"} - reader = SamSwaggerReader(definition_uri=location) + reader = SwaggerReader(definition_uri=location) reader._download_from_s3 = Mock() actual = reader._download_swagger(location) @@ -193,7 +193,7 @@ def test_must_read_from_local_file(self, yaml_parse_mock): cwd = os.path.dirname(filepath) filename = os.path.basename(filepath) - reader = SamSwaggerReader(definition_uri=filename, working_dir=cwd) + reader = SwaggerReader(definition_uri=filename, working_dir=cwd) actual = reader._download_swagger(filename) self.assertEquals(actual, expected) @@ -211,7 +211,7 @@ def test_must_read_from_local_file_without_working_directory(self, yaml_parse_mo json.dump(data, fp) fp.flush() - reader = SamSwaggerReader(definition_uri=filepath) + reader = SwaggerReader(definition_uri=filepath) actual = reader._download_swagger(filepath) self.assertEquals(actual, expected) @@ -222,7 +222,7 @@ def test_must_return_none_if_file_not_found(self, yaml_parse_mock): expected = "parsed result" yaml_parse_mock.return_value = expected - reader = SamSwaggerReader(definition_uri="somepath") + reader = SwaggerReader(definition_uri="somepath") actual = reader._download_swagger("abcdefgh.txt") self.assertIsNone(actual) @@ -230,7 +230,7 @@ def test_must_return_none_if_file_not_found(self, yaml_parse_mock): def test_with_invalid_location(self): - reader = SamSwaggerReader(definition_uri="something") + reader = SwaggerReader(definition_uri="something") actual = reader._download_swagger({}) self.assertIsNone(actual) @@ -256,7 +256,7 @@ def test_must_download_file_from_s3(self, tempfilemock, botomock): expected = "data from file" fp_mock.read.return_value = expected - actual = SamSwaggerReader._download_from_s3(self.bucket, self.key, self.version) + actual = SwaggerReader._download_from_s3(self.bucket, self.key, self.version) self.assertEquals(actual, expected) s3_mock.download_fileobj.assert_called_with(self.bucket, self.key, fp_mock, @@ -277,7 +277,7 @@ def test_must_fail_on_download_from_s3(self, tempfilemock, botomock): "download_file") with self.assertRaises(Exception) as cm: - SamSwaggerReader._download_from_s3(self.bucket, self.key) + SwaggerReader._download_from_s3(self.bucket, self.key) self.assertIn(cm.exception.__class__, (botocore.exceptions.NoCredentialsError, botocore.exceptions.ClientError)) @@ -294,7 +294,7 @@ def test_must_work_without_object_version_id(self, tempfilemock, botomock): expected = "data from file" fp_mock.read.return_value = expected - actual = SamSwaggerReader._download_from_s3(self.bucket, self.key) + actual = SwaggerReader._download_from_s3(self.bucket, self.key) self.assertEquals(actual, expected) s3_mock.download_fileobj.assert_called_with(self.bucket, self.key, fp_mock, @@ -313,7 +313,7 @@ def test_must_log_on_download_exception(self, tempfilemock, botomock): "download_file") with self.assertRaises(botocore.exceptions.ClientError): - SamSwaggerReader._download_from_s3(self.bucket, self.key) + SwaggerReader._download_from_s3(self.bucket, self.key) fp_mock.read.assert_not_called() @@ -332,7 +332,7 @@ def test_must_parse_valid_dict(self): "Version": self.version } - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (self.bucket, self.key, self.version)) def test_must_parse_dict_without_version(self): @@ -341,19 +341,19 @@ def test_must_parse_dict_without_version(self): "Key": self.key } - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (self.bucket, self.key, None)) def test_must_parse_s3_uri_string(self): location = "s3://{}/{}?versionId={}".format(self.bucket, self.key, self.version) - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (self.bucket, self.key, self.version)) def test_must_parse_s3_uri_string_without_version_id(self): location = "s3://{}/{}".format(self.bucket, self.key) - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (self.bucket, self.key, None)) @parameterized.expand([ @@ -364,5 +364,5 @@ def test_must_parse_s3_uri_string_without_version_id(self): ]) def test_must_parse_invalid_location(self, location): - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (None, None, None)) diff --git a/tests/unit/commands/local/lib/test_api_provider.py b/tests/unit/commands/local/lib/test_api_provider.py index 50b8d073d4..013405429a 100644 --- a/tests/unit/commands/local/lib/test_api_provider.py +++ b/tests/unit/commands/local/lib/test_api_provider.py @@ -3,6 +3,7 @@ from mock import patch +from samcli.commands.local.lib.provider import Api from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.sam_api_provider import SamApiProvider from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider @@ -10,18 +11,17 @@ class TestApiProvider_init(TestCase): - @patch.object(ApiProvider, "_extract_apis") + @patch.object(ApiProvider, "_extract_api") @patch("samcli.commands.local.lib.api_provider.SamBaseProvider") def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): - extract_api_mock.return_value = {"set", "of", "values"} - + extract_api_mock.return_value = Api(routes={"set", "of", "values"}) template = {"Resources": {"a": "b"}} SamBaseProviderMock.get_template.return_value = template provider = ApiProvider(template) + self.assertEquals(len(provider.routes), 3) + self.assertEquals(provider.routes, set(["set", "of", "values"])) - self.assertEquals(len(provider.apis), 3) - self.assertEquals(provider.apis, set(["set", "of", "values"])) self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) self.assertEquals(provider.resources, {"a": "b"}) diff --git a/tests/unit/commands/local/lib/test_cfn_api_provider.py b/tests/unit/commands/local/lib/test_cfn_api_provider.py index 723951eb11..d4f45171e5 100644 --- a/tests/unit/commands/local/lib/test_cfn_api_provider.py +++ b/tests/unit/commands/local/lib/test_cfn_api_provider.py @@ -1,27 +1,24 @@ import json import tempfile +from collections import OrderedDict from unittest import TestCase from mock import patch from six import assertCountEqual from samcli.commands.local.lib.api_provider import ApiProvider -from samcli.commands.local.lib.provider import Api +from samcli.local.apigw.local_apigw_service import Route from tests.unit.commands.local.lib.test_sam_api_provider import make_swagger -class TestApiProviderWithApiGatewayRestApi(TestCase): +class TestApiProviderWithApiGatewayRestRoute(TestCase): def setUp(self): self.binary_types = ["image/png", "image/jpg"] - self.input_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None) + self.input_routes = [ + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") ] def test_with_no_apis(self): @@ -39,7 +36,7 @@ def test_with_no_apis(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) def test_with_inline_swagger_apis(self): template = { @@ -48,20 +45,21 @@ def test_with_inline_swagger_apis(self): "Api1": { "Type": "AWS::ApiGateway::RestApi", "Properties": { - "Body": make_swagger(self.input_apis) + "Body": make_swagger(self.input_routes) } } } } provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) + assertCountEqual(self, self.input_routes, provider.routes) def test_with_swagger_as_local_file(self): with tempfile.NamedTemporaryFile(mode='w') as fp: filename = fp.name - swagger = make_swagger(self.input_apis) + swagger = make_swagger(self.input_routes) + json.dump(swagger, fp) fp.flush() @@ -78,13 +76,14 @@ def test_with_swagger_as_local_file(self): } provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) + assertCountEqual(self, self.input_routes, provider.routes) def test_body_with_swagger_as_local_file_expect_fail(self): with tempfile.NamedTemporaryFile(mode='w') as fp: filename = fp.name - swagger = make_swagger(self.input_apis) + swagger = make_swagger(self.input_routes) + json.dump(swagger, fp) fp.flush() @@ -101,8 +100,8 @@ def test_body_with_swagger_as_local_file_expect_fail(self): } self.assertRaises(Exception, ApiProvider, template) - @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): + @patch("samcli.commands.local.lib.cfn_base_api_provider.SwaggerReader") + def test_with_swagger_as_both_body_and_uri_called(self, SwaggerReaderMock): body = {"some": "body"} filename = "somefile.txt" @@ -119,26 +118,26 @@ def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): } } - SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) + SwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_routes) cwd = "foo" provider = ApiProvider(template, cwd=cwd) - assertCountEqual(self, self.input_apis, provider.apis) - SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) + assertCountEqual(self, self.input_routes, provider.routes) + SwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) def test_swagger_with_any_method(self): - apis = [ - Api(path="/path", method="any", function_name="SamFunc1", cors=None) + routes = [ + Route(path="/path", methods=["any"], function_name="SamFunc1") ] - expected_apis = [ - Api(path="/path", method="GET", function_name="SamFunc1", cors=None), - Api(path="/path", method="POST", function_name="SamFunc1", cors=None), - Api(path="/path", method="PUT", function_name="SamFunc1", cors=None), - Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None), - Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None), - Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None), - Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None) + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], function_name="SamFunc1") ] template = { @@ -146,14 +145,14 @@ def test_swagger_with_any_method(self): "Api1": { "Type": "AWS::ApiGateway::RestApi", "Properties": { - "Body": make_swagger(apis) + "Body": make_swagger(routes) } } } } provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_with_binary_media_types(self): template = { @@ -162,7 +161,7 @@ def test_with_binary_media_types(self): "Api1": { "Type": "AWS::ApiGateway::RestApi", "Properties": { - "Body": make_swagger(self.input_apis, binary_media_types=self.binary_types) + "Body": make_swagger(self.input_routes, binary_media_types=self.binary_types) } } } @@ -170,26 +169,19 @@ def test_with_binary_media_types(self): expected_binary_types = sorted(self.binary_types) expected_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types) + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") ] provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_apis, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_binary_types) def test_with_binary_media_types_in_swagger_and_on_resource(self): - input_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1"), + input_routes = [ + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), + ] extra_binary_types = ["text/html"] @@ -200,16 +192,201 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): "Type": "AWS::ApiGateway::RestApi", "Properties": { "BinaryMediaTypes": extra_binary_types, - "Body": make_swagger(input_apis, binary_media_types=self.binary_types) + "Body": make_swagger(input_routes, binary_media_types=self.binary_types) } } } } expected_binary_types = sorted(self.binary_types + extra_binary_types) - expected_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=expected_binary_types), + expected_routes = [ + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), ] provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_binary_types) + + +class TestCloudFormationStageValues(TestCase): + def setUp(self): + self.binary_types = ["image/png", "image/jpg"] + self.input_routes = [ + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") + ] + + def test_provider_parse_stage_name(self): + template = { + "Resources": { + "Stage": { + "Type": "AWS::ApiGateway::Stage", + "Properties": { + "StageName": "dev", + "RestApiId": "TestApi" + } + }, + "TestApi": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + } + provider = ApiProvider(template) + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + + self.assertIn(route1, provider.routes) + self.assertEquals(provider.api.stage_name, "dev") + self.assertEquals(provider.api.stage_variables, None) + + def test_provider_stage_variables(self): + template = { + "Resources": { + "Stage": { + "Type": "AWS::ApiGateway::Stage", + "Properties": { + "StageName": "dev", + "Variables": { + "vis": "data", + "random": "test", + "foo": "bar" + }, + "RestApiId": "TestApi" + } + }, + "TestApi": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + } + provider = ApiProvider(template) + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + self.assertIn(route1, provider.routes) + self.assertEquals(provider.api.stage_name, "dev") + self.assertEquals(provider.api.stage_variables, { + "vis": "data", + "random": "test", + "foo": "bar" + }) + + def test_multi_stage_get_all(self): + resources = OrderedDict({ + "ProductionApi": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + }, + "/anotherpath": { + "post": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + }) + resources["StageDev"] = { + "Type": "AWS::ApiGateway::Stage", + "Properties": { + "StageName": "dev", + "Variables": { + "vis": "data", + "random": "test", + "foo": "bar" + }, + "RestApiId": "ProductionApi" + } + } + resources["StageProd"] = { + "Type": "AWS::ApiGateway::Stage", + "Properties": { + "StageName": "Production", + "Variables": { + "vis": "prod data", + "random": "test", + "foo": "bar" + }, + "RestApiId": "ProductionApi" + }, + } + template = {"Resources": resources} + provider = ApiProvider(template) + + result = [f for f in provider.get_all()] + routes = result[0].routes + + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + route2 = Route(path='/anotherpath', methods=['POST'], function_name='NoApiEventFunction') + self.assertEquals(len(routes), 2) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + + self.assertEquals(provider.api.stage_name, "Production") + self.assertEquals(provider.api.stage_variables, { + "vis": "prod data", + "random": "test", + "foo": "bar" + }) diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index cfa35af954..f43f93713e 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -6,10 +6,11 @@ from mock import Mock, patch -from samcli.commands.local.lib import provider +from samcli.commands.local.lib.provider import Api +from samcli.commands.local.lib.api_collector import ApiCollector +from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined from samcli.commands.local.lib.local_api_service import LocalApiService -from samcli.commands.local.lib.provider import Api from samcli.local.apigw.local_apigw_service import Route @@ -38,9 +39,7 @@ def setUp(self): @patch("samcli.commands.local.lib.local_api_service.ApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") - @patch.object(LocalApiService, "_make_routing_list") def test_must_start_service(self, - make_routing_list_mock, log_routes_mock, make_static_dir_mock, SamApiProviderMock, @@ -48,7 +47,6 @@ def test_must_start_service(self, routing_list = [1, 2, 3] # something static_dir_path = "/foo/bar" - make_routing_list_mock.return_value = routing_list make_static_dir_mock.return_value = static_dir_path SamApiProviderMock.return_value = self.api_provider_mock @@ -56,6 +54,7 @@ def test_must_start_service(self, # Now start the service local_service = LocalApiService(self.lambda_invoke_context_mock, self.port, self.host, self.static_dir) + local_service.api_provider.api.routes = routing_list local_service.start() # Make sure the right methods are called @@ -63,10 +62,9 @@ def test_must_start_service(self, cwd=self.cwd, parameter_overrides=self.lambda_invoke_context_mock.parameter_overrides) - make_routing_list_mock.assert_called_with(self.api_provider_mock) - log_routes_mock.assert_called_with(self.api_provider_mock, self.host, self.port) + log_routes_mock.assert_called_with(routing_list, self.host, self.port) make_static_dir_mock.assert_called_with(self.cwd, self.static_dir) - ApiGwServiceMock.assert_called_with(routing_list=routing_list, + ApiGwServiceMock.assert_called_with(api=self.api_provider_mock.api, lambda_runner=self.lambda_runner_mock, static_dir=static_dir_path, port=self.port, @@ -80,72 +78,47 @@ def test_must_start_service(self, @patch("samcli.commands.local.lib.local_api_service.ApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") - @patch.object(LocalApiService, "_make_routing_list") + @patch.object(ApiProvider, "_extract_api") def test_must_raise_if_route_not_available(self, - make_routing_list_mock, + extract_api, log_routes_mock, make_static_dir_mock, SamApiProviderMock, ApiGwServiceMock): routing_list = [] # Empty - - make_routing_list_mock.return_value = routing_list - + api = Api() + extract_api.return_value = api + SamApiProviderMock.extract_api.return_value = api SamApiProviderMock.return_value = self.api_provider_mock ApiGwServiceMock.return_value = self.apigw_service # Now start the service local_service = LocalApiService(self.lambda_invoke_context_mock, self.port, self.host, self.static_dir) - + local_service.api_provider.api.routes = routing_list with self.assertRaises(NoApisDefined): local_service.start() -class TestLocalApiService_make_routing_list(TestCase): - - def test_must_return_routing_list_from_apis(self): - api_provider = Mock() - apis = [ - Api(path="/1", method="GET1", function_name="name1", cors="CORS1"), - Api(path="/2", method="GET2", function_name="name2", cors="CORS2"), - Api(path="/3", method="GET3", function_name="name3", cors="CORS3"), - ] - expected = [ - Route(path="/1", methods=["GET1"], function_name="name1"), - Route(path="/2", methods=["GET2"], function_name="name2"), - Route(path="/3", methods=["GET3"], function_name="name3") - ] - - api_provider.get_all.return_value = apis - - result = LocalApiService._make_routing_list(api_provider) - self.assertEquals(len(result), len(expected)) - for index, r in enumerate(result): - self.assertEquals(r.__dict__, expected[index].__dict__) - - class TestLocalApiService_print_routes(TestCase): def test_must_print_routes(self): host = "host" port = 123 - api_provider = Mock() apis = [ - Api(path="/1", method="GET", function_name="name1", cors="CORS1"), - Api(path="/1", method="POST", function_name="name1", cors="CORS1"), - Api(path="/1", method="DELETE", function_name="othername1", cors="CORS1"), - Api(path="/2", method="GET2", function_name="name2", cors="CORS2"), - Api(path="/3", method="GET3", function_name="name3", cors="CORS3"), + Route(path="/1", methods=["GET"], function_name="name1"), + Route(path="/1", methods=["POST"], function_name="name1"), + Route(path="/1", methods=["DELETE"], function_name="othername1"), + Route(path="/2", methods=["GET2"], function_name="name2"), + Route(path="/3", methods=["GET3"], function_name="name3"), ] - api_provider.get_all.return_value = apis - + apis = ApiCollector.dedupe_function_routes(apis) expected = {"Mounting name1 at http://host:123/1 [GET, POST]", "Mounting othername1 at http://host:123/1 [DELETE]", "Mounting name2 at http://host:123/2 [GET2]", "Mounting name3 at http://host:123/3 [GET3]"} - actual = LocalApiService._print_routes(api_provider, host, port) + actual = LocalApiService._print_routes(apis, host, port) self.assertEquals(expected, set(actual)) @@ -181,39 +154,3 @@ def test_must_return_none_if_path_not_exists(self, os_mock): result = LocalApiService._make_static_dir_path(cwd, static_dir) self.assertIsNone(result) - - -class TestRoutingList(TestCase): - - def setUp(self): - self.function_name = "routingTest" - apis = [ - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors"), - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors", stage_name="Dev"), - provider.Api(path="/post", method="POST", function_name=self.function_name, cors="cors", stage_name="Prod"), - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors", - stage_variables={"test": "data"}), - provider.Api(path="/post", method="POST", function_name=self.function_name, cors="cors", stage_name="Prod", - stage_variables={"data": "more data"}), - ] - self.api_provider_mock = Mock() - self.api_provider_mock.get_all.return_value = apis - - def test_make_routing_list(self): - routing_list = LocalApiService._make_routing_list(self.api_provider_mock) - - expected_routes = [ - Route(function_name=self.function_name, methods=['GET'], path='/get', stage_name=None, - stage_variables=None), - Route(function_name=self.function_name, methods=['GET'], path='/get', stage_name='Dev', - stage_variables=None), - Route(function_name=self.function_name, methods=['POST'], path='/post', stage_name='Prod', - stage_variables=None), - Route(function_name=self.function_name, methods=['GET'], path='/get', stage_name=None, - stage_variables={'test': 'data'}), - Route(function_name=self.function_name, methods=['POST'], path='/post', stage_name='Prod', - stage_variables={'data': 'more data'}), - ] - self.assertEquals(len(routing_list), len(expected_routes)) - for index, r in enumerate(routing_list): - self.assertEquals(r.__dict__, expected_routes[index].__dict__) diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index fa5f342e49..a210b95eb3 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -1,15 +1,14 @@ -import tempfile import json - +import tempfile +from collections import OrderedDict from unittest import TestCase + from mock import patch from nose_parameterized import parameterized - from six import assertCountEqual -from samcli.commands.local.lib.api_provider import ApiProvider, SamApiProvider -from samcli.commands.local.lib.provider import Api -from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.local.apigw.local_apigw_service import Route class TestSamApiProviderWithImplicitApis(TestCase): @@ -26,7 +25,7 @@ def test_provider_with_no_resource_properties(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) @parameterized.expand([("GET"), ("get")]) def test_provider_has_correct_api(self, method): @@ -55,9 +54,8 @@ def test_provider_has_correct_api(self, method): provider = ApiProvider(template) - self.assertEquals(len(provider.apis), 1) - self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", cors=None, - stage_name="Prod")) + self.assertEquals(len(provider.routes), 1) + self.assertEquals(list(provider.routes)[0], Route(path="/path", methods=["GET"], function_name="SamFunc1")) def test_provider_creates_api_for_all_events(self): template = { @@ -92,12 +90,10 @@ def test_provider_creates_api_for_all_events(self): provider = ApiProvider(template) - api_event1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") - api_event2 = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") + api = Route(path="/path", methods=["GET", "POST"], function_name="SamFunc1") - self.assertIn(api_event1, provider.apis) - self.assertIn(api_event2, provider.apis) - self.assertEquals(len(provider.apis), 2) + self.assertIn(api, provider.routes) + self.assertEquals(len(provider.routes), 1) def test_provider_has_correct_template(self): template = { @@ -142,11 +138,11 @@ def test_provider_has_correct_template(self): provider = ApiProvider(template) - api1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") - api2 = Api(path="/path", method="POST", function_name="SamFunc2", cors=None, stage_name="Prod") + api1 = Route(path="/path", methods=["GET"], function_name="SamFunc1") + api2 = Route(path="/path", methods=["POST"], function_name="SamFunc2") - self.assertIn(api1, provider.apis) - self.assertIn(api2, provider.apis) + self.assertIn(api1, provider.routes) + self.assertIn(api2, provider.routes) def test_provider_with_no_api_events(self): template = { @@ -173,7 +169,7 @@ def test_provider_with_no_api_events(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) def test_provider_with_no_serverless_function(self): template = { @@ -192,7 +188,7 @@ def test_provider_with_no_serverless_function(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) def test_provider_get_all(self): template = { @@ -238,21 +234,22 @@ def test_provider_get_all(self): provider = ApiProvider(template) result = [f for f in provider.get_all()] + routes = result[0].routes + route1 = Route(path="/path", methods=["GET"], function_name="SamFunc1") + route2 = Route(path="/path", methods=["POST"], function_name="SamFunc2") - api1 = Api(path="/path", method="GET", function_name="SamFunc1", stage_name="Prod") - api2 = Api(path="/path", method="POST", function_name="SamFunc2", stage_name="Prod") - - self.assertIn(api1, result) - self.assertIn(api2, result) + self.assertIn(route1, routes) + self.assertIn(route2, routes) - def test_provider_get_all_with_no_apis(self): + def test_provider_get_all_with_no_routes(self): template = {} provider = ApiProvider(template) result = [f for f in provider.get_all()] + routes = result[0].routes - self.assertEquals(result, []) + self.assertEquals(routes, []) @parameterized.expand([("ANY"), ("any")]) def test_provider_with_any_method(self, method): @@ -281,22 +278,16 @@ def test_provider_with_any_method(self, method): provider = ApiProvider(template) - api_get = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") - api_post = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") - api_put = Api(path="/path", method="PUT", function_name="SamFunc1", cors=None, stage_name="Prod") - api_delete = Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None, stage_name="Prod") - api_patch = Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None, stage_name="Prod") - api_head = Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None, stage_name="Prod") - api_options = Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None, stage_name="Prod") - - self.assertEquals(len(provider.apis), 7) - self.assertIn(api_get, provider.apis) - self.assertIn(api_post, provider.apis) - self.assertIn(api_put, provider.apis) - self.assertIn(api_delete, provider.apis) - self.assertIn(api_patch, provider.apis) - self.assertIn(api_head, provider.apis) - self.assertIn(api_options, provider.apis) + api1 = Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], function_name="SamFunc1") + + self.assertEquals(len(provider.routes), 1) + self.assertIn(api1, provider.routes) def test_provider_must_support_binary_media_types(self): template = { @@ -334,10 +325,10 @@ def test_provider_must_support_binary_media_types(self): provider = ApiProvider(template) - self.assertEquals(len(provider.apis), 1) - self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", - binary_media_types=["image/gif", "image/png"], cors=None, - stage_name="Prod")) + self.assertEquals(len(provider.routes), 1) + self.assertEquals(list(provider.routes)[0], Route(path="/path", methods=["GET"], function_name="SamFunc1")) + assertCountEqual(self, provider.api.binary_media_types, ["image/gif", "image/png"]) + self.assertEquals(provider.api.stage_name, "Prod") def test_provider_must_support_binary_media_types_with_any_method(self): template = { @@ -374,49 +365,34 @@ def test_provider_must_support_binary_media_types_with_any_method(self): binary = ["image/gif", "image/png", "text/html"] - expected_apis = [ - Api(path="/path", method="GET", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="POST", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="PUT", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="DELETE", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="HEAD", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="PATCH", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod") + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], function_name="SamFunc1") ] provider = ApiProvider(template) - assertCountEqual(self, provider.apis, expected_apis) - - def test_convert_event_api_with_invalid_event_properties(self): - properties = { - "Path": "/foo", - "Method": "get", - "RestApiId": { - # This is not supported. Only Ref is supported - "Fn::Sub": "foo" - } - } - - with self.assertRaises(InvalidSamDocumentException): - SamApiProvider._convert_event_api("logicalId", properties) + assertCountEqual(self, provider.routes, expected_routes) + assertCountEqual(self, provider.api.binary_media_types, binary) class TestSamApiProviderWithExplicitApis(TestCase): def setUp(self): self.binary_types = ["image/png", "image/jpg"] - self.input_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod"), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod"), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, stage_name="Prod") + self.stage_name = "Prod" + self.input_routes = [ + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") ] - def test_with_no_apis(self): + def test_with_no_routes(self): template = { "Resources": { @@ -431,9 +407,9 @@ def test_with_no_apis(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) - def test_with_inline_swagger_apis(self): + def test_with_inline_swagger_routes(self): template = { "Resources": { @@ -441,20 +417,20 @@ def test_with_inline_swagger_apis(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(self.input_apis) + "DefinitionBody": make_swagger(self.input_routes) } } } } provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) + assertCountEqual(self, self.input_routes, provider.routes) def test_with_swagger_as_local_file(self): with tempfile.NamedTemporaryFile(mode='w') as fp: filename = fp.name - swagger = make_swagger(self.input_apis) + swagger = make_swagger(self.input_routes) json.dump(swagger, fp) fp.flush() @@ -472,10 +448,10 @@ def test_with_swagger_as_local_file(self): } provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) + assertCountEqual(self, self.input_routes, provider.routes) - @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): + @patch("samcli.commands.local.lib.cfn_base_api_provider.SwaggerReader") + def test_with_swagger_as_both_body_and_uri_called(self, SwaggerReaderMock): body = {"some": "body"} filename = "somefile.txt" @@ -493,26 +469,27 @@ def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): } } - SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) + SwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_routes) cwd = "foo" provider = ApiProvider(template, cwd=cwd) - assertCountEqual(self, self.input_apis, provider.apis) - SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) + assertCountEqual(self, self.input_routes, provider.routes) + SwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) def test_swagger_with_any_method(self): - apis = [ - Api(path="/path", method="any", function_name="SamFunc1", cors=None) + routes = [ + Route(path="/path", methods=["any"], function_name="SamFunc1") ] - expected_apis = [ - Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="PUT", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None, stage_name="Prod") + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], + function_name="SamFunc1") ] template = { @@ -521,14 +498,14 @@ def test_swagger_with_any_method(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(apis) + "DefinitionBody": make_swagger(routes) } } } } provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_with_binary_media_types(self): template = { @@ -538,34 +515,26 @@ def test_with_binary_media_types(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(self.input_apis, binary_media_types=self.binary_types) + "DefinitionBody": make_swagger(self.input_routes, binary_media_types=self.binary_types) } } } } expected_binary_types = sorted(self.binary_types) - expected_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod"), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod"), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod"), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod") + expected_routes = [ + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["GET", "PUT"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") ] provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_binary_types) def test_with_binary_media_types_in_swagger_and_on_resource(self): - input_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1", stage_name="Prod"), + input_routes = [ + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), ] extra_binary_types = ["text/html"] @@ -577,32 +546,33 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): "Properties": { "BinaryMediaTypes": extra_binary_types, "StageName": "Prod", - "DefinitionBody": make_swagger(input_apis, binary_media_types=self.binary_types) + "DefinitionBody": make_swagger(input_routes, binary_media_types=self.binary_types) } } } } expected_binary_types = sorted(self.binary_types + extra_binary_types) - expected_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=expected_binary_types, - stage_name="Prod"), + expected_routes = [ + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), ] provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_binary_types) class TestSamApiProviderWithExplicitAndImplicitApis(TestCase): def setUp(self): - self.explicit_apis = [ - Api(path="/path1", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod") + self.stage_name = "Prod" + self.explicit_routes = [ + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction") ] - self.swagger = make_swagger(self.explicit_apis) + self.swagger = make_swagger(self.explicit_routes) self.template = { "Resources": { @@ -655,22 +625,22 @@ def test_must_union_implicit_and_explicit(self): self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events - expected_apis = [ + expected_routes = [ # From Explicit APIs - Api(path="/path1", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs - Api(path="/path1", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path2", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path3", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") + Route(path="/path1", methods=["POST"], function_name="ImplicitFunc"), + Route(path="/path2", methods=["POST"], function_name="ImplicitFunc"), + Route(path="/path3", methods=["POST"], function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_must_prefer_implicit_api_over_explicit(self): - implicit_apis = { + implicit_routes = { "Event1": { "Type": "Api", "Properties": { @@ -690,24 +660,24 @@ def test_must_prefer_implicit_api_over_explicit(self): } self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_apis + self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes - expected_apis = [ - Api(path="/path1", method="GET", function_name="ImplicitFunc", stage_name="Prod"), + expected_routes = [ + Route(path="/path1", methods=["GET"], function_name="ImplicitFunc"), # Comes from Implicit - Api(path="/path2", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path2", method="POST", function_name="ImplicitFunc", stage_name="Prod"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["POST"], function_name="ImplicitFunc"), # Comes from implicit - Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_must_prefer_implicit_with_any_method(self): - implicit_apis = { + implicit_routes = { "Event1": { "Type": "Api", "Properties": { @@ -718,30 +688,31 @@ def test_must_prefer_implicit_with_any_method(self): } } - explicit_apis = [ + explicit_routes = [ # Explicit should be over masked completely by implicit, because of "ANY" - Api(path="/path", method="GET", function_name="explicitfunction", cors=None), - Api(path="/path", method="DELETE", function_name="explicitfunction", cors=None), + Route(path="/path", methods=["GET"], function_name="explicitfunction"), + Route(path="/path", methods=["DELETE"], function_name="explicitfunction"), ] - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_apis) - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_apis - - expected_apis = [ - Api(path="/path", method="GET", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="PUT", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="DELETE", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="HEAD", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="OPTIONS", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="PATCH", function_name="ImplicitFunc", cors=None, stage_name="Prod") + self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_routes) + self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], + function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_with_any_method_on_both(self): - implicit_apis = { + implicit_routes = { "Event1": { "Type": "Api", "Properties": { @@ -760,30 +731,32 @@ def test_with_any_method_on_both(self): } } - explicit_apis = [ + explicit_routes = [ # Explicit should be over masked completely by implicit, because of "ANY" - Api(path="/path", method="ANY", function_name="explicitfunction", cors=None), - Api(path="/path2", method="POST", function_name="explicitfunction", cors=None), + Route(path="/path", methods=["ANY"], function_name="explicitfunction"), + Route(path="/path2", methods=["POST"], function_name="explicitfunction"), ] - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_apis) - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_apis - - expected_apis = [ - Api(path="/path", method="GET", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="PUT", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="DELETE", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="HEAD", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="OPTIONS", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="PATCH", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - - Api(path="/path2", method="GET", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path2", method="POST", function_name="explicitfunction", cors=None, stage_name="Prod") + self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_routes) + self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], + function_name="ImplicitFunc"), + + Route(path="/path2", methods=["GET"], + function_name="ImplicitFunc"), + Route(path="/path2", methods=["POST"], function_name="explicitfunction") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_must_add_explicit_api_when_ref_with_rest_api_id(self): events = { @@ -809,20 +782,20 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events - expected_apis = [ + expected_routes = [ # From Explicit APIs - Api(path="/path1", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs - Api(path="/newpath1", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/newpath2", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") + Route(path="/newpath1", methods=["POST"], function_name="ImplicitFunc"), + Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) - def test_both_apis_must_get_binary_media_types(self): + def test_both_routes_must_get_binary_media_types(self): events = { "Event1": { "Type": "Api", @@ -855,27 +828,20 @@ def test_both_apis_must_get_binary_media_types(self): # Because of Globals, binary types will be concatenated on the explicit API expected_explicit_binary_types = ["explicit/type1", "explicit/type2", "image/gif", "image/png"] - expected_implicit_binary_types = ["image/gif", "image/png"] - expected_apis = [ + expected_routes = [ # From Explicit APIs - Api(path="/path1", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs - Api(path="/newpath1", method="POST", function_name="ImplicitFunc", - binary_media_types=expected_implicit_binary_types, - stage_name="Prod"), - Api(path="/newpath2", method="POST", function_name="ImplicitFunc", - binary_media_types=expected_implicit_binary_types, - stage_name="Prod") + Route(path="/newpath1", methods=["POST"], function_name="ImplicitFunc"), + Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_explicit_binary_types) def test_binary_media_types_with_rest_api_id_reference(self): events = { @@ -911,31 +877,25 @@ def test_binary_media_types_with_rest_api_id_reference(self): # Because of Globals, binary types will be concatenated on the explicit API expected_explicit_binary_types = ["explicit/type1", "explicit/type2", "image/gif", "image/png"] - expected_implicit_binary_types = ["image/gif", "image/png"] + # expected_implicit_binary_types = ["image/gif", "image/png"] - expected_apis = [ + expected_routes = [ # From Explicit APIs - Api(path="/path1", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # Because of the RestApiId, Implicit APIs will also get the binary media types inherited from # the corresponding Explicit API - Api(path="/connected-to-explicit-path", method="POST", function_name="ImplicitFunc", - binary_media_types=expected_explicit_binary_types, - stage_name="Prod"), + Route(path="/connected-to-explicit-path", methods=["POST"], function_name="ImplicitFunc"), # This is still just a true implicit API because it does not have RestApiId property - Api(path="/true-implicit-path", method="POST", function_name="ImplicitFunc", - binary_media_types=expected_implicit_binary_types, - stage_name="Prod") + Route(path="/true-implicit-path", methods=["POST"], function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_explicit_binary_types) class TestSamStageValues(TestCase): @@ -971,11 +931,11 @@ def test_provider_parse_stage_name(self): } } provider = ApiProvider(template) - api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], - stage_name='dev', - stage_variables=None) + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') - self.assertIn(api1, provider.apis) + self.assertIn(route1, provider.routes) + self.assertEquals(provider.api.stage_name, "dev") + self.assertEquals(provider.api.stage_variables, None) def test_provider_stage_variables(self): template = { @@ -1013,125 +973,120 @@ def test_provider_stage_variables(self): } } provider = ApiProvider(template) - api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], - stage_name='dev', - stage_variables={ - "vis": "data", - "random": "test", - "foo": "bar" - }) + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') - self.assertIn(api1, provider.apis) + self.assertIn(route1, provider.routes) + self.assertEquals(provider.api.stage_name, "dev") + self.assertEquals(provider.api.stage_variables, { + "vis": "data", + "random": "test", + "foo": "bar" + }) def test_multi_stage_get_all(self): - template = { - "Resources": { - "TestApi": { - "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "dev", - "Variables": { - "vis": "data", - "random": "test", - "foo": "bar" - }, - "DefinitionBody": { - "paths": { - "/path2": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } + template = OrderedDict({ + "Resources": {} + }) + template["Resources"]["TestApi"] = { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "dev", + "Variables": { + "vis": "data", + "random": "test", + "foo": "bar" + }, + "DefinitionBody": { + "paths": { + "/path2": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, } } } + } + } + } + + template["Resources"]["ProductionApi"] = { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Production", + "Variables": { + "vis": "prod data", + "random": "test", + "foo": "bar" }, - "ProductionApi": { - "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "Production", - "Variables": { - "vis": "prod data", - "random": "test", - "foo": "bar" + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } }, - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } + "/anotherpath": { + "post": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, }, - "/anotherpath": { - "post": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - } } + } } } } + provider = ApiProvider(template) result = [f for f in provider.get_all()] + routes = result[0].routes + + route1 = Route(path='/path2', methods=['GET'], function_name='NoApiEventFunction') + route2 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + route3 = Route(path='/anotherpath', methods=['POST'], function_name='NoApiEventFunction') + self.assertEquals(len(routes), 3) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + self.assertIn(route3, routes) + + self.assertEquals(provider.api.stage_name, "Production") + self.assertEquals(provider.api.stage_variables, { + "vis": "prod data", + "random": "test", + "foo": "bar" + }) - api1 = Api(path='/path2', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], - stage_name='dev', - stage_variables={ - "vis": "data", - "random": "test", - "foo": "bar" - }) - api2 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], - stage_name='Production', stage_variables={'vis': 'prod data', 'random': 'test', 'foo': 'bar'}) - api3 = Api(path='/anotherpath', method='POST', function_name='NoApiEventFunction', cors=None, - binary_media_types=[], - stage_name='Production', - stage_variables={ - "vis": "prod data", - "random": "test", - "foo": "bar" - }) - self.assertEquals(len(result), 3) - self.assertIn(api1, result) - self.assertIn(api2, result) - self.assertIn(api3, result) - - -def make_swagger(apis, binary_media_types=None): + +def make_swagger(routes, binary_media_types=None): """ Given a list of API configurations named tuples, returns a Swagger document Parameters ---------- - apis : list of samcli.commands.local.lib.provider.Api + routes : list of samcli.commands.local.agiw.local_agiw_service.Route binary_media_types : list of str Returns @@ -1145,7 +1100,7 @@ def make_swagger(apis, binary_media_types=None): } } - for api in apis: + for api in routes: swagger["paths"].setdefault(api.path, {}) integration = { @@ -1156,12 +1111,11 @@ def make_swagger(apis, binary_media_types=None): api.function_name) # NOQA } } + for method in api.methods: + if method.lower() == "any": + method = "x-amazon-apigateway-any-method" - method = api.method - if method.lower() == "any": - method = "x-amazon-apigateway-any-method" - - swagger["paths"][api.path][method] = integration + swagger["paths"][api.path][method] = integration if binary_media_types: swagger["x-amazon-apigateway-binary-media-types"] = binary_media_types diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index ba2d6316b5..9bbf52cc62 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -1,3 +1,4 @@ +import copy from unittest import TestCase from mock import Mock, patch, ANY import json @@ -6,6 +7,7 @@ from parameterized import parameterized, param from werkzeug.datastructures import Headers +from samcli.commands.local.lib.provider import Api from samcli.local.apigw.local_apigw_service import LocalApigwService, Route from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -14,14 +16,15 @@ class TestApiGatewayService(TestCase): def setUp(self): self.function_name = Mock() - self.api_gateway_route = Route(['GET'], self.function_name, '/') + self.api_gateway_route = Route(methods=['GET'], function_name=self.function_name, path='/') self.list_of_routes = [self.api_gateway_route] self.lambda_runner = Mock() self.lambda_runner.is_debugging.return_value = False self.stderr = Mock() - self.service = LocalApigwService(self.list_of_routes, + self.api = Api(routes=self.list_of_routes) + self.service = LocalApigwService(self.api, self.lambda_runner, port=3000, host='127.0.0.1', @@ -102,14 +105,15 @@ def test_request_handler_returns_make_response(self): def test_create_creates_dict_of_routes(self): function_name_1 = Mock() function_name_2 = Mock() - api_gateway_route_1 = Route(['GET'], function_name_1, '/') - api_gateway_route_2 = Route(['POST'], function_name_2, '/') + api_gateway_route_1 = Route(methods=["GET"], function_name=function_name_1, path='/') + api_gateway_route_2 = Route(methods=["POST"], function_name=function_name_2, path='/') list_of_routes = [api_gateway_route_1, api_gateway_route_2] lambda_runner = Mock() - service = LocalApigwService(list_of_routes, lambda_runner) + api = Api(routes=list_of_routes) + service = LocalApigwService(api, lambda_runner) service.create() @@ -135,16 +139,16 @@ def test_create_creates_flask_app_with_url_rules(self, flask): def test_initalize_creates_default_values(self): self.assertEquals(self.service.port, 3000) self.assertEquals(self.service.host, '127.0.0.1') - self.assertEquals(self.service.routing_list, self.list_of_routes) + self.assertEquals(self.service.api.routes, self.list_of_routes) self.assertIsNone(self.service.static_dir) self.assertEquals(self.service.lambda_runner, self.lambda_runner) def test_initalize_with_values(self): lambda_runner = Mock() - local_service = LocalApigwService([], lambda_runner, static_dir='dir/static', port=5000, host='129.0.0.0') + local_service = LocalApigwService(Api(), lambda_runner, static_dir='dir/static', port=5000, host='129.0.0.0') self.assertEquals(local_service.port, 5000) self.assertEquals(local_service.host, '129.0.0.0') - self.assertEquals(local_service.routing_list, []) + self.assertEquals(local_service.api.routes, []) self.assertEquals(local_service.static_dir, 'dir/static') self.assertEquals(local_service.lambda_runner, lambda_runner) @@ -250,19 +254,12 @@ class TestApiGatewayModel(TestCase): def setUp(self): self.function_name = "name" - self.stage_name = "Dev" - self.stage_variables = { - "test": "sample" - } - self.api_gateway = Route(['POST'], self.function_name, '/', stage_name=self.stage_name, - stage_variables=self.stage_variables) + self.api_gateway = Route(function_name=self.function_name, methods=["Post"], path="/") def test_class_initialization(self): self.assertEquals(self.api_gateway.methods, ['POST']) self.assertEquals(self.api_gateway.function_name, self.function_name) self.assertEquals(self.api_gateway.path, '/') - self.assertEqual(self.api_gateway.stage_name, "Dev") - self.assertEqual(self.api_gateway.stage_variables, {"test": "sample"}) class TestLambdaHeaderDictionaryMerge(TestCase): @@ -488,7 +485,7 @@ def setUp(self): '"Custom User Agent String", "caller": null, "cognitoAuthenticationType": null, "sourceIp": ' \ '"190.0.0.0", "user": null}, "accountId": "123456789012"}, "headers": {"Content-Type": ' \ '"application/json", "X-Test": "Value", "X-Forwarded-Port": "3000", "X-Forwarded-Proto": "http"}, ' \ - '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], '\ + '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], ' \ '"X-Forwarded-Port": ["3000"], "X-Forwarded-Proto": ["http"]}, ' \ '"stageVariables": null, "path": "path", "pathParameters": {"path": "params"}, ' \ '"isBase64Encoded": false}' @@ -590,3 +587,60 @@ def test_should_base64_encode_returns_true(self, test_case_name, binary_types, m ]) def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype): self.assertFalse(LocalApigwService._should_base64_encode(binary_types, mimetype)) + + +class TestRouteEqualsHash(TestCase): + + def test_route_in_list(self): + route = Route(function_name="test", path="/test", methods=["POST"]) + routes = [route] + self.assertIn(route, routes) + + def test_route_method_order_equals(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + route2 = Route(function_name="test", path="/test", methods=["GET", "POST"]) + self.assertEquals(route1, route2) + + def test_route_hash(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + dic = {route1: "test"} + self.assertEquals(dic[route1], "test") + + def test_route_object_equals(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + route2 = type('obj', (object,), {'function_name': 'test', "path": "/test", "methods": ["GET", "POST"]}) + + self.assertNotEqual(route1, route2) + + def test_route_function_name_equals(self): + route1 = Route(function_name="test1", path="/test", methods=["GET", "POST"]) + route2 = Route(function_name="test2", path="/test", methods=["GET", "POST"]) + self.assertNotEqual(route1, route2) + + def test_route_different_path_equals(self): + route1 = Route(function_name="test", path="/test1", methods=["GET", "POST"]) + route2 = Route(function_name="test", path="/test2", methods=["GET", "POST"]) + self.assertNotEqual(route1, route2) + + def test_same_object_equals(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + self.assertEquals(route1, copy.deepcopy(route1)) + + def test_route_function_name_hash(self): + route1 = Route(function_name="test1", path="/test", methods=["GET", "POST"]) + route2 = Route(function_name="test2", path="/test", methods=["GET", "POST"]) + self.assertNotEqual(route1.__hash__(), route2.__hash__()) + + def test_route_different_path_hash(self): + route1 = Route(function_name="test", path="/test1", methods=["GET", "POST"]) + route2 = Route(function_name="test", path="/test2", methods=["GET", "POST"]) + self.assertNotEqual(route1.__hash__(), route2.__hash__()) + + def test_same_object_hash(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + self.assertEquals(route1.__hash__(), copy.deepcopy(route1).__hash__()) + + def test_route_method_order_hash(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + route2 = Route(function_name="test", path="/test", methods=["GET", "POST"]) + self.assertEquals(route1.__hash__(), route2.__hash__()) From fa5dad3a1e73797d0098572f5d9129d331aa6cea Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Fri, 26 Jul 2019 12:32:38 -0700 Subject: [PATCH 14/30] Update cors tests --- samcli/commands/local/lib/provider.py | 26 ++++++++++- samcli/commands/local/lib/sam_api_provider.py | 4 +- samcli/local/apigw/local_apigw_service.py | 42 +++++++----------- .../local/start_api/start_api_integ_base.py | 6 +-- .../local/start_api/test_start_api.py | 3 ++ .../local/apigw/test_local_apigw_service.py | 44 ++++++++++++------- 6 files changed, 76 insertions(+), 49 deletions(-) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 971e1c9e0c..1283a6b577 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -232,12 +232,34 @@ def binary_media_types(self): _CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None None, # Allow Methods is optional and defaults to empty None, # Allow Headers is optional and defaults to empty - None # MaxAge is optional and defaults to empty + None # MaxAge is optional and defaults to empty ) class Cors(_CorsTuple): - pass + + @staticmethod + def cors_to_headers(cors): + """ + Convert CORS object to headers dictionary + Parameters + ---------- + cors list(samcli.commands.local.lib.provider.Cors) + CORS configuration objcet + Returns + ------- + Dictionary with CORS headers + """ + if not cors: + return {} + headers = { + 'Access-Control-Allow-Origin': cors.allow_origin, + 'Access-Control-Allow-Methods': cors.allow_methods, + 'Access-Control-Allow-Headers': cors.allow_headers, + 'Access-Control-Max-Age': cors.max_age + } + + return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None} class AbstractApiProvider(object): diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 9ab9fa0639..9e8047851a 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -104,11 +104,11 @@ def _extract_cors(self, cors_prop): cors_prop : dict Resource properties for Cors """ - cors = {} + cors = None if cors_prop and isinstance(cors_prop, dict): allow_methods = cors_prop.get("AllowMethods", ','.join(Route.ANY_HTTP_METHODS)) - if allow_methods and "OPTIONS" not in allow_methods: + if allow_methods and "OPTIONS" not in allow_methods and "options" not in allow_methods: allow_methods += ",OPTIONS" cors = Cors( diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index e750fa846a..3a5bb6fade 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -7,6 +7,7 @@ from flask import Flask, request from werkzeug.datastructures import Headers +from samcli.commands.local.lib.provider import Cors from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser from samcli.lib.utils.stream_writer import StreamWriter from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -169,9 +170,10 @@ def _request_handler(self, **kwargs): Response object """ route = self._get_current_route(request) - cors_headers = LocalApigwService.cors_to_headers(route.cors) + cors_headers = Cors.cors_to_headers(self.api.cors) - if 'OPTIONS' in route.methods: + method, _ = self.get_request_methods_endpoints(request) + if method == 'OPTIONS': headers = Headers(cors_headers) return self.service_response('', headers, 200) @@ -215,8 +217,7 @@ def _get_current_route(self, flask_request): :param request flask_request: Flask Request :return: Route matching the endpoint and method of the request """ - endpoint = flask_request.endpoint - method = flask_request.method + endpoint, method = self.get_request_methods_endpoints(flask_request) route_key = self._route_key(method, endpoint) route = self._dict_of_routes.get(route_key, None) @@ -229,6 +230,16 @@ def _get_current_route(self, flask_request): return route + def get_request_methods_endpoints(self, flask_request): + """ + Separated out for testing requests in request handler + :param request flask_request: Flask Request + :return: the request's endpoint and method + """ + endpoint = flask_request.endpoint + method = flask_request.method + return method, endpoint + # Consider moving this out to its own class. Logic is started to get dense and looks messy @jfuss @staticmethod def _parse_lambda_output(lambda_output, binary_types, flask_request): @@ -493,26 +504,3 @@ def _should_base64_encode(binary_types, request_mimetype): """ return request_mimetype in binary_types or "*/*" in binary_types - - @staticmethod - def cors_to_headers(cors): - """ - Convert CORS object to headers dictionary - Parameters - ---------- - cors list(samcli.commands.local.lib.provider.Cors) - CORS configuration objcet - Returns - ------- - Dictionary with CORS headers - """ - if not cors: - return {} - headers = { - 'Access-Control-Allow-Origin': cors.allow_origin, - 'Access-Control-Allow-Methods': cors.allow_methods, - 'Access-Control-Allow-Headers': cors.allow_headers, - 'Access-Control-Max-Age': cors.max_age - } - - return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None} diff --git a/tests/integration/local/start_api/start_api_integ_base.py b/tests/integration/local/start_api/start_api_integ_base.py index 08306a2649..c9ee1cbdec 100644 --- a/tests/integration/local/start_api/start_api_integ_base.py +++ b/tests/integration/local/start_api/start_api_integ_base.py @@ -33,9 +33,9 @@ def setUpClass(cls): @classmethod def start_api(cls): - command = "sam" - if os.getenv("SAM_CLI_DEV"): - command = "samdev" + # command = "sam" + # if os.getenv("SAM_CLI_DEV"): + command = "samdev" cls.start_api_process = Popen([command, "local", "start-api", "-t", cls.template, "-p", cls.port, "--debug"]) # we need to wait some time for start-api to start, hence the sleep diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index ff66a3ba7f..b68b9ca561 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -665,6 +665,9 @@ class TestServiceCorsSwaggerRequests(StartApiIntegBaseClass): template_path = "/testdata/start_api/swagger-template.yaml" binary_data_file = "testdata/start_api/binarydata.gif" + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + def test_cors_swagger_options(self): """ This tests that the Cors are added to option requests in the swagger template diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index 51e715a8b9..30143c8f0b 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -1,14 +1,14 @@ +import base64 import copy -from unittest import TestCase -from mock import Mock, patch, ANY, MagicMock import json -import base64 +from unittest import TestCase +from mock import Mock, patch, ANY, MagicMock from parameterized import parameterized, param from werkzeug.datastructures import Headers -from samcli.commands.local.lib.provider import Cors from samcli.commands.local.lib.provider import Api +from samcli.commands.local.lib.provider import Cors from samcli.local.apigw.local_apigw_service import LocalApigwService, Route from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -31,7 +31,8 @@ def setUp(self): host='127.0.0.1', stderr=self.stderr) - def test_request_must_invoke_lambda(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_must_invoke_lambda(self, request_mock): make_response_mock = Mock() self.service.service_response = make_response_mock @@ -47,6 +48,8 @@ def test_request_must_invoke_lambda(self): service_response_mock.return_value = make_response_mock self.service.service_response = service_response_mock + request_mock.return_value = ('test', 'test') + result = self.service._request_handler() self.assertEquals(result, make_response_mock) @@ -55,10 +58,11 @@ def test_request_must_invoke_lambda(self): stdout=ANY, stderr=self.stderr) + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.LambdaOutputParser') - def test_request_handler_returns_process_stdout_when_making_response(self, lambda_output_parser_mock): + def test_request_handler_returns_process_stdout_when_making_response(self, lambda_output_parser_mock, request_mock): make_response_mock = Mock() - + request_mock.return_value = ('test', 'test') self.service.service_response = make_response_mock self.service._get_current_route = MagicMock() self.service._get_current_route.methods = [] @@ -87,7 +91,8 @@ def test_request_handler_returns_process_stdout_when_making_response(self, lambd # Make sure the logs are written to stderr self.stderr.write.assert_called_with(lambda_logs) - def test_request_handler_returns_make_response(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_handler_returns_make_response(self, request_mock): make_response_mock = Mock() self.service.service_response = make_response_mock @@ -103,6 +108,7 @@ def test_request_handler_returns_make_response(self): service_response_mock.return_value = make_response_mock self.service.service_response = service_response_mock + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, make_response_mock) @@ -157,8 +163,9 @@ def test_initalize_with_values(self): self.assertEquals(local_service.static_dir, 'dir/static') self.assertEquals(local_service.lambda_runner, lambda_runner) + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handles_error_when_invoke_cant_find_function(self, service_error_responses_patch): + def test_request_handles_error_when_invoke_cant_find_function(self, service_error_responses_patch, request_mock): not_found_response_mock = Mock() self.service._construct_event = Mock() self.service._get_current_route = MagicMock() @@ -167,22 +174,26 @@ def test_request_handles_error_when_invoke_cant_find_function(self, service_erro service_error_responses_patch.lambda_not_found_response.return_value = not_found_response_mock self.lambda_runner.invoke.side_effect = FunctionNotFound() - + request_mock.return_value = ('test', 'test') response = self.service._request_handler() self.assertEquals(response, not_found_response_mock) - def test_request_throws_when_invoke_fails(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_throws_when_invoke_fails(self, request_mock): self.lambda_runner.invoke.side_effect = Exception() self.service._construct_event = Mock() self.service._get_current_route = Mock() + request_mock.return_value = ('test', 'test') with self.assertRaises(Exception): self.service._request_handler() + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, service_error_responses_patch): + def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, service_error_responses_patch, + request_mock): parse_output_mock = Mock() parse_output_mock.side_effect = KeyError() self.service._parse_lambda_output = parse_output_mock @@ -195,6 +206,7 @@ def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, s self.service._get_current_route = MagicMock() self.service._get_current_route.methods = [] + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, failure_response_mock) @@ -208,8 +220,9 @@ def test_request_handler_errors_when_get_current_route_fails(self, service_error with self.assertRaises(KeyError): self.service._request_handler() + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handler_errors_when_unable_to_read_binary_data(self, service_error_responses_patch): + def test_request_handler_errors_when_unable_to_read_binary_data(self, service_error_responses_patch, request_mock): _construct_event = Mock() _construct_event.side_effect = UnicodeDecodeError("utf8", b"obj", 1, 2, "reason") self.service._get_current_route = MagicMock() @@ -220,6 +233,7 @@ def test_request_handler_errors_when_unable_to_read_binary_data(self, service_er failure_mock = Mock() service_error_responses_patch.lambda_failure_response.return_value = failure_mock + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, failure_mock) @@ -603,14 +617,14 @@ class TestServiceCorsToHeaders(TestCase): def test_basic_conversion(self): cors = Cors(allow_origin="*", allow_methods=','.join(["POST", "OPTIONS"]), allow_headers="UPGRADE-HEADER", max_age=6) - headers = LocalApigwService.cors_to_headers(cors) + headers = Cors.cors_to_headers(cors) self.assertEquals(headers, {'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Methods': 'POST,OPTIONS', 'Access-Control-Allow-Headers': 'UPGRADE-HEADER', 'Access-Control-Max-Age': 6}) def test_empty_elements(self): cors = Cors(allow_origin="www.domain.com", allow_methods=','.join(["GET", "POST", "OPTIONS"])) - headers = LocalApigwService.cors_to_headers(cors) + headers = Cors.cors_to_headers(cors) self.assertEquals(headers, {'Access-Control-Allow-Origin': 'www.domain.com', From 5853a33990d81a0dc725616fdd0027c470d37c0a Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Fri, 26 Jul 2019 12:32:38 -0700 Subject: [PATCH 15/30] Update cors tests --- samcli/commands/local/lib/provider.py | 26 ++++++++++- samcli/commands/local/lib/sam_api_provider.py | 4 +- samcli/local/apigw/local_apigw_service.py | 42 +++++++----------- .../local/start_api/test_start_api.py | 3 ++ .../local/apigw/test_local_apigw_service.py | 44 ++++++++++++------- 5 files changed, 73 insertions(+), 46 deletions(-) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 971e1c9e0c..1283a6b577 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -232,12 +232,34 @@ def binary_media_types(self): _CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None None, # Allow Methods is optional and defaults to empty None, # Allow Headers is optional and defaults to empty - None # MaxAge is optional and defaults to empty + None # MaxAge is optional and defaults to empty ) class Cors(_CorsTuple): - pass + + @staticmethod + def cors_to_headers(cors): + """ + Convert CORS object to headers dictionary + Parameters + ---------- + cors list(samcli.commands.local.lib.provider.Cors) + CORS configuration objcet + Returns + ------- + Dictionary with CORS headers + """ + if not cors: + return {} + headers = { + 'Access-Control-Allow-Origin': cors.allow_origin, + 'Access-Control-Allow-Methods': cors.allow_methods, + 'Access-Control-Allow-Headers': cors.allow_headers, + 'Access-Control-Max-Age': cors.max_age + } + + return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None} class AbstractApiProvider(object): diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 9ab9fa0639..9e8047851a 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -104,11 +104,11 @@ def _extract_cors(self, cors_prop): cors_prop : dict Resource properties for Cors """ - cors = {} + cors = None if cors_prop and isinstance(cors_prop, dict): allow_methods = cors_prop.get("AllowMethods", ','.join(Route.ANY_HTTP_METHODS)) - if allow_methods and "OPTIONS" not in allow_methods: + if allow_methods and "OPTIONS" not in allow_methods and "options" not in allow_methods: allow_methods += ",OPTIONS" cors = Cors( diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index e750fa846a..3a5bb6fade 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -7,6 +7,7 @@ from flask import Flask, request from werkzeug.datastructures import Headers +from samcli.commands.local.lib.provider import Cors from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser from samcli.lib.utils.stream_writer import StreamWriter from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -169,9 +170,10 @@ def _request_handler(self, **kwargs): Response object """ route = self._get_current_route(request) - cors_headers = LocalApigwService.cors_to_headers(route.cors) + cors_headers = Cors.cors_to_headers(self.api.cors) - if 'OPTIONS' in route.methods: + method, _ = self.get_request_methods_endpoints(request) + if method == 'OPTIONS': headers = Headers(cors_headers) return self.service_response('', headers, 200) @@ -215,8 +217,7 @@ def _get_current_route(self, flask_request): :param request flask_request: Flask Request :return: Route matching the endpoint and method of the request """ - endpoint = flask_request.endpoint - method = flask_request.method + endpoint, method = self.get_request_methods_endpoints(flask_request) route_key = self._route_key(method, endpoint) route = self._dict_of_routes.get(route_key, None) @@ -229,6 +230,16 @@ def _get_current_route(self, flask_request): return route + def get_request_methods_endpoints(self, flask_request): + """ + Separated out for testing requests in request handler + :param request flask_request: Flask Request + :return: the request's endpoint and method + """ + endpoint = flask_request.endpoint + method = flask_request.method + return method, endpoint + # Consider moving this out to its own class. Logic is started to get dense and looks messy @jfuss @staticmethod def _parse_lambda_output(lambda_output, binary_types, flask_request): @@ -493,26 +504,3 @@ def _should_base64_encode(binary_types, request_mimetype): """ return request_mimetype in binary_types or "*/*" in binary_types - - @staticmethod - def cors_to_headers(cors): - """ - Convert CORS object to headers dictionary - Parameters - ---------- - cors list(samcli.commands.local.lib.provider.Cors) - CORS configuration objcet - Returns - ------- - Dictionary with CORS headers - """ - if not cors: - return {} - headers = { - 'Access-Control-Allow-Origin': cors.allow_origin, - 'Access-Control-Allow-Methods': cors.allow_methods, - 'Access-Control-Allow-Headers': cors.allow_headers, - 'Access-Control-Max-Age': cors.max_age - } - - return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None} diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index ff66a3ba7f..b68b9ca561 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -665,6 +665,9 @@ class TestServiceCorsSwaggerRequests(StartApiIntegBaseClass): template_path = "/testdata/start_api/swagger-template.yaml" binary_data_file = "testdata/start_api/binarydata.gif" + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + def test_cors_swagger_options(self): """ This tests that the Cors are added to option requests in the swagger template diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index 51e715a8b9..30143c8f0b 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -1,14 +1,14 @@ +import base64 import copy -from unittest import TestCase -from mock import Mock, patch, ANY, MagicMock import json -import base64 +from unittest import TestCase +from mock import Mock, patch, ANY, MagicMock from parameterized import parameterized, param from werkzeug.datastructures import Headers -from samcli.commands.local.lib.provider import Cors from samcli.commands.local.lib.provider import Api +from samcli.commands.local.lib.provider import Cors from samcli.local.apigw.local_apigw_service import LocalApigwService, Route from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -31,7 +31,8 @@ def setUp(self): host='127.0.0.1', stderr=self.stderr) - def test_request_must_invoke_lambda(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_must_invoke_lambda(self, request_mock): make_response_mock = Mock() self.service.service_response = make_response_mock @@ -47,6 +48,8 @@ def test_request_must_invoke_lambda(self): service_response_mock.return_value = make_response_mock self.service.service_response = service_response_mock + request_mock.return_value = ('test', 'test') + result = self.service._request_handler() self.assertEquals(result, make_response_mock) @@ -55,10 +58,11 @@ def test_request_must_invoke_lambda(self): stdout=ANY, stderr=self.stderr) + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.LambdaOutputParser') - def test_request_handler_returns_process_stdout_when_making_response(self, lambda_output_parser_mock): + def test_request_handler_returns_process_stdout_when_making_response(self, lambda_output_parser_mock, request_mock): make_response_mock = Mock() - + request_mock.return_value = ('test', 'test') self.service.service_response = make_response_mock self.service._get_current_route = MagicMock() self.service._get_current_route.methods = [] @@ -87,7 +91,8 @@ def test_request_handler_returns_process_stdout_when_making_response(self, lambd # Make sure the logs are written to stderr self.stderr.write.assert_called_with(lambda_logs) - def test_request_handler_returns_make_response(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_handler_returns_make_response(self, request_mock): make_response_mock = Mock() self.service.service_response = make_response_mock @@ -103,6 +108,7 @@ def test_request_handler_returns_make_response(self): service_response_mock.return_value = make_response_mock self.service.service_response = service_response_mock + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, make_response_mock) @@ -157,8 +163,9 @@ def test_initalize_with_values(self): self.assertEquals(local_service.static_dir, 'dir/static') self.assertEquals(local_service.lambda_runner, lambda_runner) + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handles_error_when_invoke_cant_find_function(self, service_error_responses_patch): + def test_request_handles_error_when_invoke_cant_find_function(self, service_error_responses_patch, request_mock): not_found_response_mock = Mock() self.service._construct_event = Mock() self.service._get_current_route = MagicMock() @@ -167,22 +174,26 @@ def test_request_handles_error_when_invoke_cant_find_function(self, service_erro service_error_responses_patch.lambda_not_found_response.return_value = not_found_response_mock self.lambda_runner.invoke.side_effect = FunctionNotFound() - + request_mock.return_value = ('test', 'test') response = self.service._request_handler() self.assertEquals(response, not_found_response_mock) - def test_request_throws_when_invoke_fails(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_throws_when_invoke_fails(self, request_mock): self.lambda_runner.invoke.side_effect = Exception() self.service._construct_event = Mock() self.service._get_current_route = Mock() + request_mock.return_value = ('test', 'test') with self.assertRaises(Exception): self.service._request_handler() + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, service_error_responses_patch): + def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, service_error_responses_patch, + request_mock): parse_output_mock = Mock() parse_output_mock.side_effect = KeyError() self.service._parse_lambda_output = parse_output_mock @@ -195,6 +206,7 @@ def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, s self.service._get_current_route = MagicMock() self.service._get_current_route.methods = [] + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, failure_response_mock) @@ -208,8 +220,9 @@ def test_request_handler_errors_when_get_current_route_fails(self, service_error with self.assertRaises(KeyError): self.service._request_handler() + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handler_errors_when_unable_to_read_binary_data(self, service_error_responses_patch): + def test_request_handler_errors_when_unable_to_read_binary_data(self, service_error_responses_patch, request_mock): _construct_event = Mock() _construct_event.side_effect = UnicodeDecodeError("utf8", b"obj", 1, 2, "reason") self.service._get_current_route = MagicMock() @@ -220,6 +233,7 @@ def test_request_handler_errors_when_unable_to_read_binary_data(self, service_er failure_mock = Mock() service_error_responses_patch.lambda_failure_response.return_value = failure_mock + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, failure_mock) @@ -603,14 +617,14 @@ class TestServiceCorsToHeaders(TestCase): def test_basic_conversion(self): cors = Cors(allow_origin="*", allow_methods=','.join(["POST", "OPTIONS"]), allow_headers="UPGRADE-HEADER", max_age=6) - headers = LocalApigwService.cors_to_headers(cors) + headers = Cors.cors_to_headers(cors) self.assertEquals(headers, {'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Methods': 'POST,OPTIONS', 'Access-Control-Allow-Headers': 'UPGRADE-HEADER', 'Access-Control-Max-Age': 6}) def test_empty_elements(self): cors = Cors(allow_origin="www.domain.com", allow_methods=','.join(["GET", "POST", "OPTIONS"])) - headers = LocalApigwService.cors_to_headers(cors) + headers = Cors.cors_to_headers(cors) self.assertEquals(headers, {'Access-Control-Allow-Origin': 'www.domain.com', From 5b89f27c6ff0e51715e6626897948c9a13de3385 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Fri, 26 Jul 2019 12:33:47 -0700 Subject: [PATCH 16/30] update test --- tests/integration/local/start_api/start_api_integ_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/local/start_api/start_api_integ_base.py b/tests/integration/local/start_api/start_api_integ_base.py index c9ee1cbdec..08306a2649 100644 --- a/tests/integration/local/start_api/start_api_integ_base.py +++ b/tests/integration/local/start_api/start_api_integ_base.py @@ -33,9 +33,9 @@ def setUpClass(cls): @classmethod def start_api(cls): - # command = "sam" - # if os.getenv("SAM_CLI_DEV"): - command = "samdev" + command = "sam" + if os.getenv("SAM_CLI_DEV"): + command = "samdev" cls.start_api_process = Popen([command, "local", "start-api", "-t", cls.template, "-p", cls.port, "--debug"]) # we need to wait some time for start-api to start, hence the sleep From 41da0eb23afb45ccb522ab2a10fcb6a6a8bd3220 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Fri, 26 Jul 2019 12:49:21 -0700 Subject: [PATCH 17/30] Update cors with comments --- samcli/commands/local/lib/provider.py | 3 ++- samcli/local/apigw/local_apigw_service.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 1283a6b577..7a3762c469 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -258,7 +258,8 @@ def cors_to_headers(cors): 'Access-Control-Allow-Headers': cors.allow_headers, 'Access-Control-Max-Age': cors.max_age } - + # Filters out items in the headers dictionary that isn't empty. + # This is required because the flask Headers dict will send an invalid 'None' string return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None} diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 3a5bb6fade..f6929366db 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -217,7 +217,7 @@ def _get_current_route(self, flask_request): :param request flask_request: Flask Request :return: Route matching the endpoint and method of the request """ - endpoint, method = self.get_request_methods_endpoints(flask_request) + method, endpoint = self.get_request_methods_endpoints(flask_request) route_key = self._route_key(method, endpoint) route = self._dict_of_routes.get(route_key, None) From 161de22988793422e0bf03cb80391f26af7e74be Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Mon, 29 Jul 2019 10:03:11 -0700 Subject: [PATCH 18/30] Fix rebase error --- samcli/commands/local/lib/provider.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 7a3762c469..e9717b54ea 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -275,12 +275,3 @@ def get_all(self): :yields Api: namedtuple containing the API information """ raise NotImplementedError("not implemented") - - @staticmethod - def normalize_http_methods(api): - """ - Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all - supported Http Methods on Api Gateway. - :param api api: Api - :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) - """ From e14e84f6af4ffc934720753f9d69c8cc62666116 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Mon, 29 Jul 2019 10:50:15 -0700 Subject: [PATCH 19/30] Remove multi value headers --- .vscode/settings.json | 3 +++ samcli/local/apigw/local_apigw_service.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000..53d8ec25db --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "venv/bin/python3.7" +} \ No newline at end of file diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index f6929366db..51bcf6d59d 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -483,7 +483,6 @@ def _event_headers(flask_request, port, cors_headers): multi_value_headers_dict["X-Forwarded-Port"] = [str(port)] if cors_headers: headers_dict.update(cors_headers) - multi_value_headers_dict.update(cors_headers) return headers_dict, multi_value_headers_dict @staticmethod From a4a8ee6f8014ea8b42919f076d8f444651d85f87 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Mon, 29 Jul 2019 13:58:22 -0700 Subject: [PATCH 20/30] Update cors allow methods --- samcli/commands/local/lib/sam_api_provider.py | 33 ++++++++++++++++--- samcli/local/apigw/local_apigw_service.py | 10 +++--- .../local/apigw/test_local_apigw_service.py | 4 +-- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 9e8047851a..1360d31182 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -107,10 +107,7 @@ def _extract_cors(self, cors_prop): cors = None if cors_prop and isinstance(cors_prop, dict): allow_methods = cors_prop.get("AllowMethods", ','.join(Route.ANY_HTTP_METHODS)) - - if allow_methods and "OPTIONS" not in allow_methods and "options" not in allow_methods: - allow_methods += ",OPTIONS" - + allow_methods = self.normalize_cors_allow_methods(allow_methods) cors = Cors( allow_origin=cors_prop.get("AllowOrigin"), allow_methods=allow_methods, @@ -126,6 +123,34 @@ def _extract_cors(self, cors_prop): ) return cors + def normalize_cors_allow_methods(self, allow_methods): + """ + Normalize cors AllowMethods and Options to the methods if it's missing. + + Parameters + ---------- + allow_methods : str + The allow_methods string provided in the query + + Return + ------- + A string with normalized route + """ + if allow_methods == "*": + return allow_methods + methods = allow_methods.split(",") + normalized_methods = [] + for method in methods: + normalized_method = method.upper() + if normalized_method not in Route.ANY_HTTP_METHODS: + raise InvalidSamDocumentException("The method {} is not a valid CORS method".format(normalized_method)) + normalized_methods.append(normalized_method) + + if "OPTIONS" not in normalized_methods: + normalized_methods.append("OPTIONS") + + return ','.join(normalized_methods) + def _extract_routes_from_function(self, logical_id, function_resource, collector): """ Fetches a list of routes configured for this SAM Function resource. diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 51bcf6d59d..4b93121aa3 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -179,7 +179,7 @@ def _request_handler(self, **kwargs): try: event = self._construct_event(request, self.port, self.api.binary_media_types, self.api.stage_name, - self.api.stage_variables, cors_headers) + self.api.stage_variables) except UnicodeDecodeError: return ServiceErrorResponses.lambda_failure_response() @@ -360,7 +360,7 @@ def _merge_response_headers(headers, multi_headers): return processed_headers @staticmethod - def _construct_event(flask_request, port, binary_types, stage_name=None, stage_variables=None, cors_headers=None): + def _construct_event(flask_request, port, binary_types, stage_name=None, stage_variables=None): """ Helper method that constructs the Event to be passed to Lambda @@ -394,7 +394,7 @@ def _construct_event(flask_request, port, binary_types, stage_name=None, stage_v identity=identity, path=endpoint) - headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port, cors_headers) + headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port) query_string_dict, multi_value_query_string_dict = LocalApigwService._query_string_params(flask_request) @@ -449,7 +449,7 @@ def _query_string_params(flask_request): return query_string_dict, multi_value_query_string_dict @staticmethod - def _event_headers(flask_request, port, cors_headers): + def _event_headers(flask_request, port): """ Constructs an APIGW equivalent headers dictionary @@ -481,8 +481,6 @@ def _event_headers(flask_request, port, cors_headers): headers_dict["X-Forwarded-Port"] = str(port) multi_value_headers_dict["X-Forwarded-Port"] = [str(port)] - if cors_headers: - headers_dict.update(cors_headers) return headers_dict, multi_value_headers_dict @staticmethod diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index 30143c8f0b..6ff3ee025d 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -549,7 +549,7 @@ def test_event_headers_with_empty_list(self): request_mock.headers = headers_mock request_mock.scheme = "http" - actual_query_string = LocalApigwService._event_headers(request_mock, "3000", {}) + actual_query_string = LocalApigwService._event_headers(request_mock, "3000") self.assertEquals(actual_query_string, ({"X-Forwarded-Proto": "http", "X-Forwarded-Port": "3000"}, {"X-Forwarded-Proto": ["http"], "X-Forwarded-Port": ["3000"]})) @@ -562,7 +562,7 @@ def test_event_headers_with_non_empty_list(self): request_mock.headers = headers_mock request_mock.scheme = "http" - actual_query_string = LocalApigwService._event_headers(request_mock, "3000", {}) + actual_query_string = LocalApigwService._event_headers(request_mock, "3000") self.assertEquals(actual_query_string, ({"Content-Type": "application/json", "X-Test": "Value", "X-Forwarded-Proto": "http", "X-Forwarded-Port": "3000"}, {"Content-Type": ["application/json"], "X-Test": ["Value"], From e7ab5f2f8cc6dcd6c999ecbcbba402bd08117f05 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Mon, 29 Jul 2019 14:13:44 -0700 Subject: [PATCH 21/30] Update cors allow methods tests --- samcli/commands/local/lib/sam_api_provider.py | 11 +- .../local/start_api/test_start_api.py | 17 --- .../local/lib/test_sam_api_provider.py | 123 ++++++++++++++++-- 3 files changed, 120 insertions(+), 31 deletions(-) diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 1360d31182..5c837019aa 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -80,7 +80,7 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= body = properties.get("DefinitionBody") uri = properties.get("DefinitionUri") binary_media = properties.get("BinaryMediaTypes", []) - cors = self._extract_cors(properties.get("Cors", {})) + cors = self.extract_cors(properties.get("Cors", {})) stage_name = properties.get("StageName") stage_variables = properties.get("Variables") if not body and not uri: @@ -94,7 +94,7 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= collector.stage_variables = stage_variables collector.cors = cors - def _extract_cors(self, cors_prop): + def extract_cors(self, cors_prop): """ Extract Cors property from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result is added to the Api. @@ -123,7 +123,8 @@ def _extract_cors(self, cors_prop): ) return cors - def normalize_cors_allow_methods(self, allow_methods): + @staticmethod + def normalize_cors_allow_methods(allow_methods): """ Normalize cors AllowMethods and Options to the methods if it's missing. @@ -141,7 +142,7 @@ def normalize_cors_allow_methods(self, allow_methods): methods = allow_methods.split(",") normalized_methods = [] for method in methods: - normalized_method = method.upper() + normalized_method = method.strip().upper() if normalized_method not in Route.ANY_HTTP_METHODS: raise InvalidSamDocumentException("The method {} is not a valid CORS method".format(normalized_method)) normalized_methods.append(normalized_method) @@ -149,7 +150,7 @@ def normalize_cors_allow_methods(self, allow_methods): if "OPTIONS" not in normalized_methods: normalized_methods.append("OPTIONS") - return ','.join(normalized_methods) + return ','.join(sorted(normalized_methods)) def _extract_routes_from_function(self, logical_id, function_resource, collector): """ diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index b68b9ca561..20a287d7c0 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -681,23 +681,6 @@ def test_cors_swagger_options(self): self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS") self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510') - def test_cors_swagger_post(self): - """ - This tests that the Cors are added to post requests in the swagger template - """ - input_data = self.get_binary_data(self.binary_data_file) - response = requests.post(self.url + '/echobase64eventbody', - headers={"Content-Type": "image/gif"}, - data=input_data) - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.headers.get("Content-Type"), "image/gif") - self.assertEquals(response.content, input_data) - self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") - self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), "origin, x-requested-with") - self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS") - self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510') - class TestServiceCorsGlobalRequests(StartApiIntegBaseClass): """ diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index fba66f56bc..aa53b2b7cf 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -7,6 +7,7 @@ from nose_parameterized import parameterized from six import assertCountEqual +from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.provider import Cors @@ -1150,7 +1151,7 @@ def test_provider_parse_cors_dict(self): "Properties": { "StageName": "Prod", "Cors": { - "AllowMethods": "POST", + "AllowMethods": "POST, GET", "AllowOrigin": "*", "AllowHeaders": "Upgrade-Insecure-Requests", "MaxAge": 600 @@ -1192,7 +1193,7 @@ def test_provider_parse_cors_dict(self): routes = provider.routes cors = Cors(allow_origin="*", - allow_methods=','.join(["POST", "OPTIONS"]), + allow_methods=','.join(sorted(["POST", "GET", "OPTIONS"])), allow_headers="Upgrade-Insecure-Requests", max_age=600) route1 = Route(path='/path2', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') @@ -1203,6 +1204,116 @@ def test_provider_parse_cors_dict(self): self.assertIn(route2, routes) self.assertEquals(provider.api.cors, cors) + def test_provider_parse_cors_dict_star_allow(self): + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": { + "AllowMethods": "*", + "AllowOrigin": "*", + "AllowHeaders": "Upgrade-Insecure-Requests", + "MaxAge": 600 + }, + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + }, + "/path": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + } + } + } + } + } + } + + provider = ApiProvider(template) + + routes = provider.routes + cors = Cors(allow_origin="*", + allow_methods='*', + allow_headers="Upgrade-Insecure-Requests", + max_age=600) + route1 = Route(path='/path2', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') + route2 = Route(path='/path', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') + + self.assertEquals(len(routes), 2) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + self.assertEquals(provider.api.cors, cors) + + def test_invalid_cors_dict_allow_methods(self): + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": { + "AllowMethods": "GET, INVALID_METHOD", + "AllowOrigin": "*", + "AllowHeaders": "Upgrade-Insecure-Requests", + "MaxAge": 600 + }, + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + }, + "/path": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + } + } + } + } + } + } + with self.assertRaises(InvalidSamDocumentException, + msg="ApiProvider should fail for Invalid Cors Allow method"): + ApiProvider(template) + def test_default_cors_dict_prop(self): template = { "Resources": { @@ -1238,13 +1349,7 @@ def test_default_cors_dict_prop(self): provider = ApiProvider(template) routes = provider.routes - cors = Cors(allow_origin="www.domain.com", allow_methods=','.join(["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"])) + cors = Cors(allow_origin="www.domain.com", allow_methods=','.join(sorted(Route.ANY_HTTP_METHODS))) route1 = Route(path='/path2', methods=['GET', 'OPTIONS'], function_name='NoApiEventFunction') self.assertEquals(len(routes), 1) self.assertIn(route1, routes) From 107d2da599d552d6b3c6b61d6bd507c030489703 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Mon, 29 Jul 2019 14:14:54 -0700 Subject: [PATCH 22/30] Update cors integ test --- tests/integration/local/start_api/test_start_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 20a287d7c0..7dd732e4d8 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -2,6 +2,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from time import time +from local.apigw.local_apigw_service import Route from .start_api_integ_base import StartApiIntegBaseClass @@ -701,7 +702,7 @@ def test_cors_global(self): self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None) self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), - "GET,DELETE,PUT,POST,HEAD,OPTIONS,PATCH") + sorted(Route.ANY_HTTP_METHODS)) self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) def test_cors_global_get(self): From 49ae556a0d02092dedcc1e261c55d31924a8f32a Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Tue, 30 Jul 2019 09:54:11 -0700 Subject: [PATCH 23/30] Update * Allow Methods --- .vscode/settings.json | 3 --- samcli/commands/local/lib/sam_api_provider.py | 2 +- samcli/local/apigw/local_apigw_service.py | 2 -- 3 files changed, 1 insertion(+), 6 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 53d8ec25db..0000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.pythonPath": "venv/bin/python3.7" -} \ No newline at end of file diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 5c837019aa..fb7bceaa2c 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -138,7 +138,7 @@ def normalize_cors_allow_methods(allow_methods): A string with normalized route """ if allow_methods == "*": - return allow_methods + return ','.join(sorted(Route.ANY_HTTP_METHODS)) methods = allow_methods.split(",") normalized_methods = [] for method in methods: diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 4b93121aa3..45f53541f9 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -201,8 +201,6 @@ def _request_handler(self, **kwargs): (status_code, headers, body) = self._parse_lambda_output(lambda_response, self.api.binary_media_types, request) - if cors_headers: - headers.extend(cors_headers) except (KeyError, TypeError, ValueError): LOG.error("Function returned an invalid response (must include one of: body, headers, multiValueHeaders or " "statusCode in the response object). Response received: %s", lambda_response) From b7292fe979919ced5e4f4a7962e5b9d15546a4fd Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Tue, 30 Jul 2019 09:56:44 -0700 Subject: [PATCH 24/30] Update * Allow Methods --- tests/unit/commands/local/lib/test_sam_api_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index aa53b2b7cf..752c5f2b88 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -1254,7 +1254,7 @@ def test_provider_parse_cors_dict_star_allow(self): routes = provider.routes cors = Cors(allow_origin="*", - allow_methods='*', + allow_methods=','.join(sorted(Route.ANY_HTTP_METHODS)), allow_headers="Upgrade-Insecure-Requests", max_age=600) route1 = Route(path='/path2', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') From 1b1fc73ab21e64b063ebaf7f96a69a25209b6c91 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Tue, 30 Jul 2019 10:19:32 -0700 Subject: [PATCH 25/30] Update start_api import --- tests/integration/local/start_api/test_start_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 7dd732e4d8..7fac9c7edb 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -2,7 +2,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from time import time -from local.apigw.local_apigw_service import Route +from samcli.local.apigw.local_apigw_service import Route from .start_api_integ_base import StartApiIntegBaseClass From 6873f54b9a637753b06b38bdd52c967428858c36 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Wed, 31 Jul 2019 08:42:54 -0700 Subject: [PATCH 26/30] Update tests to pass --- samcli/commands/local/lib/sam_api_provider.py | 4 ++-- tests/integration/local/start_api/test_start_api.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index fb7bceaa2c..09b4871b89 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -106,7 +106,7 @@ def extract_cors(self, cors_prop): """ cors = None if cors_prop and isinstance(cors_prop, dict): - allow_methods = cors_prop.get("AllowMethods", ','.join(Route.ANY_HTTP_METHODS)) + allow_methods = cors_prop.get("AllowMethods", ','.join(sorted(Route.ANY_HTTP_METHODS))) allow_methods = self.normalize_cors_allow_methods(allow_methods) cors = Cors( allow_origin=cors_prop.get("AllowOrigin"), @@ -117,7 +117,7 @@ def extract_cors(self, cors_prop): elif cors_prop and isinstance(cors_prop, string_types): cors = Cors( allow_origin=cors_prop, - allow_methods=','.join(Route.ANY_HTTP_METHODS), + allow_methods=','.join(sorted(Route.ANY_HTTP_METHODS)), allow_headers=None, max_age=None ) diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 7fac9c7edb..88744dccb4 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -702,7 +702,7 @@ def test_cors_global(self): self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None) self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), - sorted(Route.ANY_HTTP_METHODS)) + ','.join(sorted(Route.ANY_HTTP_METHODS))) self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) def test_cors_global_get(self): @@ -714,10 +714,9 @@ def test_cors_global_get(self): self.assertEquals(response.status_code, 200) self.assertEquals(response.content.decode('utf-8'), "no data") self.assertEquals(response.headers.get("Content-Type"), "application/json") - self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") + self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), None) self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None) - self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), - "GET,DELETE,PUT,POST,HEAD,OPTIONS,PATCH") + self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), None) self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) From 9ed522420958655a0d231b9e2387f5cc835e368b Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Wed, 31 Jul 2019 13:14:08 -0700 Subject: [PATCH 27/30] Update bad unit test --- .../commands/local/lib/test_sam_api_provider.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 752c5f2b88..b0c9df022d 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -1128,13 +1128,13 @@ def test_provider_parse_cors_string(self): provider = ApiProvider(template) routes = provider.routes - cors = Cors(allow_origin="*", allow_methods=','.join(["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"])) + cors = Cors(allow_origin="*", allow_methods=','.join(sorted(["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"]))) route1 = Route(path='/path2', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') route2 = Route(path='/path', methods=['GET', 'OPTIONS'], function_name='NoApiEventFunction') From 9f820f93e275836b3e24289b931137275d0307ea Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Thu, 1 Aug 2019 09:40:15 -0700 Subject: [PATCH 28/30] Trigger From a355959a2db47074ef13d4327708cc924dd35c42 Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Fri, 9 Aug 2019 10:09:13 -0700 Subject: [PATCH 29/30] Update Style for cors pr --- samcli/local/apigw/local_apigw_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index c3cc8c7c8b..478a9a3688 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -27,7 +27,6 @@ class Route(object): "OPTIONS", "PATCH"] - def __init__(self, function_name, path, methods): """ Creates an ApiGatewayRoute From 3718fcf5adef4ead274d9127420e90716a5ff39f Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa Date: Sun, 11 Aug 2019 17:26:18 -0700 Subject: [PATCH 30/30] Fix style for flake8 --- .../local/lib/test_sam_api_provider.py | 796 +++++++++--------- 1 file changed, 381 insertions(+), 415 deletions(-) diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 85d7d0ba9c..3e582ebac4 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -14,16 +14,8 @@ class TestSamApiProviderWithImplicitApis(TestCase): - def test_provider_with_no_resource_properties(self): - template = { - "Resources": { - - "SamFunc1": { - "Type": "AWS::Lambda::Function" - } - } - } + template = {"Resources": {"SamFunc1": {"Type": "AWS::Lambda::Function"}}} provider = ApiProvider(template) @@ -33,7 +25,6 @@ def test_provider_with_no_resource_properties(self): def test_provider_has_correct_api(self, method): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -43,13 +34,10 @@ def test_provider_has_correct_api(self, method): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": method - } + "Properties": {"Path": "/path", "Method": method}, } - } - } + }, + }, } } } @@ -57,12 +45,14 @@ def test_provider_has_correct_api(self, method): provider = ApiProvider(template) self.assertEquals(len(provider.routes), 1) - self.assertEquals(list(provider.routes)[0], Route(path="/path", methods=["GET"], function_name="SamFunc1")) + self.assertEquals( + list(provider.routes)[0], + Route(path="/path", methods=["GET"], function_name="SamFunc1"), + ) def test_provider_creates_api_for_all_events(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -72,20 +62,14 @@ def test_provider_creates_api_for_all_events(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "GET" - } + "Properties": {"Path": "/path", "Method": "GET"}, }, "Event2": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "POST" - } - } - } - } + "Properties": {"Path": "/path", "Method": "POST"}, + }, + }, + }, } } } @@ -100,7 +84,6 @@ def test_provider_creates_api_for_all_events(self): def test_provider_has_correct_template(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -110,13 +93,10 @@ def test_provider_has_correct_template(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "GET" - } + "Properties": {"Path": "/path", "Method": "GET"}, } - } - } + }, + }, }, "SamFunc2": { "Type": "AWS::Serverless::Function", @@ -127,14 +107,11 @@ def test_provider_has_correct_template(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "POST" - } + "Properties": {"Path": "/path", "Method": "POST"}, } - } - } - } + }, + }, + }, } } @@ -149,7 +126,6 @@ def test_provider_has_correct_template(self): def test_provider_with_no_api_events(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -159,12 +135,10 @@ def test_provider_with_no_api_events(self): "Events": { "Event1": { "Type": "S3", - "Properties": { - "Property1": "value" - } + "Properties": {"Property1": "value"}, } - } - } + }, + }, } } } @@ -176,14 +150,13 @@ def test_provider_with_no_api_events(self): def test_provider_with_no_serverless_function(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Lambda::Function", "Properties": { "CodeUri": "/usr/foo/bar", "Runtime": "nodejs4.3", - "Handler": "index.handler" - } + "Handler": "index.handler", + }, } } } @@ -195,7 +168,6 @@ def test_provider_with_no_serverless_function(self): def test_provider_get_all(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -205,13 +177,10 @@ def test_provider_get_all(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "GET" - } + "Properties": {"Path": "/path", "Method": "GET"}, } - } - } + }, + }, }, "SamFunc2": { "Type": "AWS::Serverless::Function", @@ -222,14 +191,11 @@ def test_provider_get_all(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "POST" - } + "Properties": {"Path": "/path", "Method": "POST"}, } - } - } - } + }, + }, + }, } } @@ -257,7 +223,6 @@ def test_provider_get_all_with_no_routes(self): def test_provider_with_any_method(self, method): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -267,26 +232,21 @@ def test_provider_with_any_method(self, method): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": method - } + "Properties": {"Path": "/path", "Method": method}, } - } - } + }, + }, } } } provider = ApiProvider(template) - api1 = Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], function_name="SamFunc1") + api1 = Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="SamFunc1", + ) self.assertEquals(len(provider.routes), 1) self.assertIn(api1, provider.routes) @@ -299,12 +259,11 @@ def test_provider_must_support_binary_media_types(self): "image~1gif", "image~1png", "image~1png", # Duplicates must be ignored - {"Ref": "SomeParameter"} # Refs are ignored as well + {"Ref": "SomeParameter"}, # Refs are ignored as well ] } }, "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -314,37 +273,32 @@ def test_provider_must_support_binary_media_types(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "get" - } + "Properties": {"Path": "/path", "Method": "get"}, } - } - } + }, + }, } - } + }, } provider = ApiProvider(template) self.assertEquals(len(provider.routes), 1) - self.assertEquals(list(provider.routes)[0], Route(path="/path", methods=["GET"], function_name="SamFunc1")) - assertCountEqual(self, provider.api.binary_media_types, ["image/gif", "image/png"]) + self.assertEquals( + list(provider.routes)[0], + Route(path="/path", methods=["GET"], function_name="SamFunc1"), + ) + assertCountEqual( + self, provider.api.binary_media_types, ["image/gif", "image/png"] + ) self.assertEquals(provider.api.stage_name, "Prod") def test_provider_must_support_binary_media_types_with_any_method(self): template = { "Globals": { - "Api": { - "BinaryMediaTypes": [ - "image~1gif", - "image~1png", - "text/html" - ] - } + "Api": {"BinaryMediaTypes": ["image~1gif", "image~1png", "text/html"]} }, "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -354,27 +308,22 @@ def test_provider_must_support_binary_media_types_with_any_method(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "any" - } + "Properties": {"Path": "/path", "Method": "any"}, } - } - } + }, + }, } - } + }, } binary = ["image/gif", "image/png", "text/html"] expected_routes = [ - Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], function_name="SamFunc1") + Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="SamFunc1", + ) ] provider = ApiProvider(template) @@ -384,25 +333,21 @@ def test_provider_must_support_binary_media_types_with_any_method(self): class TestSamApiProviderWithExplicitApis(TestCase): - def setUp(self): self.binary_types = ["image/png", "image/jpg"] self.stage_name = "Prod" self.input_routes = [ Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), - Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1"), ] def test_with_no_routes(self): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "Prod" - } + "Properties": {"StageName": "Prod"}, } } } @@ -414,13 +359,12 @@ def test_with_no_routes(self): def test_with_inline_swagger_routes(self): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(self.input_routes) - } + "DefinitionBody": make_swagger(self.input_routes), + }, } } } @@ -429,7 +373,7 @@ def test_with_inline_swagger_routes(self): assertCountEqual(self, self.input_routes, provider.routes) def test_with_swagger_as_local_file(self): - with tempfile.NamedTemporaryFile(mode='w', delete=False) as fp: + with tempfile.NamedTemporaryFile(mode="w", delete=False) as fp: filename = fp.name swagger = make_swagger(self.input_routes) @@ -438,13 +382,9 @@ def test_with_swagger_as_local_file(self): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "Prod", - "DefinitionUri": filename - } + "Properties": {"StageName": "Prod", "DefinitionUri": filename}, } } } @@ -459,39 +399,37 @@ def test_with_swagger_as_both_body_and_uri_called(self, SwaggerReaderMock): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", "DefinitionUri": filename, - "DefinitionBody": body - } + "DefinitionBody": body, + }, } } } - SwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_routes) + SwaggerReaderMock.return_value.read.return_value = make_swagger( + self.input_routes + ) cwd = "foo" provider = ApiProvider(template, cwd=cwd) assertCountEqual(self, self.input_routes, provider.routes) - SwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) + SwaggerReaderMock.assert_called_with( + definition_body=body, definition_uri=filename, working_dir=cwd + ) def test_swagger_with_any_method(self): - routes = [ - Route(path="/path", methods=["any"], function_name="SamFunc1") - ] + routes = [Route(path="/path", methods=["any"], function_name="SamFunc1")] expected_routes = [ - Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], - function_name="SamFunc1") + Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="SamFunc1", + ) ] template = { @@ -500,8 +438,8 @@ def test_swagger_with_any_method(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(routes) - } + "DefinitionBody": make_swagger(routes), + }, } } } @@ -512,13 +450,14 @@ def test_swagger_with_any_method(self): def test_with_binary_media_types(self): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(self.input_routes, binary_media_types=self.binary_types) - } + "DefinitionBody": make_swagger( + self.input_routes, binary_media_types=self.binary_types + ), + }, } } } @@ -527,7 +466,7 @@ def test_with_binary_media_types(self): expected_routes = [ Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), Route(path="/path2", methods=["GET", "PUT"], function_name="SamFunc1"), - Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1"), ] provider = ApiProvider(template) @@ -536,27 +475,28 @@ def test_with_binary_media_types(self): def test_with_binary_media_types_in_swagger_and_on_resource(self): input_routes = [ - Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1") ] extra_binary_types = ["text/html"] template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", "Properties": { "BinaryMediaTypes": extra_binary_types, "StageName": "Prod", - "DefinitionBody": make_swagger(input_routes, binary_media_types=self.binary_types) - } + "DefinitionBody": make_swagger( + input_routes, binary_media_types=self.binary_types + ), + }, } } } expected_binary_types = sorted(self.binary_types + extra_binary_types) expected_routes = [ - Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1") ] provider = ApiProvider(template) @@ -565,35 +505,30 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): class TestSamApiProviderWithExplicitAndImplicitApis(TestCase): - def setUp(self): self.stage_name = "Prod" self.explicit_routes = [ Route(path="/path1", methods=["GET"], function_name="explicitfunction"), Route(path="/path2", methods=["GET"], function_name="explicitfunction"), - Route(path="/path3", methods=["GET"], function_name="explicitfunction") + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), ] self.swagger = make_swagger(self.explicit_routes) self.template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "Prod", - } + "Properties": {"StageName": "Prod"}, }, - "ImplicitFunc": { "Type": "AWS::Serverless::Function", "Properties": { "CodeUri": "/usr/foo/bar", "Runtime": "nodejs4.3", - "Handler": "index.handler" - } - } + "Handler": "index.handler", + }, + }, } } @@ -601,30 +536,21 @@ def test_must_union_implicit_and_explicit(self): events = { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path1", - "Method": "POST" - } + "Properties": {"Path": "/path1", "Method": "POST"}, }, - "Event2": { "Type": "Api", - "Properties": { - "Path": "/path2", - "Method": "POST" - } + "Properties": {"Path": "/path2", "Method": "POST"}, }, - "Event3": { "Type": "Api", - "Properties": { - "Path": "/path3", - "Method": "POST" - } - } + "Properties": {"Path": "/path3", "Method": "POST"}, + }, } - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events expected_routes = [ @@ -635,7 +561,7 @@ def test_must_union_implicit_and_explicit(self): # From Implicit APIs Route(path="/path1", methods=["POST"], function_name="ImplicitFunc"), Route(path="/path2", methods=["POST"], function_name="ImplicitFunc"), - Route(path="/path3", methods=["POST"], function_name="ImplicitFunc") + Route(path="/path3", methods=["POST"], function_name="ImplicitFunc"), ] provider = ApiProvider(self.template) @@ -648,30 +574,28 @@ def test_must_prefer_implicit_api_over_explicit(self): "Properties": { # This API is duplicated between implicit & explicit "Path": "/path1", - "Method": "get" - } + "Method": "get", + }, }, - "Event2": { "Type": "Api", - "Properties": { - "Path": "/path2", - "Method": "POST" - } - } + "Properties": {"Path": "/path2", "Method": "POST"}, + }, } - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger + self.template["Resources"]["ImplicitFunc"]["Properties"][ + "Events" + ] = implicit_routes expected_routes = [ Route(path="/path1", methods=["GET"], function_name="ImplicitFunc"), # Comes from Implicit - Route(path="/path2", methods=["GET"], function_name="explicitfunction"), Route(path="/path2", methods=["POST"], function_name="ImplicitFunc"), # Comes from implicit - Route(path="/path3", methods=["GET"], function_name="explicitfunction"), ] @@ -685,8 +609,8 @@ def test_must_prefer_implicit_with_any_method(self): "Properties": { # This API is duplicated between implicit & explicit "Path": "/path", - "Method": "ANY" - } + "Method": "ANY", + }, } } @@ -696,18 +620,19 @@ def test_must_prefer_implicit_with_any_method(self): Route(path="/path", methods=["DELETE"], function_name="explicitfunction"), ] - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_routes) - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = make_swagger(explicit_routes) + self.template["Resources"]["ImplicitFunc"]["Properties"][ + "Events" + ] = implicit_routes expected_routes = [ - Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], - function_name="ImplicitFunc") + Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="ImplicitFunc", + ) ] provider = ApiProvider(self.template) @@ -720,17 +645,17 @@ def test_with_any_method_on_both(self): "Properties": { # This API is duplicated between implicit & explicit "Path": "/path", - "Method": "ANY" - } + "Method": "ANY", + }, }, "Event2": { "Type": "Api", "Properties": { # This API is duplicated between implicit & explicit "Path": "/path2", - "Method": "GET" - } - } + "Method": "GET", + }, + }, } explicit_routes = [ @@ -739,22 +664,21 @@ def test_with_any_method_on_both(self): Route(path="/path2", methods=["POST"], function_name="explicitfunction"), ] - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_routes) - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = make_swagger(explicit_routes) + self.template["Resources"]["ImplicitFunc"]["Properties"][ + "Events" + ] = implicit_routes expected_routes = [ - Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], - function_name="ImplicitFunc"), - - Route(path="/path2", methods=["GET"], - function_name="ImplicitFunc"), - Route(path="/path2", methods=["POST"], function_name="explicitfunction") + Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="ImplicitFunc", + ), + Route(path="/path2", methods=["GET"], function_name="ImplicitFunc"), + Route(path="/path2", methods=["POST"], function_name="explicitfunction"), ] provider = ApiProvider(self.template) @@ -767,21 +691,24 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): "Properties": { "Path": "/newpath1", "Method": "POST", - "RestApiId": "Api1" # This path must get added to this API - } + "RestApiId": "Api1", # This path must get added to this API + }, }, - "Event2": { "Type": "Api", "Properties": { "Path": "/newpath2", "Method": "POST", - "RestApiId": {"Ref": "Api1"} # This path must get added to this API - } - } + "RestApiId": { + "Ref": "Api1" + }, # This path must get added to this API + }, + }, } - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events expected_routes = [ @@ -791,7 +718,7 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs Route(path="/newpath1", methods=["POST"], function_name="ImplicitFunc"), - Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc") + Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc"), ] provider = ApiProvider(self.template) @@ -801,35 +728,36 @@ def test_both_routes_must_get_binary_media_types(self): events = { "Event1": { "Type": "Api", - "Properties": { - "Path": "/newpath1", - "Method": "POST" - } + "Properties": {"Path": "/newpath1", "Method": "POST"}, }, - "Event2": { "Type": "Api", - "Properties": { - "Path": "/newpath2", - "Method": "POST" - } - } + "Properties": {"Path": "/newpath2", "Method": "POST"}, + }, } # Binary type for implicit self.template["Globals"] = { - "Api": { - "BinaryMediaTypes": ["image~1gif", "image~1png"] - } + "Api": {"BinaryMediaTypes": ["image~1gif", "image~1png"]} } self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger # Binary type for explicit - self.template["Resources"]["Api1"]["Properties"]["BinaryMediaTypes"] = ["explicit/type1", "explicit/type2"] + self.template["Resources"]["Api1"]["Properties"]["BinaryMediaTypes"] = [ + "explicit/type1", + "explicit/type2", + ] # Because of Globals, binary types will be concatenated on the explicit API - expected_explicit_binary_types = ["explicit/type1", "explicit/type2", "image/gif", "image/png"] + expected_explicit_binary_types = [ + "explicit/type1", + "explicit/type2", + "image/gif", + "image/png", + ] expected_routes = [ # From Explicit APIs @@ -838,12 +766,14 @@ def test_both_routes_must_get_binary_media_types(self): Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs Route(path="/newpath1", methods=["POST"], function_name="ImplicitFunc"), - Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc") + Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc"), ] provider = ApiProvider(self.template) assertCountEqual(self, expected_routes, provider.routes) - assertCountEqual(self, provider.api.binary_media_types, expected_explicit_binary_types) + assertCountEqual( + self, provider.api.binary_media_types, expected_explicit_binary_types + ) def test_binary_media_types_with_rest_api_id_reference(self): events = { @@ -852,33 +782,37 @@ def test_binary_media_types_with_rest_api_id_reference(self): "Properties": { "Path": "/connected-to-explicit-path", "Method": "POST", - "RestApiId": "Api1" - } + "RestApiId": "Api1", + }, }, - "Event2": { "Type": "Api", - "Properties": { - "Path": "/true-implicit-path", - "Method": "POST" - } - } + "Properties": {"Path": "/true-implicit-path", "Method": "POST"}, + }, } # Binary type for implicit self.template["Globals"] = { - "Api": { - "BinaryMediaTypes": ["image~1gif", "image~1png"] - } + "Api": {"BinaryMediaTypes": ["image~1gif", "image~1png"]} } self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger # Binary type for explicit - self.template["Resources"]["Api1"]["Properties"]["BinaryMediaTypes"] = ["explicit/type1", "explicit/type2"] + self.template["Resources"]["Api1"]["Properties"]["BinaryMediaTypes"] = [ + "explicit/type1", + "explicit/type2", + ] # Because of Globals, binary types will be concatenated on the explicit API - expected_explicit_binary_types = ["explicit/type1", "explicit/type2", "image/gif", "image/png"] + expected_explicit_binary_types = [ + "explicit/type1", + "explicit/type2", + "image/gif", + "image/png", + ] # expected_implicit_binary_types = ["image/gif", "image/png"] expected_routes = [ @@ -886,26 +820,32 @@ def test_binary_media_types_with_rest_api_id_reference(self): Route(path="/path1", methods=["GET"], function_name="explicitfunction"), Route(path="/path2", methods=["GET"], function_name="explicitfunction"), Route(path="/path3", methods=["GET"], function_name="explicitfunction"), - # Because of the RestApiId, Implicit APIs will also get the binary media types inherited from # the corresponding Explicit API - Route(path="/connected-to-explicit-path", methods=["POST"], function_name="ImplicitFunc"), - + Route( + path="/connected-to-explicit-path", + methods=["POST"], + function_name="ImplicitFunc", + ), # This is still just a true implicit API because it does not have RestApiId property - Route(path="/true-implicit-path", methods=["POST"], function_name="ImplicitFunc") + Route( + path="/true-implicit-path", + methods=["POST"], + function_name="ImplicitFunc", + ), ] provider = ApiProvider(self.template) assertCountEqual(self, expected_routes, provider.routes) - assertCountEqual(self, provider.api.binary_media_types, expected_explicit_binary_types) + assertCountEqual( + self, provider.api.binary_media_types, expected_explicit_binary_types + ) class TestSamStageValues(TestCase): - def test_provider_parse_stage_name(self): template = { "Resources": { - "TestApi": { "Type": "AWS::Serverless::Api", "Properties": { @@ -919,21 +859,22 @@ def test_provider_parse_stage_name(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } } - } - } - } + }, + }, } } } provider = ApiProvider(template) - route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + route1 = Route( + path="/path", methods=["GET"], function_name="NoApiEventFunction" + ) self.assertIn(route1, provider.routes) self.assertEquals(provider.api.stage_name, "dev") @@ -942,16 +883,11 @@ def test_provider_parse_stage_name(self): def test_provider_stage_variables(self): template = { "Resources": { - "TestApi": { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "dev", - "Variables": { - "vis": "data", - "random": "test", - "foo": "bar" - }, + "Variables": {"vis": "data", "random": "test", "foo": "bar"}, "DefinitionBody": { "paths": { "/path": { @@ -961,43 +897,37 @@ def test_provider_stage_variables(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } } - } - } - } + }, + }, } } } provider = ApiProvider(template) - route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + route1 = Route( + path="/path", methods=["GET"], function_name="NoApiEventFunction" + ) self.assertIn(route1, provider.routes) self.assertEquals(provider.api.stage_name, "dev") - self.assertEquals(provider.api.stage_variables, { - "vis": "data", - "random": "test", - "foo": "bar" - }) + self.assertEquals( + provider.api.stage_variables, + {"vis": "data", "random": "test", "foo": "bar"}, + ) def test_multi_stage_get_all(self): - template = OrderedDict({ - "Resources": {} - }) + template = OrderedDict({"Resources": {}}) template["Resources"]["TestApi"] = { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "dev", - "Variables": { - "vis": "data", - "random": "test", - "foo": "bar" - }, + "Variables": {"vis": "data", "random": "test", "foo": "bar"}, "DefinitionBody": { "paths": { "/path2": { @@ -1007,26 +937,22 @@ def test_multi_stage_get_all(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } } } - } - } + }, + }, } template["Resources"]["ProductionApi"] = { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Production", - "Variables": { - "vis": "prod data", - "random": "test", - "foo": "bar" - }, + "Variables": {"vis": "prod data", "random": "test", "foo": "bar"}, "DefinitionBody": { "paths": { "/path": { @@ -1036,10 +962,10 @@ def test_multi_stage_get_all(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } }, "/anotherpath": { @@ -1049,38 +975,41 @@ def test_multi_stage_get_all(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } - } - + }, } - } - } + }, + }, } - provider = ApiProvider(template) result = [f for f in provider.get_all()] routes = result[0].routes - route1 = Route(path='/path2', methods=['GET'], function_name='NoApiEventFunction') - route2 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') - route3 = Route(path='/anotherpath', methods=['POST'], function_name='NoApiEventFunction') + route1 = Route( + path="/path2", methods=["GET"], function_name="NoApiEventFunction" + ) + route2 = Route( + path="/path", methods=["GET"], function_name="NoApiEventFunction" + ) + route3 = Route( + path="/anotherpath", methods=["POST"], function_name="NoApiEventFunction" + ) self.assertEquals(len(routes), 3) self.assertIn(route1, routes) self.assertIn(route2, routes) self.assertIn(route3, routes) self.assertEquals(provider.api.stage_name, "Production") - self.assertEquals(provider.api.stage_variables, { - "vis": "prod data", - "random": "test", - "foo": "bar" - }) + self.assertEquals( + provider.api.stage_variables, + {"vis": "prod data", "random": "test", "foo": "bar"}, + ) class TestSamCors(TestCase): @@ -1100,10 +1029,10 @@ def test_provider_parse_cors_string(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } }, "/path": { @@ -1112,15 +1041,15 @@ def test_provider_parse_cors_string(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } - } + }, } - } - } + }, + }, } } } @@ -1128,15 +1057,20 @@ def test_provider_parse_cors_string(self): provider = ApiProvider(template) routes = provider.routes - cors = Cors(allow_origin="*", allow_methods=','.join(sorted(["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"]))) - route1 = Route(path='/path2', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') - route2 = Route(path='/path', methods=['GET', 'OPTIONS'], function_name='NoApiEventFunction') + cors = Cors( + allow_origin="*", + allow_methods=",".join( + sorted(["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"]) + ), + ) + route1 = Route( + path="/path2", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) + route2 = Route( + path="/path", methods=["GET", "OPTIONS"], function_name="NoApiEventFunction" + ) self.assertEquals(len(routes), 2) self.assertIn(route1, routes) @@ -1154,7 +1088,7 @@ def test_provider_parse_cors_dict(self): "AllowMethods": "POST, GET", "AllowOrigin": "*", "AllowHeaders": "Upgrade-Insecure-Requests", - "MaxAge": 600 + "MaxAge": 600, }, "DefinitionBody": { "paths": { @@ -1164,10 +1098,10 @@ def test_provider_parse_cors_dict(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } }, "/path": { @@ -1176,15 +1110,15 @@ def test_provider_parse_cors_dict(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } - } + }, } - } - } + }, + }, } } } @@ -1192,12 +1126,22 @@ def test_provider_parse_cors_dict(self): provider = ApiProvider(template) routes = provider.routes - cors = Cors(allow_origin="*", - allow_methods=','.join(sorted(["POST", "GET", "OPTIONS"])), - allow_headers="Upgrade-Insecure-Requests", - max_age=600) - route1 = Route(path='/path2', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') - route2 = Route(path='/path', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') + cors = Cors( + allow_origin="*", + allow_methods=",".join(sorted(["POST", "GET", "OPTIONS"])), + allow_headers="Upgrade-Insecure-Requests", + max_age=600, + ) + route1 = Route( + path="/path2", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) + route2 = Route( + path="/path", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) self.assertEquals(len(routes), 2) self.assertIn(route1, routes) @@ -1215,7 +1159,7 @@ def test_provider_parse_cors_dict_star_allow(self): "AllowMethods": "*", "AllowOrigin": "*", "AllowHeaders": "Upgrade-Insecure-Requests", - "MaxAge": 600 + "MaxAge": 600, }, "DefinitionBody": { "paths": { @@ -1225,10 +1169,10 @@ def test_provider_parse_cors_dict_star_allow(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } }, "/path": { @@ -1237,15 +1181,15 @@ def test_provider_parse_cors_dict_star_allow(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } - } + }, } - } - } + }, + }, } } } @@ -1253,12 +1197,22 @@ def test_provider_parse_cors_dict_star_allow(self): provider = ApiProvider(template) routes = provider.routes - cors = Cors(allow_origin="*", - allow_methods=','.join(sorted(Route.ANY_HTTP_METHODS)), - allow_headers="Upgrade-Insecure-Requests", - max_age=600) - route1 = Route(path='/path2', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') - route2 = Route(path='/path', methods=['POST', 'OPTIONS'], function_name='NoApiEventFunction') + cors = Cors( + allow_origin="*", + allow_methods=",".join(sorted(Route.ANY_HTTP_METHODS)), + allow_headers="Upgrade-Insecure-Requests", + max_age=600, + ) + route1 = Route( + path="/path2", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) + route2 = Route( + path="/path", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) self.assertEquals(len(routes), 2) self.assertIn(route1, routes) @@ -1276,7 +1230,7 @@ def test_invalid_cors_dict_allow_methods(self): "AllowMethods": "GET, INVALID_METHOD", "AllowOrigin": "*", "AllowHeaders": "Upgrade-Insecure-Requests", - "MaxAge": 600 + "MaxAge": 600, }, "DefinitionBody": { "paths": { @@ -1286,10 +1240,10 @@ def test_invalid_cors_dict_allow_methods(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } }, "/path": { @@ -1298,20 +1252,22 @@ def test_invalid_cors_dict_allow_methods(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } - } + }, } - } - } + }, + }, } } } - with self.assertRaises(InvalidSamDocumentException, - msg="ApiProvider should fail for Invalid Cors Allow method"): + with self.assertRaises( + InvalidSamDocumentException, + msg="ApiProvider should fail for Invalid Cors Allow method", + ): ApiProvider(template) def test_default_cors_dict_prop(self): @@ -1321,9 +1277,7 @@ def test_default_cors_dict_prop(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "Cors": { - "AllowOrigin": "www.domain.com", - }, + "Cors": {"AllowOrigin": "www.domain.com"}, "DefinitionBody": { "paths": { "/path2": { @@ -1333,15 +1287,15 @@ def test_default_cors_dict_prop(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } } } - } - } + }, + }, } } } @@ -1349,8 +1303,15 @@ def test_default_cors_dict_prop(self): provider = ApiProvider(template) routes = provider.routes - cors = Cors(allow_origin="www.domain.com", allow_methods=','.join(sorted(Route.ANY_HTTP_METHODS))) - route1 = Route(path='/path2', methods=['GET', 'OPTIONS'], function_name='NoApiEventFunction') + cors = Cors( + allow_origin="www.domain.com", + allow_methods=",".join(sorted(Route.ANY_HTTP_METHODS)), + ) + route1 = Route( + path="/path2", + methods=["GET", "OPTIONS"], + function_name="NoApiEventFunction", + ) self.assertEquals(len(routes), 1) self.assertIn(route1, routes) self.assertEquals(provider.api.cors, cors) @@ -1363,8 +1324,8 @@ def test_global_cors(self): "AllowMethods": "GET", "AllowOrigin": "*", "AllowHeaders": "Upgrade-Insecure-Requests", - "MaxAge": 600 - }, + "MaxAge": 600, + } } }, "Resources": { @@ -1380,10 +1341,10 @@ def test_global_cors(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } }, "/path": { @@ -1392,29 +1353,36 @@ def test_global_cors(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } - } + }, } - } - - } + }, + }, } - } + }, } provider = ApiProvider(template) routes = provider.routes - cors = Cors(allow_origin="*", - allow_headers="Upgrade-Insecure-Requests", - allow_methods=','.join(["GET", "OPTIONS"]), - max_age=600) - route1 = Route(path='/path2', methods=['GET', 'OPTIONS'], function_name='NoApiEventFunction') - route2 = Route(path='/path', methods=['GET', 'OPTIONS'], function_name='NoApiEventFunction') + cors = Cors( + allow_origin="*", + allow_headers="Upgrade-Insecure-Requests", + allow_methods=",".join(["GET", "OPTIONS"]), + max_age=600, + ) + route1 = Route( + path="/path2", + methods=["GET", "OPTIONS"], + function_name="NoApiEventFunction", + ) + route2 = Route( + path="/path", methods=["GET", "OPTIONS"], function_name="NoApiEventFunction" + ) self.assertEquals(len(routes), 2) self.assertIn(route1, routes) @@ -1437,10 +1405,7 @@ def make_swagger(routes, binary_media_types=None): Swagger document """ - swagger = { - "paths": { - } - } + swagger = {"paths": {}} for api in routes: swagger["paths"].setdefault(api.path, {}) @@ -1449,8 +1414,9 @@ def make_swagger(routes, binary_media_types=None): "x-amazon-apigateway-integration": { "type": "aws_proxy", "uri": "arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1" - ":123456789012:function:{}/invocations".format( - api.function_name) # NOQA + ":123456789012:function:{}/invocations".format( + api.function_name + ), # NOQA } } for method in api.methods: