Skip to content

Commit

Permalink
feat(start-api): Cors Support (aws#1242)
Browse files Browse the repository at this point in the history
  • Loading branch information
viksrivat committed Aug 14, 2019
1 parent 03b0cd9 commit 824f3a7
Show file tree
Hide file tree
Showing 9 changed files with 923 additions and 363 deletions.
31 changes: 30 additions & 1 deletion samcli/commands/local/lib/api_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self):
self.binary_media_types_set = set()
self.stage_name = None
self.stage_variables = None
self.cors = None

def __iter__(self):
"""
Expand Down Expand Up @@ -103,12 +104,40 @@ def get_api(self):
An Api object with all the properties
"""
api = Api()
api.routes = self.dedupe_function_routes(self.routes)
routes = self.dedupe_function_routes(self.routes)
routes = self.normalize_cors_methods(routes, self.cors)
api.routes = routes
api.binary_media_types_set = self.binary_media_types_set
api.stage_name = self.stage_name
api.stage_variables = self.stage_variables
api.cors = self.cors
return api

@staticmethod
def normalize_cors_methods(routes, cors):
"""
Adds OPTIONS method to all the route methods if cors exists
Parameters
-----------
routes: list(samcli.local.apigw.local_apigw_service.Route)
List of Routes
cors: samcli.commands.local.lib.provider.Cors
the cors object for the api
Return
-------
A list of routes without duplicate routes with the same function_name and method
"""

def add_options_to_route(route):
if "OPTIONS" not in route.methods:
route.methods.append("OPTIONS")
return route

return routes if not cors else [add_options_to_route(route) for route in routes]

@staticmethod
def dedupe_function_routes(routes):
"""
Expand Down
36 changes: 35 additions & 1 deletion samcli/commands/local/lib/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,41 @@ def binary_media_types(self):
return list(self.binary_media_types_set)


Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"])
_CorsTuple = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"])


_CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None
None, # Allow Methods is optional and defaults to empty
None, # Allow Headers is optional and defaults to empty
None # MaxAge is optional and defaults to empty
)


class Cors(_CorsTuple):

@staticmethod
def cors_to_headers(cors):
"""
Convert CORS object to headers dictionary
Parameters
----------
cors list(samcli.commands.local.lib.provider.Cors)
CORS configuration objcet
Returns
-------
Dictionary with CORS headers
"""
if not cors:
return {}
headers = {
'Access-Control-Allow-Origin': cors.allow_origin,
'Access-Control-Allow-Methods': cors.allow_methods,
'Access-Control-Allow-Headers': cors.allow_headers,
'Access-Control-Max-Age': cors.max_age
}
# Filters out items in the headers dictionary that isn't empty.
# This is required because the flask Headers dict will send an invalid 'None' string
return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None}


class AbstractApiProvider(object):
Expand Down
66 changes: 64 additions & 2 deletions samcli/commands/local/lib/sam_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import logging

from six import string_types

from samcli.commands.local.lib.provider import Cors
from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider
from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException
from samcli.local.apigw.local_apigw_service import Route
Expand Down Expand Up @@ -77,9 +80,9 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=
body = properties.get("DefinitionBody")
uri = properties.get("DefinitionUri")
binary_media = properties.get("BinaryMediaTypes", [])
cors = self.extract_cors(properties.get("Cors", {}))
stage_name = properties.get("StageName")
stage_variables = properties.get("Variables")

if not body and not uri:
# Swagger is not found anywhere.
LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri",
Expand All @@ -88,6 +91,65 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=
self.extract_swagger_route(logical_id, body, uri, binary_media, collector, cwd=cwd)
collector.stage_name = stage_name
collector.stage_variables = stage_variables
collector.cors = cors

def extract_cors(self, cors_prop):
"""
Extract Cors property from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result
is added to the Api.
Parameters
----------
cors_prop : dict
Resource properties for Cors
"""
cors = None
if cors_prop and isinstance(cors_prop, dict):
allow_methods = cors_prop.get("AllowMethods", ','.join(sorted(Route.ANY_HTTP_METHODS)))
allow_methods = self.normalize_cors_allow_methods(allow_methods)
cors = Cors(
allow_origin=cors_prop.get("AllowOrigin"),
allow_methods=allow_methods,
allow_headers=cors_prop.get("AllowHeaders"),
max_age=cors_prop.get("MaxAge")
)
elif cors_prop and isinstance(cors_prop, string_types):
cors = Cors(
allow_origin=cors_prop,
allow_methods=','.join(sorted(Route.ANY_HTTP_METHODS)),
allow_headers=None,
max_age=None
)
return cors

@staticmethod
def normalize_cors_allow_methods(allow_methods):
"""
Normalize cors AllowMethods and Options to the methods if it's missing.
Parameters
----------
allow_methods : str
The allow_methods string provided in the query
Return
-------
A string with normalized route
"""
if allow_methods == "*":
return ','.join(sorted(Route.ANY_HTTP_METHODS))
methods = allow_methods.split(",")
normalized_methods = []
for method in methods:
normalized_method = method.strip().upper()
if normalized_method not in Route.ANY_HTTP_METHODS:
raise InvalidSamDocumentException("The method {} is not a valid CORS method".format(normalized_method))
normalized_methods.append(normalized_method)

if "OPTIONS" not in normalized_methods:
normalized_methods.append("OPTIONS")

return ','.join(sorted(normalized_methods))

def _extract_routes_from_function(self, logical_id, function_resource, collector):
"""
Expand All @@ -96,7 +158,7 @@ def _extract_routes_from_function(self, logical_id, function_resource, collector
Parameters
----------
logical_id : str
Logical ID of the resource
Logical ID of the resourc
function_resource : dict
Contents of the function resource including its properties
Expand Down
23 changes: 20 additions & 3 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flask import Flask, request
from werkzeug.datastructures import Headers

from samcli.commands.local.lib.provider import Cors
from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser
from samcli.lib.utils.stream_writer import StreamWriter
from samcli.local.lambdafn.exceptions import FunctionNotFound
Expand Down Expand Up @@ -170,6 +171,12 @@ def _request_handler(self, **kwargs):
"""

route = self._get_current_route(request)
cors_headers = Cors.cors_to_headers(self.api.cors)

method, _ = self.get_request_methods_endpoints(request)
if method == 'OPTIONS':
headers = Headers(cors_headers)
return self.service_response('', headers, 200)

try:
event = self._construct_event(request, self.port, self.api.binary_media_types, self.api.stage_name,
Expand Down Expand Up @@ -209,8 +216,7 @@ def _get_current_route(self, flask_request):
:param request flask_request: Flask Request
:return: Route matching the endpoint and method of the request
"""
endpoint = flask_request.endpoint
method = flask_request.method
method, endpoint = self.get_request_methods_endpoints(flask_request)

route_key = self._route_key(method, endpoint)
route = self._dict_of_routes.get(route_key, None)
Expand All @@ -223,6 +229,16 @@ def _get_current_route(self, flask_request):

return route

def get_request_methods_endpoints(self, flask_request):
"""
Separated out for testing requests in request handler
:param request flask_request: Flask Request
:return: the request's endpoint and method
"""
endpoint = flask_request.endpoint
method = flask_request.method
return method, endpoint

# Consider moving this out to its own class. Logic is started to get dense and looks messy @jfuss
@staticmethod
def _parse_lambda_output(lambda_output, binary_types, flask_request):
Expand Down Expand Up @@ -451,6 +467,8 @@ def _event_headers(flask_request, port):
Request from Flask
int port
Forwarded Port
cors_headers dict
Dict of the Cors properties
Returns dict (str: str), dict (str: list of str)
-------
Expand All @@ -471,7 +489,6 @@ def _event_headers(flask_request, port):

headers_dict["X-Forwarded-Port"] = str(port)
multi_value_headers_dict["X-Forwarded-Port"] = [str(port)]

return headers_dict, multi_value_headers_dict

@staticmethod
Expand Down
62 changes: 62 additions & 0 deletions tests/integration/local/start_api/test_start_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from time import time

from samcli.local.apigw.local_apigw_service import Route
from .start_api_integ_base import StartApiIntegBaseClass


Expand Down Expand Up @@ -664,6 +665,67 @@ def test_swagger_stage_variable(self):
self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'})


class TestServiceCorsSwaggerRequests(StartApiIntegBaseClass):
"""
Test to check that the correct headers are being added with Cors with swagger code
"""
template_path = "/testdata/start_api/swagger-template.yaml"
binary_data_file = "testdata/start_api/binarydata.gif"

def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_cors_swagger_options(self):
"""
This tests that the Cors are added to option requests in the swagger template
"""
response = requests.options(self.url + '/echobase64eventbody')

self.assertEquals(response.status_code, 200)

self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*")
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), "origin, x-requested-with")
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS")
self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510')


