diff --git a/aiohttp/protocol.py b/aiohttp/protocol.py index 4663d1edd5e..6eb63b651ad 100644 --- a/aiohttp/protocol.py +++ b/aiohttp/protocol.py @@ -521,8 +521,6 @@ class HttpMessage: # this is useful for wsgi's start_response implementation. _send_headers = False - _has_user_agent = False - def __init__(self, transport, version, close): self.transport = transport self.version = version @@ -595,9 +593,6 @@ def add_header(self, name, value): self.chunked = value.lower().strip() == 'chunked' elif name not in self.HOP_HEADERS: - if name == 'USER-AGENT': - self._has_user_agent = True - # ignore hop-by-hop headers self.headers.add(name, value) @@ -805,8 +800,6 @@ class Response(HttpMessage): 'TRAILERS', 'TRANSFER-ENCODING', 'UPGRADE', - 'SERVER', - 'DATE', } @staticmethod @@ -833,8 +826,10 @@ def __init__(self, transport, status, def _add_default_headers(self): super()._add_default_headers() - self.headers.extend((('DATE', format_date_time(None)), - ('SERVER', self.SERVER_SOFTWARE),)) + if 'DATE' not in self.headers: + # format_date_time(None) is quite expensive + self.headers.setdefault('DATE', format_date_time(None)) + self.headers.setdefault('SERVER', self.SERVER_SOFTWARE) class Request(HttpMessage): @@ -853,5 +848,4 @@ def __init__(self, transport, method, path, def _add_default_headers(self): super()._add_default_headers() - if not self._has_user_agent: - self.headers['USER-AGENT'] = self.SERVER_SOFTWARE + self.headers.setdefault('USER-AGENT', self.SERVER_SOFTWARE) diff --git a/tests/test_http_protocol.py b/tests/test_http_protocol.py index 96c04c8d1bb..c4bfca4b2f4 100644 --- a/tests/test_http_protocol.py +++ b/tests/test_http_protocol.py @@ -479,3 +479,18 @@ def test_write_drain(self): res = msg.write(b'1') self.assertEqual(res, ()) + + def test_dont_override_request_headers_with_default_values(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + msg.add_header('USER-AGENT', 'custom') + msg._add_default_headers() + self.assertEqual('custom', msg.headers['USER-AGENT']) + + def test_dont_override_response_headers_with_default_values(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.add_header('DATE', 'now') + msg.add_header('SERVER', 'custom') + msg._add_default_headers() + self.assertEqual('custom', msg.headers['SERVER']) + self.assertEqual('now', msg.headers['DATE'])