diff --git a/homeassistant/components/http/cors.py b/homeassistant/components/http/cors.py index b01e68f701d95..555f302f8e196 100644 --- a/homeassistant/components/http/cors.py +++ b/homeassistant/components/http/cors.py @@ -27,30 +27,36 @@ def setup_cors(app, origins): ) for host in origins }) - def allow_cors(route, methods): + cors_added = set() + + def _allow_cors(route, config=None): """Allow cors on a route.""" - cors.add(route, { - '*': aiohttp_cors.ResourceOptions( - allow_headers=ALLOWED_CORS_HEADERS, - allow_methods=methods, - ) - }) + if hasattr(route, 'resource'): + path = route.resource + else: + path = route + + path = path.canonical + + if path in cors_added: + return + + cors.add(route, config) + cors_added.add(path) - app['allow_cors'] = allow_cors + app['allow_cors'] = lambda route: _allow_cors(route, { + '*': aiohttp_cors.ResourceOptions( + allow_headers=ALLOWED_CORS_HEADERS, + allow_methods='*', + ) + }) if not origins: return async def cors_startup(app): """Initialize cors when app starts up.""" - cors_added = set() - for route in list(app.router.routes()): - if hasattr(route, 'resource'): - route = route.resource - if route in cors_added: - continue - cors.add(route) - cors_added.add(route) + _allow_cors(route) app.on_startup.append(cors_startup) diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index f3d3cd06e22d5..2b6c2a113c4dd 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -69,15 +69,13 @@ def register(self, app, router): handler = request_handler_factory(self, handler) for url in urls: - routes.append( - (method, router.add_route(method, url, handler)) - ) + routes.append(router.add_route(method, url, handler)) if not self.cors_allowed: return - for method, route in routes: - app['allow_cors'](route, [method.upper()]) + for route in routes: + app['allow_cors'](route) def request_handler_factory(view, handler): diff --git a/tests/components/http/test_cors.py b/tests/components/http/test_cors.py index 523d4943ba04a..a510d2b3829e7 100644 --- a/tests/components/http/test_cors.py +++ b/tests/components/http/test_cors.py @@ -14,6 +14,7 @@ from homeassistant.const import HTTP_HEADER_HA_AUTH from homeassistant.setup import async_setup_component from homeassistant.components.http.cors import setup_cors +from homeassistant.components.http.view import HomeAssistantView TRUSTED_ORIGIN = 'https://home-assistant.io' @@ -96,3 +97,34 @@ async def test_cors_preflight_allowed(client): assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == TRUSTED_ORIGIN assert req.headers[ACCESS_CONTROL_ALLOW_HEADERS] == \ HTTP_HEADER_HA_AUTH.upper() + + +async def test_cors_middleware_with_cors_allowed_view(hass): + """Test that we can configure cors and have a cors_allowed view.""" + class MyView(HomeAssistantView): + """Test view that allows CORS.""" + + requires_auth = False + cors_allowed = True + + def __init__(self, url, name): + """Initialize test view.""" + self.url = url + self.name = name + + async def get(self, request): + """Test response.""" + return "test" + + assert await async_setup_component(hass, 'http', { + 'http': { + 'cors_allowed_origins': ['http://home-assistant.io'] + } + }) + + hass.http.register_view(MyView('/api/test', 'api:test')) + hass.http.register_view(MyView('/api/test', 'api:test2')) + hass.http.register_view(MyView('/api/test2', 'api:test')) + + hass.http.app._on_startup.freeze() + await hass.http.app.startup()