Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) Cors Support #1242

Merged
merged 39 commits into from
Aug 12, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
14c51e3
Start adding cors
viksrivat Jun 26, 2019
64ba5d4
Merge branch 'develop' of https://github.com/awslabs/aws-sam-cli into…
viksrivat Jun 26, 2019
d376c4f
Reorganize cors header
viksrivat Jun 26, 2019
7de2882
Update cors
viksrivat Jun 26, 2019
3c58a95
Add cors tests
viksrivat Jun 26, 2019
0094f29
Update tests to check make pr passes
viksrivat Jun 26, 2019
d2fba5f
Fix headers so that it is returned with post/get requests
viksrivat Jun 27, 2019
7d9f270
Cleanup Tests and style
viksrivat Jun 27, 2019
087c152
Update code with style comments
viksrivat Jun 27, 2019
21f7683
Merge branch 'develop' into feature/cors_support
viksrivat Jul 16, 2019
0f50751
Run make pr and Fix merge errors
viksrivat Jul 16, 2019
aa04f5e
Fix Merge Issue with ApiGateway RestApi
viksrivat Jul 16, 2019
24d16ed
Cleanup Cors class in provider
viksrivat Jul 18, 2019
49d16b4
feat(start-api): CloudFormation AWS::ApiGateway::RestApi support (#1238)
viksrivat Jul 23, 2019
aeba546
feat(start-api): CloudFormation AWS::ApiGateway::Stage Support (#1239)
viksrivat Jul 26, 2019
f0d64c0
Merge to cfn branch and update with comments
viksrivat Jul 26, 2019
fa5dad3
Update cors tests
viksrivat Jul 26, 2019
5853a33
Update cors tests
viksrivat Jul 26, 2019
d0c93e3
Merge branch 'feature/cors_support' of github.com:viksrivat/aws-sam-c…
viksrivat Jul 26, 2019
5b89f27
update test
viksrivat Jul 26, 2019
41da0eb
Update cors with comments
viksrivat Jul 26, 2019
161de22
Fix rebase error
viksrivat Jul 29, 2019
e14e84f
Remove multi value headers
viksrivat Jul 29, 2019
a4a8ee6
Update cors allow methods
viksrivat Jul 29, 2019
e7ab5f2
Update cors allow methods tests
viksrivat Jul 29, 2019
107d2da
Update cors integ test
viksrivat Jul 29, 2019
49ae556
Update * Allow Methods
viksrivat Jul 30, 2019
b7292fe
Update * Allow Methods
viksrivat Jul 30, 2019
1b1fc73
Update start_api import
viksrivat Jul 30, 2019
6873f54
Update tests to pass
viksrivat Jul 31, 2019
9ed5224
Update bad unit test
viksrivat Jul 31, 2019
9f820f9
Trigger
viksrivat Aug 1, 2019
66662f4
Merge branch 'develop' into feature/cors_support
jfuss Aug 7, 2019
7354760
Merge branch 'develop' into feature/cors_support
viksrivat Aug 9, 2019
a355959
Update Style for cors pr
viksrivat Aug 9, 2019
6e0d664
Merge branch 'develop' into feature/cors_support
jfuss Aug 9, 2019
c7d4563
Merge branch 'develop' into feature/cors_support
jfuss Aug 10, 2019
3718fcf
Fix style for flake8
viksrivat Aug 12, 2019
4cdd407
Merge branch 'develop' into feature/cors_support
viksrivat Aug 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions samcli/commands/local/lib/api_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def add_stage_variables(self, logical_id, stage_variables):
properties = properties._replace(stage_variables=stage_variables)
self._set_properties(logical_id, properties)

def add_cors(self, logical_id, cors):
properties = self._get_properties(logical_id)
properties = properties._replace(cors=cors)
self._set_properties(logical_id, properties)

def _get_apis_with_config(self, logical_id):
"""
Returns the list of APIs in this resource along with other extra configuration such as binary media types,
Expand Down
2 changes: 1 addition & 1 deletion samcli/commands/local/lib/local_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _make_routing_list(api_provider):
for api in api_provider.get_all():
route = Route(methods=[api.method], function_name=api.function_name, path=api.path,
binary_types=api.binary_media_types, stage_name=api.stage_name,
stage_variables=api.stage_variables)
stage_variables=api.stage_variables, cors=api.cors)
routes.append(route)
return routes

Expand Down
40 changes: 26 additions & 14 deletions samcli/commands/local/lib/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,30 @@ def __hash__(self):
return hash(self.path) * hash(self.method) * hash(self.function_name)


Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"])
_CorsTuple = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"])

_CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None
None, # Allow Methods is optional and defaults to empty
None, # Allow Headers is optional and defaults to empty
None # MaxAge is optional and defaults to empty
)


class Cors(_CorsTuple):
pass


class AbstractApiProvider(object):
"""
Abstract base class to return APIs and the functions they route to
"""
_ANY_HTTP_METHODS = ["GET",
"DELETE",
"PUT",
"POST",
"HEAD",
"OPTIONS",
"PATCH"]
ANY_HTTP_METHODS = ["GET",
"DELETE",
"PUT",
"POST",
"HEAD",
"OPTIONS",
"PATCH"]

def get_all(self):
"""
Expand All @@ -259,21 +269,23 @@ def get_all(self):
raise NotImplementedError("not implemented")

@staticmethod
def normalize_http_methods(http_method):
def normalize_http_methods(api):
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
"""
Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all
supported Http Methods on Api Gateway.

:param str http_method: Http method
:param api api: Api
:yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods)
"""

http_method = api.method
if http_method.upper() == 'ANY':
for method in AbstractApiProvider._ANY_HTTP_METHODS:
for method in AbstractApiProvider.ANY_HTTP_METHODS:
yield method.upper()
else:
yield http_method.upper()

if api.cors and http_method.upper() != "OPTIONS":
yield "OPTIONS"

@staticmethod
def normalize_apis(apis):
"""
Expand All @@ -292,7 +304,7 @@ def normalize_apis(apis):

result = list()
for api in apis:
for normalized_method in AbstractApiProvider.normalize_http_methods(api.method):
for normalized_method in AbstractApiProvider.normalize_http_methods(api):
# _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple
result.append(api._replace(method=normalized_method))

Expand Down
44 changes: 41 additions & 3 deletions samcli/commands/local/lib/sam_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import logging

from six import string_types

from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider
from samcli.commands.local.lib.provider import Api, AbstractApiProvider
from samcli.commands.local.lib.provider import Cors
from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException
from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,9 +78,9 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=
body = properties.get("DefinitionBody")
uri = properties.get("DefinitionUri")
binary_media = properties.get("BinaryMediaTypes", [])
cors = self._extract_cors(properties)
stage_name = properties.get("StageName")
stage_variables = properties.get("Variables")

if not body and not uri:
# Swagger is not found anywhere.
LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri",
Expand All @@ -86,6 +89,41 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=
self.extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd)
collector.add_stage_name(logical_id, stage_name)
collector.add_stage_variables(logical_id, stage_variables)
if cors:
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
collector.add_cors(logical_id, cors)

def _extract_cors(self, properties):
"""
Extract Cors property from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result
is added to the Api.

Parameters
----------
properties : dict
Resource properties
"""
cors_prop = properties.get("Cors")
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
cors = None
if cors_prop and isinstance(cors_prop, dict):
allow_methods = cors_prop.get("AllowMethods", ','.join(AbstractApiProvider.ANY_HTTP_METHODS))

if allow_methods and "OPTIONS" not in allow_methods:
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
allow_methods += ",OPTIONS"

cors = Cors(
allow_origin=cors_prop.get("AllowOrigin"),
allow_methods=allow_methods,
allow_headers=cors_prop.get("AllowHeaders"),
max_age=cors_prop.get("MaxAge")
)
elif cors_prop and isinstance(cors_prop, string_types):
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
cors = Cors(
allow_origin=cors_prop,
allow_methods=','.join(AbstractApiProvider.ANY_HTTP_METHODS),
allow_headers=None,
max_age=None
)
return cors

def _extract_apis_from_function(self, logical_id, function_resource, collector):
"""
Expand Down Expand Up @@ -202,7 +240,7 @@ def merge_apis(collector):
for config in all_configs:
# Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP
# method on explicit API.
for normalized_method in AbstractApiProvider.normalize_http_methods(config.method):
for normalized_method in AbstractApiProvider.normalize_http_methods(config):
key = config.path + normalized_method
all_apis[key] = config

Expand Down
48 changes: 42 additions & 6 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

class Route(object):

def __init__(self, methods, function_name, path, binary_types=None, stage_name=None, stage_variables=None):
def __init__(self, methods, function_name, path, binary_types=None, stage_name=None, stage_variables=None,
cors=None):
"""
Creates an ApiGatewayRoute

Expand All @@ -33,6 +34,7 @@ def __init__(self, methods, function_name, path, binary_types=None, stage_name=N
self.binary_types = binary_types or []
self.stage_name = stage_name
self.stage_variables = stage_variables
self.cors = cors


class LocalApigwService(BaseLocalService):
Expand Down Expand Up @@ -142,10 +144,15 @@ def _request_handler(self, **kwargs):
Response object
"""
route = self._get_current_route(request)
cors_headers = LocalApigwService.cors_to_headers(route.cors)
viksrivat marked this conversation as resolved.
Show resolved Hide resolved

if 'OPTIONS' in route.methods:
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
headers = Headers(cors_headers)
return self.service_response('', headers, 200)

try:
event = self._construct_event(request, self.port, route.binary_types, route.stage_name,
route.stage_variables)
route.stage_variables, cors_headers)
except UnicodeDecodeError:
return ServiceErrorResponses.lambda_failure_response()

Expand All @@ -167,6 +174,8 @@ def _request_handler(self, **kwargs):
(status_code, headers, body) = self._parse_lambda_output(lambda_response,
route.binary_types,
request)
if cors_headers:
headers.extend(cors_headers)
except (KeyError, TypeError, ValueError):
LOG.error("Function returned an invalid response (must include one of: body, headers, multiValueHeaders or "
"statusCode in the response object). Response received: %s", lambda_response)
Expand Down Expand Up @@ -315,7 +324,7 @@ def _merge_response_headers(headers, multi_headers):
return processed_headers

@staticmethod
def _construct_event(flask_request, port, binary_types, stage_name=None, stage_variables=None):
def _construct_event(flask_request, port, binary_types, stage_name=None, stage_variables=None, cors_headers=None):
"""
Helper method that constructs the Event to be passed to Lambda

Expand Down Expand Up @@ -349,7 +358,7 @@ def _construct_event(flask_request, port, binary_types, stage_name=None, stage_v
identity=identity,
path=endpoint)

headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port)
headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port, cors_headers)

query_string_dict, multi_value_query_string_dict = LocalApigwService._query_string_params(flask_request)

Expand Down Expand Up @@ -404,7 +413,7 @@ def _query_string_params(flask_request):
return query_string_dict, multi_value_query_string_dict

@staticmethod
def _event_headers(flask_request, port):
def _event_headers(flask_request, port, cors_headers=None):
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
"""
Constructs an APIGW equivalent headers dictionary

Expand All @@ -414,6 +423,8 @@ def _event_headers(flask_request, port):
Request from Flask
int port
Forwarded Port
cors_headers dict
Dict of the Cors properties

Returns dict (str: str), dict (str: list of str)
-------
Expand All @@ -434,7 +445,9 @@ def _event_headers(flask_request, port):

headers_dict["X-Forwarded-Port"] = str(port)
multi_value_headers_dict["X-Forwarded-Port"] = [str(port)]

if cors_headers:
headers_dict.update(cors_headers)
multi_value_headers_dict.update(cors_headers)
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
return headers_dict, multi_value_headers_dict

@staticmethod
Expand All @@ -455,3 +468,26 @@ def _should_base64_encode(binary_types, request_mimetype):

"""
return request_mimetype in binary_types or "*/*" in binary_types

@staticmethod
def cors_to_headers(cors):
"""
Convert CORS object to headers dictionary
Parameters
----------
cors list(samcli.commands.local.lib.provider.Cors)
CORS configuration objcet
Returns
-------
Dictionary with CORS headers
"""
if not cors:
return {}
headers = {
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
'Access-Control-Allow-Origin': cors.allow_origin,
'Access-Control-Allow-Methods': cors.allow_methods,
'Access-Control-Allow-Headers': cors.allow_headers,
'Access-Control-Max-Age': cors.max_age
}

return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None}
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
79 changes: 79 additions & 0 deletions tests/integration/local/start_api/test_start_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,82 @@ def test_swagger_stage_variable(self):

response_data = response.json()
self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'})


class TestServiceCorsSwaggerRequests(StartApiIntegBaseClass):
"""
Test to check that the correct headers are being added with Cors with swagger code
"""
template_path = "/testdata/start_api/swagger-template.yaml"
binary_data_file = "testdata/start_api/binarydata.gif"

def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_cors_swagger_options(self):
"""
This tests that the Cors are added to option requests in the swagger template
"""
response = requests.options(self.url + '/echobase64eventbody')

self.assertEquals(response.status_code, 200)

self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*")
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), "origin, x-requested-with")
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS")
self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510')

def test_cors_swagger_post(self):
"""
This tests that the Cors are added to post requests in the swagger template
"""
input_data = self.get_binary_data(self.binary_data_file)
response = requests.post(self.url + '/echobase64eventbody',
headers={"Content-Type": "image/gif"},
data=input_data)

self.assertEquals(response.status_code, 200)
self.assertEquals(response.headers.get("Content-Type"), "image/gif")
self.assertEquals(response.content, input_data)
self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*")
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), "origin, x-requested-with")
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS")
self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510')


class TestServiceCorsGlobalRequests(StartApiIntegBaseClass):
"""
Test to check that the correct headers are being added with Cors with the global property
"""
template_path = "/testdata/start_api/template.yaml"

def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_cors_global(self):
"""
This tests that the Cors are added to options requests when the global property is set
"""
response = requests.options(self.url + '/echobase64eventbody')

self.assertEquals(response.status_code, 200)
self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*")
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None)
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"),
"GET,DELETE,PUT,POST,HEAD,OPTIONS,PATCH")
self.assertEquals(response.headers.get("Access-Control-Max-Age"), None)

def test_cors_global_get(self):
"""
This tests that the Cors are added to post requests when the global property is set
"""
response = requests.get(self.url + "/onlysetstatuscode")

self.assertEquals(response.status_code, 200)
self.assertEquals(response.content.decode('utf-8'), "no data")
self.assertEquals(response.headers.get("Content-Type"), "application/json")
self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*")
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None)
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"),
"GET,DELETE,PUT,POST,HEAD,OPTIONS,PATCH")
self.assertEquals(response.headers.get("Access-Control-Max-Age"), None)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
AWSTemplateFormatVersion : '2010-09-09'
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31

Globals:
Expand All @@ -14,6 +14,11 @@ Resources:
StageName: dev
Variables:
VarName: varValue
Cors:
AllowOrigin: "*"
AllowMethods: "GET"
AllowHeaders: "origin, x-requested-with"
MaxAge: 510
DefinitionBody:
swagger: "2.0"
info:
Expand Down
Loading