class TestServiceCorsGlobalRequests(StartApiIntegBaseClass):
"""
Test to check that the correct headers are being added with Cors with the global property
"""
template_path = "/testdata/start_api/template.yaml"

def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_cors_global(self):
"""
This tests that the Cors are added to options requests when the global property is set
"""
response = requests.options(self.url + '/echobase64eventbody')

self.assertEquals(response.status_code, 200)
self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*")
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None)
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"),
','.join(sorted(Route.ANY_HTTP_METHODS)))
self.assertEquals(response.headers.get("Access-Control-Max-Age"), None)

def test_cors_global_get(self):
"""
This tests that the Cors are added to post requests when the global property is set
"""
response = requests.get(self.url + "/onlysetstatuscode")

self.assertEquals(response.status_code, 200)
self.assertEquals(response.content.decode('utf-8'), "no data")
self.assertEquals(response.headers.get("Content-Type"), "application/json")
self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), None)
self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None)
self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), None)
self.assertEquals(response.headers.get("Access-Control-Max-Age"), None)


class TestStartApiWithCloudFormationStage(StartApiIntegBaseClass):
"""
Test Class centered around the different responses that can happen in Lambda and pass through start-api
Expand Down
7 changes: 6 additions & 1 deletion tests/integration/testdata/start_api/swagger-template.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
AWSTemplateFormatVersion : '2010-09-09'
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31

Globals:
Expand All @@ -14,6 +14,11 @@ Resources:
StageName: dev
Variables:
VarName: varValue
Cors:
AllowOrigin: "*"
AllowMethods: "GET"
AllowHeaders: "origin, x-requested-with"
MaxAge: 510
DefinitionBody:
swagger: "2.0"
info:
Expand Down
1 change: 1 addition & 0 deletions tests/integration/testdata/start_api/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Globals:
- image~1png
Variables:
VarName: varValue
Cors: "*"
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function
Expand Down
Loading

0 comments on commit 824f3a7

Please sign in to comment.