diff --git a/app/__init__.py b/app/__init__.py index ae177bd..bf9c70b 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -82,19 +82,27 @@ def validate_origin(): origin = request.headers.get('Origin', None) referrer = request.headers.get('Referer', None) - if origin is None and referrer is None and sec_fetch_site is None: - logger.error('Referer and/or Origin and/or Sec-Fetch-Site headers not set') - abort(403, 'Permission denied') - if origin is not None and not is_domain_allowed(origin): + if origin is not None: + if is_domain_allowed(origin): + return logger.error('Origin=%s is not allowed', origin) abort(403, 'Permission denied') - if referrer is not None and not is_domain_allowed(referrer): - logger.error('Referer=%s is not allowed', referrer) - abort(403, 'Permission denied') - if sec_fetch_site is not None and sec_fetch_site != 'same-origin': + + if sec_fetch_site is not None: + if sec_fetch_site in ['same-origin', 'same-site']: + return logger.error('Sec-Fetch-Site=%s is not allowed', sec_fetch_site) abort(403, 'Permission denied') + if referrer is not None: + if is_domain_allowed(referrer): + return + logger.error('Referer=%s is not allowed', referrer) + abort(403, 'Permission denied') + + logger.error('Referer and/or Origin and/or Sec-Fetch-Site headers not set') + abort(403, 'Permission denied') + @app.after_request def log_response(response): diff --git a/tests/unit_tests/test_routes.py b/tests/unit_tests/test_routes.py index a7cdcf8..51bb9f5 100644 --- a/tests/unit_tests/test_routes.py +++ b/tests/unit_tests/test_routes.py @@ -5,6 +5,8 @@ from datetime import timedelta from unittest.mock import patch +from nose2.tools import params + from flask import url_for from app.settings import AWS_DB_TABLE_NAME @@ -230,6 +232,50 @@ def test_get_metadata_non_allowed_origin(self): self.assertEqual(response.content_type, "application/json") self.assertEqual(response.json["error"]["message"], "Permission denied") + @params( + None, + {'Origin': 'www.example'}, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'cross-site' + }, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'same-site' + }, + { + 'Origin': 'www.example', 'Sec-Fetch-Site': 'same-origin' + }, + { + 'Referer': 'http://www.example', + }, + ) + def test_get_metadata_origin_not_allowed(self, headers): + id_to_fetch = self.sample_kml['id'] + response = self.app.get(url_for('get_kml_metadata', kml_id=id_to_fetch), headers=headers) + self.assertEqual(response.status_code, 403) + self.assertCors(response, ['DELETE', 'GET', 'HEAD', 'OPTIONS', 'PUT']) + + @params( + {'Origin': 'map.geo.admin.ch'}, + { + 'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site' + }, + { + 'Origin': 'public.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin' + }, + { + 'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site' + }, + {'Sec-Fetch-Site': 'same-origin'}, + { + 'Referer': 'https://map.geo.admin.ch', + }, + ) + def test_get_metadata_origin_allowed(self, headers): + id_to_fetch = self.sample_kml['id'] + response = self.app.get(url_for('get_kml_metadata', kml_id=id_to_fetch), headers=headers) + self.assertEqual(response.status_code, 200) + self.assertCors(response, ['DELETE', 'GET', 'HEAD', 'OPTIONS', 'PUT']) + class TestPutEndpoint(BaseRouteTestCase):