Skip to content

Commit

Permalink
Merge pull request #737 from KeepSafe/keepalive_refactoring
Browse files Browse the repository at this point in the history
Keepalive refactoring
  • Loading branch information
asvetlov committed Jan 13, 2016
2 parents e8b47f7 + e23f85b commit 1711040
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 22 deletions.
12 changes: 8 additions & 4 deletions aiohttp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,14 +675,18 @@ def send_headers(self, _sep=': ', _end='\r\n'):

def _add_default_headers(self):
# set the connection header
connection = None
if self.upgrade:
connection = 'upgrade'
elif not self.closing if self.keepalive is None else self.keepalive:
connection = 'keep-alive'
if self.version == HttpVersion10:
connection = 'keep-alive'
else:
connection = 'close'
if self.version == HttpVersion11:
connection = 'close'

self.headers[hdrs.CONNECTION] = connection
if connection is not None:
self.headers[hdrs.CONNECTION] = connection

def write(self, chunk, *,
drain=False, EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
Expand Down Expand Up @@ -882,7 +886,7 @@ def __init__(self, transport, method, path,
http_version=HttpVersion11, close=False):
# set the default for HTTP 1.0 to be different
# will only be overwritten with keep-alive header
if http_version < HttpVersion11:
if http_version < HttpVersion10:
close = True

super().__init__(transport, http_version, close)
Expand Down
1 change: 0 additions & 1 deletion tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def handler(request):
resp = yield from client.get('/')
assert resp.status == 200
assert resp.raw_headers == ((b'CONTENT-LENGTH', b'0'),
(b'CONNECTION', b'keep-alive'),
(b'DATE', mock.ANY),
(b'SERVER', mock.ANY))
resp.close()
31 changes: 25 additions & 6 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,21 @@ def test_add_headers_hop_headers(transport):
assert [] == list(msg.headers)


def test_default_headers(transport):
def test_default_headers_http_10(transport):
msg = protocol.Response(transport, 200,
http_version=protocol.HttpVersion10)
msg._add_default_headers()

assert 'DATE' in msg.headers
assert 'keep-alive' == msg.headers['CONNECTION']


def test_default_headers_http_11(transport):
msg = protocol.Response(transport, 200)
msg._add_default_headers()

headers = [r for r, _ in msg.headers.items()]
assert 'DATE' in headers
assert 'CONNECTION' in headers
assert 'DATE' in msg.headers
assert 'CONNECTION' not in msg.headers


def test_default_headers_server(transport):
Expand Down Expand Up @@ -222,15 +230,26 @@ def test_default_headers_connection_close(transport):
assert [('CONNECTION', 'close')] == headers


def test_default_headers_connection_keep_alive(transport):
msg = protocol.Response(transport, 200)
def test_default_headers_connection_keep_alive_http_10(transport):
msg = protocol.Response(transport, 200,
http_version=protocol.HttpVersion10)
msg.keepalive = True
msg._add_default_headers()

headers = [r for r in msg.headers.items() if r[0] == 'CONNECTION']
assert [('CONNECTION', 'keep-alive')] == headers


def test_default_headers_connection_keep_alive_11(transport):
msg = protocol.Response(transport, 200,
http_version=protocol.HttpVersion11)
msg.keepalive = True
msg._add_default_headers()

headers = [r for r in msg.headers.items() if r[0] == 'CONNECTION']
assert 'CONNECTION' not in headers


def test_send_headers(transport):
write = transport.write = mock.Mock()

Expand Down
3 changes: 0 additions & 3 deletions tests/test_web_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def test_HTTPOk(buf, request):
assert re.match(('HTTP/1.1 200 OK\r\n'
'CONTENT-TYPE: text/plain; charset=utf-8\r\n'
'CONTENT-LENGTH: 7\r\n'
'CONNECTION: keep-alive\r\n'
'DATE: .+\r\n'
'SERVER: .+\r\n\r\n'
'200: OK'), txt)
Expand Down Expand Up @@ -95,7 +94,6 @@ def test_HTTPFound(buf, request):
'CONTENT-TYPE: text/plain; charset=utf-8\r\n'
'CONTENT-LENGTH: 10\r\n'
'LOCATION: /redirect\r\n'
'CONNECTION: keep-alive\r\n'
'DATE: .+\r\n'
'SERVER: .+\r\n\r\n'
'302: Found', txt)
Expand All @@ -122,7 +120,6 @@ def test_HTTPMethodNotAllowed(buf, request):
'CONTENT-TYPE: text/plain; charset=utf-8\r\n'
'CONTENT-LENGTH: 23\r\n'
'ALLOW: POST,PUT\r\n'
'CONNECTION: keep-alive\r\n'
'DATE: .+\r\n'
'SERVER: .+\r\n\r\n'
'405: Method Not Allowed', txt)
Expand Down
29 changes: 24 additions & 5 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def go():

