Skip to content

Commit

Permalink
Backport contextvars support (#4271)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov authored Oct 26, 2019
1 parent 8694981 commit 29eccad
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 51 deletions.
1 change: 1 addition & 0 deletions CHANGES/3380.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix failed websocket handshake leaving connection hanging.
1 change: 1 addition & 0 deletions CHANGES/3557.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Call ``AccessLogger.log`` with the current exception available from sys.exc_info().
3 changes: 3 additions & 0 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,9 @@ cdef class HttpParser:
else:
return messages, False, b''

def set_upgraded(self, val):
self._upgraded = val


cdef class HttpRequestParser(HttpParser):

Expand Down
6 changes: 6 additions & 0 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,12 @@ def parse_headers(

return (headers, raw_headers, close_conn, encoding, upgrade, chunked)

def set_upgraded(self, val: bool) -> None:
"""Set connection upgraded (to websocket) mode.
:param bool val: new state.
"""
self._upgraded = val


class HttpRequestParser(HttpParser):
"""Read request status line. Exception .http_exceptions.BadStatusLine
Expand Down
120 changes: 76 additions & 44 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Awaitable,
Callable,
Optional,
Tuple,
Type,
cast,
)
Expand Down Expand Up @@ -371,6 +372,33 @@ def _process_keepalive(self) -> None:
self._keepalive_handle = self._loop.call_later(
self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive)

async def _handle_request(self,
request: BaseRequest,
start_time: float,
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
resp = await self._request_handler(request)
except HTTPException as exc:
resp = Response(status=exc.status,
reason=exc.reason,
text=exc.text,
headers=exc.headers)
reset = await self.finish_response(request, resp, start_time)
except asyncio.CancelledError:
raise
except asyncio.TimeoutError as exc:
self.log_debug('Request handler timed out.', exc_info=exc)
resp = self.handle_error(request, 504)
reset = await self.finish_response(request, resp, start_time)
except Exception as exc:
resp = self.handle_error(request, 500, exc)
reset = await self.finish_response(request, resp, start_time)
else:
reset = await self.finish_response(request, resp, start_time)

return resp, reset

async def start(self) -> None:
"""Process incoming request.
Expand Down Expand Up @@ -403,8 +431,7 @@ async def start(self) -> None:

message, payload = self._messages.popleft()

if self.access_log:
now = loop.time()
start = loop.time()

manager.requests_count += 1
writer = StreamWriter(self, loop)
Expand All @@ -413,54 +440,23 @@ async def start(self) -> None:
try:
# a new task is used for copy context vars (#3406)
task = self._loop.create_task(
self._request_handler(request))
self._handle_request(request, start))
try:
resp = await task
except HTTPException as exc:
resp = exc
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
self.log_debug('Ignored premature client disconnection')
break
except asyncio.TimeoutError as exc:
self.log_debug('Request handler timed out.', exc_info=exc)
resp = self.handle_error(request, 504)
except Exception as exc:
resp = self.handle_error(request, 500, exc)
else:
# Deprecation warning (See #2415)
if getattr(resp, '__http_exception__', False):
warnings.warn(
"returning HTTPException object is deprecated "
"(#2415) and will be removed, "
"please raise the exception instead",
DeprecationWarning)
# Deprecation warning (See #2415)
if getattr(resp, '__http_exception__', False):
warnings.warn(
"returning HTTPException object is deprecated "
"(#2415) and will be removed, "
"please raise the exception instead",
DeprecationWarning)

# Drop the processed task from asyncio.Task.all_tasks() early
del task

if self.debug:
if not isinstance(resp, StreamResponse):
if resp is None:
raise RuntimeError("Missing return "
"statement on request handler")
else:
raise RuntimeError("Web-handler should return "
"a response instance, "
"got {!r}".format(resp))
try:
prepare_meth = resp.prepare
except AttributeError:
if resp is None:
raise RuntimeError("Missing return "
"statement on request handler")
else:
raise RuntimeError("Web-handler should return "
"a response instance, "
"got {!r}".format(resp))
try:
await prepare_meth(request)
await resp.write_eof()
except ConnectionError:
if reset:
self.log_debug('Ignored premature client disconnection 2')
break

Expand All @@ -469,7 +465,7 @@ async def start(self) -> None:

# log access
if self.access_log:
self.log_access(request, resp, loop.time() - now)
self.log_access(request, resp, loop.time() - start)

# check payload
if not payload.is_eof():
Expand Down Expand Up @@ -530,6 +526,42 @@ async def start(self) -> None:
if self.transport is not None and self._error_handler is None:
self.transport.close()

async def finish_response(self,
request: BaseRequest,
resp: StreamResponse,
start_time: float) -> bool:
"""
Prepare the response and write_eof, then log access. This has to
be called within the context of any exception so the access logger
can get exception information. Returns True if the client disconnects
prematurely.
"""
if self._request_parser is not None:
self._request_parser.set_upgraded(False)
self._upgrade = False
if self._message_tail:
self._request_parser.feed_data(self._message_tail)
self._message_tail = b''
try:
prepare_meth = resp.prepare
except AttributeError:
if resp is None:
raise RuntimeError("Missing return "
"statement on request handler")
else:
raise RuntimeError("Web-handler should return "
"a response instance, "
"got {!r}".format(resp))
try:
await prepare_meth(request)
await resp.write_eof()
except ConnectionError:
self.log_access(request, resp, start_time)
return True
else:
self.log_access(request, resp, start_time)
return False

def handle_error(self,
request: BaseRequest,
status: int=500,
Expand Down
37 changes: 37 additions & 0 deletions tests/test_web_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
import pytest

import aiohttp
from aiohttp import web
from aiohttp.abc import AbstractAccessLogger
from aiohttp.helpers import PY_37
from aiohttp.web_log import AccessLogger

try:
from contextvars import ContextVar
except ImportError:
ContextVar = None


IS_PYPY = platform.python_implementation() == 'PyPy'


Expand Down Expand Up @@ -157,3 +165,32 @@ def log(self, request, response, time):
access_logger = Logger(mock_logger, '{request} {response} {time}')
access_logger.log('request', 'response', 1)
mock_logger.info.assert_called_with('request response 1')


@pytest.mark.skipif(not PY_37,
reason="contextvars support is required")
async def test_contextvars_logger(aiohttp_server, aiohttp_client):
VAR = ContextVar('VAR')

async def handler(request):
return web.Response()

@web.middleware
async def middleware(request, handler):
VAR.set("uuid")
return await handler(request)

msg = None

class Logger(AbstractAccessLogger):
def log(self, request, response, time):
nonlocal msg
msg = 'contextvars: {}'.format(VAR.get())

app = web.Application(middlewares=[middleware])
app.router.add_get('/', handler)
server = await aiohttp_server(app, access_log_class=Logger)
client = await aiohttp_client(server)
resp = await client.get('/')
assert 200 == resp.status
assert msg == 'contextvars: uuid'
12 changes: 6 additions & 6 deletions tests/test_web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ async def handle(request):
b'GET / HTTP/1.1\r\n'
b'Host: example.com\r\n'
b'Content-Length: 0\r\n\r\n')
await asyncio.sleep(0)
await asyncio.sleep(0.01)

# with exception
request_handler.side_effect = handle_with_error()
Expand All @@ -384,7 +384,7 @@ async def handle(request):

assert srv._task_handler

await asyncio.sleep(0)
await asyncio.sleep(0.01)

await srv._task_handler
assert normal_completed
Expand Down Expand Up @@ -600,7 +600,7 @@ async def test_content_length_0(srv, request_handler) -> None:
b'GET / HTTP/1.1\r\n'
b'Host: example.org\r\n'
b'Content-Length: 0\r\n\r\n')
await asyncio.sleep(0)
await asyncio.sleep(0.01)

