Skip to content

Commit

Permalink
Use a new loop in run_app() (#5572)
Browse files Browse the repository at this point in the history
Co-authored-by: Sviatoslav Sydorenko <[email protected]>
  • Loading branch information
Dreamsorcerer and webknjaz authored Jun 21, 2021
1 parent 01ba0dd commit 88d1f80
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 47 deletions.
2 changes: 2 additions & 0 deletions CHANGES/5572.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Always create a new event loop in ``aiohttp.web.run_app()``.
This adds better compatibility with ``asyncio.run()`` or if trying to run multiple apps in sequence.
46 changes: 25 additions & 21 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,11 @@ def run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
loop = asyncio.get_event_loop()
if loop is None:
loop = asyncio.new_event_loop()
loop.set_debug(debug)

# Configure if and only if in debugging mode and using the default logger
Expand All @@ -489,27 +491,29 @@ def run_app(
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

try:
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
)

try:
asyncio.set_event_loop(loop)
loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
Expand Down
96 changes: 70 additions & 26 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_run_app_http(patched_loop: Any) -> None:
cleanup_handler = make_mocked_coro()
app.on_cleanup.append(cleanup_handler)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -105,7 +105,7 @@ def test_run_app_http(patched_loop: Any) -> None:

def test_run_app_close_loop(patched_loop: Any) -> None:
app = web.Application()
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand Down Expand Up @@ -428,7 +428,7 @@ def test_run_app_mixed_bindings(
patched_loop: Any,
) -> None:
app = web.Application()
web.run_app(app, print=stopper(patched_loop), **run_app_kwargs)
web.run_app(app, print=stopper(patched_loop), **run_app_kwargs, loop=patched_loop)

assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls
assert patched_loop.create_server.mock_calls == expected_server_calls
Expand All @@ -438,7 +438,9 @@ def test_run_app_https(patched_loop: Any) -> None:
app = web.Application()

ssl_context = ssl.create_default_context()
web.run_app(app, ssl_context=ssl_context, print=stopper(patched_loop))
web.run_app(
app, ssl_context=ssl_context, print=stopper(patched_loop), loop=patched_loop
)

patched_loop.create_server.assert_called_with(
mock.ANY,
Expand All @@ -458,7 +460,9 @@ def test_run_app_nondefault_host_port(
host = "127.0.0.1"

app = web.Application()
web.run_app(app, host=host, port=port, print=stopper(patched_loop))
web.run_app(
app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop
)

patched_loop.create_server.assert_called_with(
mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -469,7 +473,7 @@ def test_run_app_multiple_hosts(patched_loop: Any) -> None:
hosts = ("127.0.0.1", "127.0.0.2")

app = web.Application()
web.run_app(app, host=hosts, print=stopper(patched_loop))
web.run_app(app, host=hosts, print=stopper(patched_loop), loop=patched_loop)

calls = map(
lambda h: mock.call(
Expand All @@ -488,7 +492,7 @@ def test_run_app_multiple_hosts(patched_loop: Any) -> None:

def test_run_app_custom_backlog(patched_loop: Any) -> None:
app = web.Application()
web.run_app(app, backlog=10, print=stopper(patched_loop))
web.run_app(app, backlog=10, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None
Expand All @@ -497,7 +501,13 @@ def test_run_app_custom_backlog(patched_loop: Any) -> None:

def test_run_app_custom_backlog_unix(patched_loop: Any) -> None:
app = web.Application()
web.run_app(app, path="/tmp/tmpsock.sock", backlog=10, print=stopper(patched_loop))
web.run_app(
app,
path="/tmp/tmpsock.sock",
backlog=10,
print=stopper(patched_loop),
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, "/tmp/tmpsock.sock", ssl=None, backlog=10
Expand All @@ -510,7 +520,7 @@ def test_run_app_http_unix_socket(patched_loop: Any, tmp_path: Any) -> None:

sock_path = str(tmp_path / "socket.sock")
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, print=printer)
web.run_app(app, path=sock_path, print=printer, loop=patched_loop)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=None, backlog=128
Expand All @@ -525,7 +535,9 @@ def test_run_app_https_unix_socket(patched_loop: Any, tmp_path: Any) -> None:
sock_path = str(tmp_path / "socket.sock")
ssl_context = ssl.create_default_context()
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, ssl_context=ssl_context, print=printer)
web.run_app(
app, path=sock_path, ssl_context=ssl_context, print=printer, loop=patched_loop
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=ssl_context, backlog=128
Expand All @@ -539,7 +551,10 @@ def test_run_app_abstract_linux_socket(patched_loop: Any) -> None:
sock_path = b"\x00" + uuid4().hex.encode("ascii")
app = web.Application()
web.run_app(
app, path=sock_path.decode("ascii", "ignore"), print=stopper(patched_loop)
app,
path=sock_path.decode("ascii", "ignore"),
print=stopper(patched_loop),
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
Expand All @@ -556,7 +571,7 @@ def test_run_app_preexisting_inet_socket(patched_loop: Any, mocker: Any) -> None
_, port = sock.getsockname()

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -574,7 +589,7 @@ def test_run_app_preexisting_inet6_socket(patched_loop: Any) -> None:
port = sock.getsockname()[1]

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -593,7 +608,7 @@ def test_run_app_preexisting_unix_socket(patched_loop: Any, mocker: Any) -> None
os.unlink(sock_path)

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -613,7 +628,7 @@ def test_run_app_multiple_preexisting_sockets(patched_loop: Any) -> None:
_, port2 = sock2.getsockname()

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=(sock1, sock2), print=printer)
web.run_app(app, sock=(sock1, sock2), print=printer, loop=patched_loop)

patched_loop.create_server.assert_has_calls(
[
Expand Down Expand Up @@ -671,7 +686,7 @@ def test_startup_cleanup_signals_even_on_failure(patched_loop: Any) -> None:
app.on_cleanup.append(cleanup_handler)

with pytest.raises(RuntimeError):
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

startup_handler.assert_called_once_with(app)
cleanup_handler.assert_called_once_with(app)
Expand All @@ -689,7 +704,7 @@ async def make_app():
app.on_cleanup.append(cleanup_handler)
return app

web.run_app(make_app(), print=stopper(patched_loop))
web.run_app(make_app(), print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -709,7 +724,13 @@ def test_run_app_default_logger(monkeypatch: Any, patched_loop: Any) -> None:
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, debug=True, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
debug=True,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_any_call(logging.DEBUG)
mock_logger.hasHandlers.assert_called_with()
assert isinstance(mock_logger.addHandler.call_args[0][0], logging.StreamHandler)
Expand All @@ -726,7 +747,13 @@ def test_run_app_default_logger_setup_requires_debug(patched_loop: Any) -> None:
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, debug=False, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
debug=False,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_not_called()
mock_logger.addHandler.assert_not_called()
Expand All @@ -745,7 +772,13 @@ def test_run_app_default_logger_setup_requires_default_logger(
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, debug=True, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
debug=True,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_not_called()
mock_logger.addHandler.assert_not_called()
Expand All @@ -762,7 +795,13 @@ def test_run_app_default_logger_setup_only_if_unconfigured(patched_loop: Any) ->
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, debug=True, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
debug=True,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_called_with()
mock_logger.addHandler.assert_not_called()
Expand All @@ -779,7 +818,7 @@ async def on_startup(app):

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.cancelled()


Expand All @@ -797,7 +836,7 @@ async def on_startup(app):

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.done()


Expand All @@ -823,7 +862,7 @@ async def on_startup(app):

exc_handler = mock.Mock()
patched_loop.set_exception_handler(exc_handler)
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.done()

msg = {
Expand All @@ -846,7 +885,12 @@ def base_runner_init_spy(self, *args, **kwargs):

app = web.Application()
monkeypatch.setattr(BaseRunner, "__init__", base_runner_init_spy)
web.run_app(app, keepalive_timeout=new_timeout, print=stopper(patched_loop))
web.run_app(
app,
keepalive_timeout=new_timeout,
print=stopper(patched_loop),
loop=patched_loop,
)


def test_run_app_context_vars(patched_loop: Any):
Expand Down Expand Up @@ -877,5 +921,5 @@ async def init():
count += 1
return app

web.run_app(init(), print=stopper(patched_loop))
web.run_app(init(), print=stopper(patched_loop), loop=patched_loop)
assert count == 3
24 changes: 24 additions & 0 deletions tests/test_web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,27 @@ async def mock_create_server(*args, **kwargs):
assert server is runner.server
assert host is None
assert port == 8080


def test_run_after_asyncio_run() -> None:
async def nothing():
pass

def spy():
spy.called = True

spy.called = False

async def shutdown():
spy()
raise web.GracefulExit()

# asyncio.run() creates a new loop and closes it.
asyncio.run(nothing())

app = web.Application()
# create_task() will delay the function until app is run.
app.on_startup.append(lambda a: asyncio.create_task(shutdown()))

web.run_app(app)
assert spy.called, "run_app() should work after asyncio.run()."

0 comments on commit 88d1f80

Please sign in to comment.