Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) Cors Support #1242

Merged
merged 39 commits into from
Aug 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
14c51e3
Start adding cors
viksrivat Jun 26, 2019
64ba5d4
Merge branch 'develop' of https://github.com/awslabs/aws-sam-cli into…
viksrivat Jun 26, 2019
d376c4f
Reorganize cors header
viksrivat Jun 26, 2019
7de2882
Update cors
viksrivat Jun 26, 2019
3c58a95
Add cors tests
viksrivat Jun 26, 2019
0094f29
Update tests to check make pr passes
viksrivat Jun 26, 2019
d2fba5f
Fix headers so that it is returned with post/get requests
viksrivat Jun 27, 2019
7d9f270
Cleanup Tests and style
viksrivat Jun 27, 2019
087c152
Update code with style comments
viksrivat Jun 27, 2019
21f7683
Merge branch 'develop' into feature/cors_support
viksrivat Jul 16, 2019
0f50751
Run make pr and Fix merge errors
viksrivat Jul 16, 2019
aa04f5e
Fix Merge Issue with ApiGateway RestApi
viksrivat Jul 16, 2019
24d16ed
Cleanup Cors class in provider
viksrivat Jul 18, 2019
49d16b4
feat(start-api): CloudFormation AWS::ApiGateway::RestApi support (#1238)
viksrivat Jul 23, 2019
aeba546
feat(start-api): CloudFormation AWS::ApiGateway::Stage Support (#1239)
viksrivat Jul 26, 2019
f0d64c0
Merge to cfn branch and update with comments
viksrivat Jul 26, 2019
fa5dad3
Update cors tests
viksrivat Jul 26, 2019
5853a33
Update cors tests
viksrivat Jul 26, 2019
d0c93e3
Merge branch 'feature/cors_support' of github.com:viksrivat/aws-sam-c…
viksrivat Jul 26, 2019
5b89f27
update test
viksrivat Jul 26, 2019
41da0eb
Update cors with comments
viksrivat Jul 26, 2019
161de22
Fix rebase error
viksrivat Jul 29, 2019
e14e84f
Remove multi value headers
viksrivat Jul 29, 2019
a4a8ee6
Update cors allow methods
viksrivat Jul 29, 2019
e7ab5f2
Update cors allow methods tests
viksrivat Jul 29, 2019
107d2da
Update cors integ test
viksrivat Jul 29, 2019
49ae556
Update * Allow Methods
viksrivat Jul 30, 2019
b7292fe
Update * Allow Methods
viksrivat Jul 30, 2019
1b1fc73
Update start_api import
viksrivat Jul 30, 2019
6873f54
Update tests to pass
viksrivat Jul 31, 2019
9ed5224
Update bad unit test
viksrivat Jul 31, 2019
9f820f9
Trigger
viksrivat Aug 1, 2019
66662f4
Merge branch 'develop' into feature/cors_support
jfuss Aug 7, 2019
7354760
Merge branch 'develop' into feature/cors_support
viksrivat Aug 9, 2019
a355959
Update Style for cors pr
viksrivat Aug 9, 2019
6e0d664
Merge branch 'develop' into feature/cors_support
jfuss Aug 9, 2019
c7d4563
Merge branch 'develop' into feature/cors_support
jfuss Aug 10, 2019
3718fcf
Fix style for flake8
viksrivat Aug 12, 2019
4cdd407
Merge branch 'develop' into feature/cors_support
viksrivat Aug 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion samcli/commands/local/lib/api_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self):
self.binary_media_types_set = set()
self.stage_name = None
self.stage_variables = None
self.cors = None

def __iter__(self):
"""
Expand Down Expand Up @@ -103,12 +104,40 @@ def get_api(self):
An Api object with all the properties
"""
api = Api()
api.routes = self.dedupe_function_routes(self.routes)
routes = self.dedupe_function_routes(self.routes)
routes = self.normalize_cors_methods(routes, self.cors)
api.routes = routes
api.binary_media_types_set = self.binary_media_types_set
api.stage_name = self.stage_name
api.stage_variables = self.stage_variables
api.cors = self.cors
return api

@staticmethod
def normalize_cors_methods(routes, cors):
"""
Adds OPTIONS method to all the route methods if cors exists

Parameters
-----------
routes: list(samcli.local.apigw.local_apigw_service.Route)
List of Routes

cors: samcli.commands.local.lib.provider.Cors
the cors object for the api

Return
-------
A list of routes without duplicate routes with the same function_name and method
"""

def add_options_to_route(route):
if "OPTIONS" not in route.methods:
route.methods.append("OPTIONS")
return route

return routes if not cors else [add_options_to_route(route) for route in routes]

@staticmethod
def dedupe_function_routes(routes):
"""
Expand Down
36 changes: 35 additions & 1 deletion samcli/commands/local/lib/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,41 @@ def binary_media_types(self):
return list(self.binary_media_types_set)


Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"])
_CorsTuple = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"])


_CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None
None, # Allow Methods is optional and defaults to empty
None, # Allow Headers is optional and defaults to empty
None # MaxAge is optional and defaults to empty
)


class Cors(_CorsTuple):

@staticmethod
def cors_to_headers(cors):
"""
Convert CORS object to headers dictionary
Parameters
----------
cors list(samcli.commands.local.lib.provider.Cors)
CORS configuration objcet
Returns
-------
Dictionary with CORS headers
"""
if not cors:
return {}
headers = {
'Access-Control-Allow-Origin': cors.allow_origin,
'Access-Control-Allow-Methods': cors.allow_methods,
'Access-Control-Allow-Headers': cors.allow_headers,
'Access-Control-Max-Age': cors.max_age
}
# Filters out items in the headers dictionary that isn't empty.
# This is required because the flask Headers dict will send an invalid 'None' string
return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None}


class AbstractApiProvider(object):
Expand Down
66 changes: 64 additions & 2 deletions samcli/commands/local/lib/sam_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import logging

from six import string_types

from samcli.commands.local.lib.provider import Cors
from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider
from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException
from samcli.local.apigw.local_apigw_service import Route
Expand Down Expand Up @@ -77,9 +80,9 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=
body = properties.get("DefinitionBody")
uri = properties.get("DefinitionUri")
binary_media = properties.get("BinaryMediaTypes", [])
cors = self.extract_cors(properties.get("Cors", {}))
stage_name = properties.get("StageName")
stage_variables = properties.get("Variables")

if not body and not uri:
# Swagger is not found anywhere.
LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri",
Expand All @@ -88,6 +91,65 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=
self.extract_swagger_route(logical_id, body, uri, binary_media, collector, cwd=cwd)
collector.stage_name = stage_name
collector.stage_variables = stage_variables
collector.cors = cors

def extract_cors(self, cors_prop):
"""
Extract Cors property from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result
is added to the Api.

Parameters
----------
cors_prop : dict
Resource properties for Cors
"""
cors = None
if cors_prop and isinstance(cors_prop, dict):
allow_methods = cors_prop.get("AllowMethods", ','.join(sorted(Route.ANY_HTTP_METHODS)))
allow_methods = self.normalize_cors_allow_methods(allow_methods)
cors = Cors(
allow_origin=cors_prop.get("AllowOrigin"),
allow_methods=allow_methods,
allow_headers=cors_prop.get("AllowHeaders"),
max_age=cors_prop.get("MaxAge")
)
elif cors_prop and isinstance(cors_prop, string_types):
viksrivat marked this conversation as resolved.
Show resolved Hide resolved
cors = Cors(
allow_origin=cors_prop,
allow_methods=','.join(sorted(Route.ANY_HTTP_METHODS)),
allow_headers=None,
max_age=None
)
return cors

@staticmethod
def normalize_cors_allow_methods(allow_methods):
"""
Normalize cors AllowMethods and Options to the methods if it's missing.

Parameters
----------
allow_methods : str
The allow_methods string provided in the query

Return
-------
A string with normalized route
"""
if allow_methods == "*":
return ','.join(sorted(Route.ANY_HTTP_METHODS))
methods = allow_methods.split(",")
normalized_methods = []
for method in methods:
normalized_method = method.strip().upper()
if normalized_method not in Route.ANY_HTTP_METHODS:
raise InvalidSamDocumentException("The method {} is not a valid CORS method".format(normalized_method))
normalized_methods.append(normalized_method)

if "OPTIONS" not in normalized_methods:
normalized_methods.append("OPTIONS")

return ','.join(sorted(normalized_methods))

def _extract_routes_from_function(self, logical_id, function_resource, collector):
"""
Expand All @@ -96,7 +158,7 @@ def _extract_routes_from_function(self, logical_id, function_resource, collector
Parameters
----------
logical_id : str
Logical ID of the resource
Logical ID of the resourc

function_resource : dict
Contents of the function resource including its properties
Expand Down
23 changes: 20 additions & 3 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flask import Flask, request
from werkzeug.datastructures import Headers

from samcli.commands.local.lib.provider import Cors
from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser
from samcli.lib.utils.stream_writer import StreamWriter
from samcli.local.lambdafn.exceptions import FunctionNotFound
Expand Down Expand Up @@ -170,6 +171,12 @@ def _request_handler(self, **kwargs):
"""

route = self._get_current_route(request)
cors_headers = Cors.cors_to_headers(self.api.cors)

method, _ = self.get_request_methods_endpoints(request)
if method == 'OPTIONS':
headers = Headers(cors_headers)
return self.service_response('', headers, 200)

try:
event = self._construct_event(request, self.port, self.api.binary_media_types, self.api.stage_name,
Expand Down Expand Up @@ -209,8 +216,7 @@ def _get_current_route(self, flask_request):
:param request flask_request: Flask Request
:return: Route matching the endpoint and method of the request
"""
endpoint = flask_request.endpoint
method = flask_request.method
method, endpoint = self.get_request_methods_endpoints(flask_request)

route_key = self._route_key(method, endpoint)
route = self._dict_of_routes.get(route_key, None)
Expand All @@ -223,6 +229,16 @@ def _get_current_route(self, flask_request):

return route

def get_request_methods_endpoints(self, flask_request):
"""
Separated out for testing requests in request handler
:param request flask_request: Flask Request
:return: the request's endpoint and method
"""
endpoint = flask_request.endpoint
method = flask_request.method
return method, endpoint

# Consider moving this out to its own class. Logic is started to get dense and looks messy @jfuss
@staticmethod
def _parse_lambda_output(lambda_output, binary_types, flask_request):
Expand Down Expand Up @@ -451,6 +467,8 @@ def _event_headers(flask_request, port):
Request from Flask
int port
Forwarded Port
cors_headers dict
Dict of the Cors properties

Returns dict (str: str), dict (str: list of str)
-------
Expand All @@ -471,7 +489,6 @@ def _event_headers(flask_request, port):

headers_dict["X-Forwarded-Port"] = str(port)
multi_value_headers_dict["X-Forwarded-Port"] = [str(port)]

return headers_dict, multi_value_headers_dict

@staticmethod
Expand Down
62 changes: 62 additions & 0 deletions tests/integration/local/start_api/test_start_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from time import time

from samcli.local.apigw.local_apigw_service import Route
from .start_api_integ_base import StartApiIntegBaseClass


Expand Down Expand Up @@ -664,6 +665,67 @@ def test_swagger_stage_variable(self):
self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'})


class TestServiceCorsSwaggerRequests(StartApiIntegBaseClass):
"""
Test to check that the correct headers are being added with Cors with swagger code
"""
template_path = "/testdata/start_api/swagger-template.yaml"
binary_data_file = "testdata/start_api/binarydata.gif"

def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_cors_swagger_options(self):
"""
This tests that the Cors are added to option requests in the swagger template
"""
response = requests.options(self.url + '/echobase64eventbody')

self.assertEquals(response.status_code, 200)

self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*")
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), "origin, x-requested-with")
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS")
self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510')


