diff --git a/src/azul/chalice.py b/src/azul/chalice.py index f5365e2a49..26ca043b44 100644 --- a/src/azul/chalice.py +++ b/src/azul/chalice.py @@ -1,6 +1,7 @@ from collections.abc import ( Iterable, ) +import html import json from json import ( JSONEncoder, @@ -28,6 +29,9 @@ from furl import ( furl, ) +from werkzeug.http import ( + parse_accept_header, +) from azul import ( config, @@ -95,6 +99,7 @@ def __init__(self, self._specs: Optional[MutableJSON] = None super().__init__(app_name, debug=config.debug > 0, configure_logs=False) # Middleware is invoked in order of registration + self.register_middleware(self._wrapping_middleware, 'http') self.register_middleware(self._logging_middleware, 'http') self.register_middleware(self._lambda_context_middleware, 'all') self.register_middleware(self._authentication_middleware, 'http') @@ -116,6 +121,21 @@ def patched_event_source_handler(self_, event, context): if old_handler.__code__ != patched_event_source_handler.__code__: chalice.app.EventSourceHandler.__call__ = patched_event_source_handler + def _wrapping_middleware(self, event, get_response): + response = get_response(event) + if response.status_code >= 400: + parsed = parse_accept_header(event.headers.get('accept')) + text_html = parsed.find('text/html') + star_star = parsed.find('*/*') + if -1 < text_html and (star_star == -1 or text_html < star_star): + response.body = ( + '' + f'Status {response.status_code}' + f'
{html.escape(str(response.body), quote=False)}
' + '' + ) + return response + def _logging_middleware(self, event, get_response): self._log_request() response = get_response(event) diff --git a/test/service/test_response.py b/test/service/test_response.py index 2fb4f60033..f9f498511b 100644 --- a/test/service/test_response.py +++ b/test/service/test_response.py @@ -2204,6 +2204,34 @@ def test_version(self): } self.assertEqual(expected_json, response.json()['git']) + def test_response_error_escaping(self): + expected = { + 'unescaped': json.dumps({ + 'Code': 'NotFoundError', + 'Message': "Unable to find file 'foo', version None in catalog 'test'" + }, separators=(',', ':')), + 'escaped': ( + "Status 404
{"
+                "'Code': 'NotFoundError', "
+                "'Message': \"Unable to find file 'foo', version None in catalog 'test'\""
+                "}
" + ) + } + test_data = [ + (None, 'unescaped'), + ('*/*', 'unescaped'), + ('*/*,text/html', 'unescaped'), + ('text/html', 'escaped'), + ('text/html,*/*', 'escaped'), + ] + url = self.base_url.set(path='repository/files/foo', + args=dict(catalog=self.catalog)) + for accept, response_type in test_data: + headers = {'accept': accept} + with self.subTest(headers=headers): + response = requests.get(str(url), headers=headers) + self.assertEqual(expected[response_type], response.text) + class TestFileTypeSummaries(DCP1TestCase, WebServiceTestCase):