Skip to content

Commit

Permalink
Backport #5572: Use new loop for web.run_app(). (#5820)
Browse files Browse the repository at this point in the history
* Use new loop for web.run_app().

* Skip test on 3.6
  • Loading branch information
Dreamsorcerer authored Jun 21, 2021
1 parent 8b88c3d commit fa46667
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 47 deletions.
2 changes: 2 additions & 0 deletions CHANGES/5572.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Always create a new event loop in ``aiohttp.web.run_app()``.
This adds better compatibility with ``asyncio.run()`` or if trying to run multiple apps in sequence.
46 changes: 25 additions & 21 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,11 @@ def run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
loop = asyncio.get_event_loop()
if loop is None:
loop = asyncio.new_event_loop()

# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log and access_log.name == "aiohttp.access":
Expand All @@ -488,27 +490,29 @@ def run_app(
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

try:
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
)

try:
asyncio.set_event_loop(loop)
loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
Expand Down
92 changes: 66 additions & 26 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_run_app_http(patched_loop) -> None:
cleanup_handler = make_mocked_coro()
app.on_cleanup.append(cleanup_handler)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -105,7 +105,7 @@ def test_run_app_http(patched_loop) -> None:

def test_run_app_close_loop(patched_loop) -> None:
app = web.Application()
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_run_app_mixed_bindings(
run_app_kwargs, expected_server_calls, expected_unix_server_calls, patched_loop
):
app = web.Application()
web.run_app(app, print=stopper(patched_loop), **run_app_kwargs)
web.run_app(app, print=stopper(patched_loop), **run_app_kwargs, loop=patched_loop)

assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls
assert patched_loop.create_server.mock_calls == expected_server_calls
Expand All @@ -435,7 +435,9 @@ def test_run_app_https(patched_loop) -> None:
app = web.Application()

ssl_context = ssl.create_default_context()
web.run_app(app, ssl_context=ssl_context, print=stopper(patched_loop))
web.run_app(
app, ssl_context=ssl_context, print=stopper(patched_loop), loop=patched_loop
)