class TestServiceCorsGlobalRequests(StartApiIntegBaseClass):
"""
Test to check that the correct headers are being added with Cors with the global property
"""
template_path = "/testdata/start_api/template.yaml"

def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_cors_global(self):
"""
This tests that the Cors are added to options requests when the global property is set
"""
response = requests.options(self.url + '/echobase64eventbody')

self.assertEquals(response.status_code, 200)
self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*")
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None)
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"),
','.join(sorted(Route.ANY_HTTP_METHODS)))
self.assertEquals(response.headers.get("Access-Control-Max-Age"), None)

def test_cors_global_get(self):
"""
This tests that the Cors are added to post requests when the global property is set
"""
response = requests.get(self.url + "/onlysetstatuscode")

self.assertEquals(response.status_code, 200)
self.assertEquals(response.content.decode('utf-8'), "no data")
self.assertEquals(response.headers.get("Content-Type"), "application/json")
self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), None)
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None)
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), None)
self.assertEquals(response.headers.get("Access-Control-Max-Age"), None)


class TestStartApiWithCloudFormationStage(StartApiIntegBaseClass):
"""
Test Class centered around the different responses that can happen in Lambda and pass through start-api
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
AWSTemplateFormatVersion : '2010-09-09'
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31

Globals:
Expand All @@ -14,6 +14,11 @@ Resources:
StageName: dev
Variables:
VarName: varValue
Cors:
AllowOrigin: "*"
AllowMethods: "GET"
AllowHeaders: "origin, x-requested-with"
MaxAge: 510
DefinitionBody:
swagger: "2.0"
info:
Expand Down
1 change: 1 addition & 0 deletions tests/integration/testdata/start_api/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Globals:
- image~1png
Variables:
VarName: varValue
Cors: "*"
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function
Expand Down
Loading