self.loop.run_until_complete(go())

def test_http10_keep_alive_default(self):
def test_http11_keep_alive_default(self):

@asyncio.coroutine
def handler(request):
Expand All @@ -531,12 +531,31 @@ def handler(request):
def go():
_, _, url = yield from self.create_server('GET', '/', handler)
resp = yield from request('GET', url, loop=self.loop,
version=HttpVersion10)
self.assertEqual('close', resp.headers['CONNECTION'])
version=HttpVersion11)
self.assertNotIn('CONNECTION', resp.headers)
resp.close()

self.loop.run_until_complete(go())

def test_http10_keep_alive_default(self):

@asyncio.coroutine
def handler(request):
yield from request.read()
return web.Response(body=b'OK')

@asyncio.coroutine
def go():
_, _, url = yield from self.create_server('GET', '/', handler)
with ClientSession(loop=self.loop) as session:
resp = yield from session.get(url,
version=HttpVersion10)
self.assertEqual(resp.version, HttpVersion10)
self.assertEqual('keep-alive', resp.headers['CONNECTION'])
resp.close()

self.loop.run_until_complete(go())

def test_http09_keep_alive_default(self):

@asyncio.coroutine
Expand All @@ -551,7 +570,7 @@ def go():
resp = yield from request('GET', url, loop=self.loop,
headers=headers,
version=HttpVersion(0, 9))
self.assertEqual('close', resp.headers['CONNECTION'])
self.assertNotIn('CONNECTION', resp.headers)
resp.close()

self.loop.run_until_complete(go())
Expand All @@ -569,7 +588,7 @@ def go():
headers = {'Connection': 'close'}
resp = yield from request('GET', url, loop=self.loop,
headers=headers, version=HttpVersion10)
self.assertEqual('close', resp.headers['CONNECTION'])
self.assertNotIn('CONNECTION', resp.headers)
resp.close()

self.loop.run_until_complete(go())
Expand Down
3 changes: 0 additions & 3 deletions tests/test_web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,6 @@ def append(data):
yield from resp.write_eof()
txt = buf.decode('utf8')
assert re.match('HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 0\r\n'
'CONNECTION: keep-alive\r\n'
'DATE: .+\r\nSERVER: .+\r\n\r\n', txt)


Expand All @@ -836,7 +835,6 @@ def append(data):
yield from resp.write_eof()
txt = buf.decode('utf8')
assert re.match('HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 4\r\n'
'CONNECTION: keep-alive\r\n'
'DATE: .+\r\nSERVER: .+\r\n\r\ndata', txt)


Expand All @@ -861,7 +859,6 @@ def append(data):
txt = buf.decode('utf8')
assert re.match('HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 0\r\n'
'SET-COOKIE: name=value\r\n'
'CONNECTION: keep-alive\r\n'
'DATE: .+\r\nSERVER: .+\r\n\r\n', txt)


Expand Down

0 comments on commit 1711040

Please sign in to comment.