assert request_handler.called
assert request_handler.call_args[0][0].content == streams.EMPTY_PAYLOAD
Expand Down Expand Up @@ -722,7 +722,7 @@ async def handle1(request):
b'GET / HTTP/1.1\r\n'
b'Host: example.com\r\n'
b'Content-Length: 0\r\n\r\n')
await asyncio.sleep(0)
await asyncio.sleep(0.01)

# second

Expand All @@ -740,7 +740,7 @@ async def handle2(request):
b'GET / HTTP/1.1\r\n'
b'Host: example.com\r\n'
b'Content-Length: 0\r\n\r\n')
await asyncio.sleep(0)
await asyncio.sleep(0.01)

assert srv._task_handler is not None

Expand Down Expand Up @@ -855,4 +855,4 @@ async def handler(request):
writer.write(b"x")
writer.close()
await asyncio.sleep(0.1)
logger.debug.assert_called_with('Ignored premature client disconnection 2')
logger.debug.assert_called_with('Ignored premature client disconnection.')
2 changes: 1 addition & 1 deletion tests/test_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def handler(request):
with pytest.raises(client.ClientPayloadError):
await resp.read()

logger.debug.assert_called_with('Ignored premature client disconnection ')
logger.debug.assert_called_with('Ignored premature client disconnection')


async def test_raw_server_not_http_exception_debug(aiohttp_raw_server,
Expand Down
26 changes: 26 additions & 0 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,3 +788,29 @@ async def handler(request):
ws = await client.ws_connect('/')
data = await ws.receive_str()
assert data == 'OK'


async def test_bug3380(loop, aiohttp_client) -> None:

async def handle_null(request):
return aiohttp.web.json_response({'err': None})

async def ws_handler(request):
return web.Response(status=401)

app = web.Application()
app.router.add_route('GET', '/ws', ws_handler)
app.router.add_route('GET', '/api/null', handle_null)

client = await aiohttp_client(app)

resp = await client.get('/api/null')
assert (await resp.json()) == {'err': None}
resp.close()

with pytest.raises(aiohttp.WSServerHandshakeError):
await client.ws_connect('/ws')

resp = await client.get('/api/null', timeout=1)
assert (await resp.json()) == {'err': None}
resp.close()

0 comments on commit 29eccad

Please sign in to comment.