patched_loop.create_server.assert_called_with(
mock.ANY,
Expand All @@ -453,7 +455,9 @@ def test_run_app_nondefault_host_port(patched_loop, aiohttp_unused_port) -> None
host = "127.0.0.1"

app = web.Application()
web.run_app(app, host=host, port=port, print=stopper(patched_loop))
web.run_app(
app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop
)

patched_loop.create_server.assert_called_with(
mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -464,7 +468,7 @@ def test_run_app_multiple_hosts(patched_loop) -> None:
hosts = ("127.0.0.1", "127.0.0.2")

app = web.Application()
web.run_app(app, host=hosts, print=stopper(patched_loop))
web.run_app(app, host=hosts, print=stopper(patched_loop), loop=patched_loop)

calls = map(
lambda h: mock.call(
Expand All @@ -483,7 +487,7 @@ def test_run_app_multiple_hosts(patched_loop) -> None:

def test_run_app_custom_backlog(patched_loop) -> None:
app = web.Application()
web.run_app(app, backlog=10, print=stopper(patched_loop))
web.run_app(app, backlog=10, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None
Expand All @@ -492,7 +496,13 @@ def test_run_app_custom_backlog(patched_loop) -> None:

def test_run_app_custom_backlog_unix(patched_loop) -> None:
app = web.Application()
web.run_app(app, path="/tmp/tmpsock.sock", backlog=10, print=stopper(patched_loop))
web.run_app(
app,
path="/tmp/tmpsock.sock",
backlog=10,
print=stopper(patched_loop),
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, "/tmp/tmpsock.sock", ssl=None, backlog=10
Expand All @@ -505,7 +515,7 @@ def test_run_app_http_unix_socket(patched_loop, shorttmpdir) -> None:

sock_path = str(shorttmpdir / "socket.sock")
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, print=printer)
web.run_app(app, path=sock_path, print=printer, loop=patched_loop)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=None, backlog=128
Expand All @@ -520,7 +530,9 @@ def test_run_app_https_unix_socket(patched_loop, shorttmpdir) -> None:
sock_path = str(shorttmpdir / "socket.sock")
ssl_context = ssl.create_default_context()
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, ssl_context=ssl_context, print=printer)
web.run_app(
app, path=sock_path, ssl_context=ssl_context, print=printer, loop=patched_loop
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=ssl_context, backlog=128
Expand All @@ -534,7 +546,10 @@ def test_run_app_abstract_linux_socket(patched_loop) -> None:
sock_path = b"\x00" + uuid4().hex.encode("ascii")
app = web.Application()
web.run_app(
app, path=sock_path.decode("ascii", "ignore"), print=stopper(patched_loop)
app,
path=sock_path.decode("ascii", "ignore"),
print=stopper(patched_loop),
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
Expand All @@ -551,7 +566,7 @@ def test_run_app_preexisting_inet_socket(patched_loop, mocker) -> None:
_, port = sock.getsockname()

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -569,7 +584,7 @@ def test_run_app_preexisting_inet6_socket(patched_loop) -> None:
port = sock.getsockname()[1]

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -588,7 +603,7 @@ def test_run_app_preexisting_unix_socket(patched_loop, mocker) -> None:
os.unlink(sock_path)

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -608,7 +623,7 @@ def test_run_app_multiple_preexisting_sockets(patched_loop) -> None:
_, port2 = sock2.getsockname()

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=(sock1, sock2), print=printer)
web.run_app(app, sock=(sock1, sock2), print=printer, loop=patched_loop)

patched_loop.create_server.assert_has_calls(
[
Expand Down Expand Up @@ -664,7 +679,7 @@ def test_startup_cleanup_signals_even_on_failure(patched_loop) -> None:
app.on_cleanup.append(cleanup_handler)

with pytest.raises(RuntimeError):
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

startup_handler.assert_called_once_with(app)
cleanup_handler.assert_called_once_with(app)
Expand All @@ -682,7 +697,7 @@ async def make_app():
app.on_cleanup.append(cleanup_handler)
return app

web.run_app(make_app(), print=stopper(patched_loop))
web.run_app(make_app(), print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -703,7 +718,12 @@ def test_run_app_default_logger(monkeypatch, patched_loop):
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_any_call(logging.DEBUG)
mock_logger.hasHandlers.assert_called_with()
assert isinstance(mock_logger.addHandler.call_args[0][0], logging.StreamHandler)
Expand All @@ -721,7 +741,12 @@ def test_run_app_default_logger_setup_requires_debug(patched_loop):
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_not_called()
mock_logger.addHandler.assert_not_called()
Expand All @@ -739,7 +764,12 @@ def test_run_app_default_logger_setup_requires_default_logger(patched_loop):
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_not_called()
mock_logger.addHandler.assert_not_called()
Expand All @@ -757,7 +787,12 @@ def test_run_app_default_logger_setup_only_if_unconfigured(patched_loop):
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_called_with()
mock_logger.addHandler.assert_not_called()
Expand All @@ -774,7 +809,7 @@ async def on_startup(app):

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.cancelled()


Expand All @@ -792,7 +827,7 @@ async def on_startup(app):

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.done()


Expand All @@ -818,7 +853,7 @@ async def on_startup(app):

exc_handler = mock.Mock()
patched_loop.set_exception_handler(exc_handler)
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.done()

msg = {
Expand All @@ -839,7 +874,12 @@ def base_runner_init_spy(self, *args, **kwargs):

app = web.Application()
monkeypatch.setattr(BaseRunner, "__init__", base_runner_init_spy)
web.run_app(app, keepalive_timeout=new_timeout, print=stopper(patched_loop))
web.run_app(
app,
keepalive_timeout=new_timeout,
print=stopper(patched_loop),
loop=patched_loop,
)


@pytest.mark.skipif(not PY_37, reason="contextvars support is required")
Expand Down Expand Up @@ -871,5 +911,5 @@ async def init():
count += 1
return app

web.run_app(init(), print=stopper(patched_loop))
web.run_app(init(), print=stopper(patched_loop), loop=patched_loop)
assert count == 3
26 changes: 26 additions & 0 deletions tests/test_web_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import platform
import signal
import sys
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -162,3 +163,28 @@ async def mock_create_server(*args, **kwargs):
assert server is runner.server
assert host is None
assert port == 8080


@pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires asyncio.run()")
def test_run_after_asyncio_run() -> None:
async def nothing():
pass

def spy():
spy.called = True

spy.called = False

async def shutdown():
spy()
raise web.GracefulExit()

# asyncio.run() creates a new loop and closes it.
asyncio.run(nothing())

app = web.Application()
# create_task() will delay the function until app is run.
app.on_startup.append(lambda a: asyncio.create_task(shutdown()))

web.run_app(app)
assert spy.called, "run_app() should work after asyncio.run()."

0 comments on commit fa46667

Please sign in to comment.