diff --git a/README.md b/README.md index 70c6a0c..39adf6d 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,18 @@ For these cases, the server needs to inspect the Origin header from the client a Note that the `Access-Control-Allow-Origin` header can only return a single value. This means that if you want to allow requests from multiple origins you need to dynamically whitelist those origins and return a different header value depending on the incoming request. +Additionally if specific HTTP methods should be allowed an application should add: + + Access-Control-Allow-Methods: GET, OPTIONS + +Here `GET` and `OPTIONS` are allowed. + +Similarly specific headers can be allowed: + + Access-Control-Allow-Headers: content-type, Authorization + +In this case `content-type` and `Authorization` headers are allowed to be sent to the server in a CORS request. + ## How to use this middleware We will assume you have an existing ASGI app, in a variable called `app`. @@ -70,6 +82,15 @@ If you need to do something more complicated that cannot be expressed using the Your callback function will be passed the `Origin` header that was passed in by the browser. +To add specific allowed headers or methods you can specify them with the `headers=` and `methods=` parameters: + + app = asgi_cors(app, methods=[ + "GET", "OPTIONS" + ], headers=[ + "Authorization","content-type" + ]) + + ## Using the middleware as a decorator If you are defining your ASGI application directly as a function, you can use the `asgi_cors_decorator` function decorator like so: diff --git a/asgi_cors.py b/asgi_cors.py index 2345c9a..e0971f3 100644 --- a/asgi_cors.py +++ b/asgi_cors.py @@ -3,7 +3,7 @@ def asgi_cors_decorator( - allow_all=False, hosts=None, host_wildcards=None, callback=None + allow_all=False, hosts=None, host_wildcards=None, callback=None, headers=None, methods=None ): hosts = hosts or [] host_wildcards = host_wildcards or [] @@ -13,6 +13,12 @@ def asgi_cors_decorator( host_wildcards = [ h.encode("utf8") if isinstance(h, str) else h for h in host_wildcards ] + headers = [ + h.encode("utf8") if isinstance(h, str) else h for h in headers + ] + methods = [ + h.encode("utf8") if isinstance(h, str) else h for h in methods + ] if any(h.endswith(b"/") for h in (hosts or [])) or any( h.endswith(b"/") for h in (host_wildcards or []) @@ -43,6 +49,13 @@ async def wrapped_send(event): matches_callback = callback(incoming_origin) if matches_hosts or matches_wildcards or matches_callback: access_control_allow_origin = incoming_origin + access_control_allow_headers = ['content-type'] + if headers: + access_control_allow_headers = headers + access_control_allow_methods = ['GET'] + if methods: + access_control_allow_methods = methods + if access_control_allow_origin is not None: # Construct a new event with new headers event = { @@ -51,13 +64,25 @@ async def wrapped_send(event): "headers": [ p for p in original_headers - if p[0] != b"access-control-allow-origin" + if p[0] != b"access-control-allow-origin" and p[0] != b"access-control-allow-headers" and p[0] != b"access-control-allow-methods" ] + [ [ b"access-control-allow-origin", access_control_allow_origin, ] + ] + + [ + [ + b"access-control-allow-headers", + b", ".join(access_control_allow_headers), + ] + ] + + [ + [ + b"access-control-allow-methods", + b", ".join(access_control_allow_methods), + ] ], } await send(event) @@ -69,5 +94,5 @@ async def wrapped_send(event): return _asgi_cors_decorator -def asgi_cors(app, allow_all=False, hosts=None, host_wildcards=None, callback=None): - return asgi_cors_decorator(allow_all, hosts, host_wildcards, callback)(app) +def asgi_cors(app, allow_all=False, hosts=None, host_wildcards=None, callback=None, headers=None, methods=None): + return asgi_cors_decorator(allow_all, hosts, host_wildcards, callback, headers, methods)(app)