Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a new loop in run_app() #5572

Merged
merged 18 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/5572.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Always create a new event loop in ``aiohttp.web.run_app()``.
This adds better compatibility with ``asyncio.run()`` or if trying to run multiple apps in sequence.
46 changes: 25 additions & 21 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,11 @@ def run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
loop = asyncio.get_event_loop()
if loop is None:
loop = asyncio.new_event_loop()
loop.set_debug(debug)

# Configure if and only if in debugging mode and using the default logger
Expand All @@ -489,27 +491,29 @@ def run_app(
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

try:
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
)

try:
asyncio.set_event_loop(loop)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to have it here. At least, put it right before try:. The reason is that it's an antipattern to have more than one instruction in a try block (it tends to shadow exceptions sometimes and causes hard-to-debug situations).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I just noticed that it was inside the try block in asyncio.run(), figured there was probably a good reason for it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you could add a separate try/except if you want to call loop.close() and re-raise. There's no need to asyncio.set_event_loop(None) because setting it didn't happen anyway.

Copy link
Member Author

@Dreamsorcerer Dreamsorcerer Jun 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept these 2 lines together in the try to match https://github.com/python/cpython/blob/main/Lib/asyncio/runners.py#L40-L44

But, I've moved the create_task() to before the try statement, as it doesn't seem to have any business being there.

loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
Expand Down
96 changes: 70 additions & 26 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_run_app_http(patched_loop: Any) -> None:
cleanup_handler = make_mocked_coro()
app.on_cleanup.append(cleanup_handler)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -105,7 +105,7 @@ def test_run_app_http(patched_loop: Any) -> None:

def test_run_app_close_loop(patched_loop: Any) -> None:
app = web.Application()
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand Down Expand Up @@ -428,7 +428,7 @@ def test_run_app_mixed_bindings(
patched_loop: Any,
) -> None:
app = web.Application()
web.run_app(app, print=stopper(patched_loop), **run_app_kwargs)
web.run_app(app, print=stopper(patched_loop), **run_app_kwargs, loop=patched_loop)

assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls
assert patched_loop.create_server.mock_calls == expected_server_calls
Expand All @@ -438,7 +438,9 @@ def test_run_app_https(patched_loop: Any) -> None:
app = web.Application()

ssl_context = ssl.create_default_context()
web.run_app(app, ssl_context=ssl_context, print=stopper(patched_loop))
web.run_app(
app, ssl_context=ssl_context, print=stopper(patched_loop), loop=patched_loop
)

patched_loop.create_server.assert_called_with(
mock.ANY,
Expand All @@ -458,7 +460,9 @@ def test_run_app_nondefault_host_port(
host = "127.0.0.1"

app = web.Application()
web.run_app(app, host=host, port=port, print=stopper(patched_loop))
web.run_app(
app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop
)

patched_loop.create_server.assert_called_with(
mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -469,7 +473,7 @@ def test_run_app_multiple_hosts(patched_loop: Any) -> None:
hosts = ("127.0.0.1", "127.0.0.2")

app = web.Application()
web.run_app(app, host=hosts, print=stopper(patched_loop))
web.run_app(app, host=hosts, print=stopper(patched_loop), loop=patched_loop)

calls = map(
lambda h: mock.call(
Expand All @@ -488,7 +492,7 @@ def test_run_app_multiple_hosts(patched_loop: Any) -> None:

def test_run_app_custom_backlog(patched_loop: Any) -> None:
app = web.Application()
web.run_app(app, backlog=10, print=stopper(patched_loop))
web.run_app(app, backlog=10, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None
Expand All @@ -497,7 +501,13 @@ def test_run_app_custom_backlog(patched_loop: Any) -> None:

def test_run_app_custom_backlog_unix(patched_loop: Any) -> None:
app = web.Application()
web.run_app(app, path="/tmp/tmpsock.sock", backlog=10, print=stopper(patched_loop))
web.run_app(
app,
path="/tmp/tmpsock.sock",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: ouch, this looks awfully hardcoded. It's probably a good idea to fix such hardcoded socket paths across tests (in a dedicated PR, of course).

backlog=10,
print=stopper(patched_loop),
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, "/tmp/tmpsock.sock", ssl=None, backlog=10
Expand All @@ -510,7 +520,7 @@ def test_run_app_http_unix_socket(patched_loop: Any, tmp_path: Any) -> None:

sock_path = str(tmp_path / "socket.sock")
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, print=printer)
web.run_app(app, path=sock_path, print=printer, loop=patched_loop)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=None, backlog=128
Expand All @@ -525,7 +535,9 @@ def test_run_app_https_unix_socket(patched_loop: Any, tmp_path: Any) -> None:
sock_path = str(tmp_path / "socket.sock")
ssl_context = ssl.create_default_context()
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, ssl_context=ssl_context, print=printer)
web.run_app(
app, path=sock_path, ssl_context=ssl_context, print=printer, loop=patched_loop
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=ssl_context, backlog=128
Expand All @@ -539,7 +551,10 @@ def test_run_app_abstract_linux_socket(patched_loop: Any) -> None:
sock_path = b"\x00" + uuid4().hex.encode("ascii")
app = web.Application()
web.run_app(
app, path=sock_path.decode("ascii", "ignore"), print=stopper(patched_loop)
app,
path=sock_path.decode("ascii", "ignore"),
print=stopper(patched_loop),
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
Expand All @@ -556,7 +571,7 @@ def test_run_app_preexisting_inet_socket(patched_loop: Any, mocker: Any) -> None
_, port = sock.getsockname()

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -574,7 +589,7 @@ def test_run_app_preexisting_inet6_socket(patched_loop: Any) -> None:
port = sock.getsockname()[1]

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -593,7 +608,7 @@ def test_run_app_preexisting_unix_socket(patched_loop: Any, mocker: Any) -> None
os.unlink(sock_path)

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -613,7 +628,7 @@ def test_run_app_multiple_preexisting_sockets(patched_loop: Any) -> None:
_, port2 = sock2.getsockname()

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=(sock1, sock2), print=printer)
web.run_app(app, sock=(sock1, sock2), print=printer, loop=patched_loop)

patched_loop.create_server.assert_has_calls(
[
Expand Down Expand Up @@ -671,7 +686,7 @@ def test_startup_cleanup_signals_even_on_failure(patched_loop: Any) -> None:
app.on_cleanup.append(cleanup_handler)

with pytest.raises(RuntimeError):
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

startup_handler.assert_called_once_with(app)
cleanup_handler.assert_called_once_with(app)
Expand All @@ -689,7 +704,7 @@ async def make_app():
app.on_cleanup.append(cleanup_handler)
return app

web.run_app(make_app(), print=stopper(patched_loop))
web.run_app(make_app(), print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -709,7 +724,13 @@ def test_run_app_default_logger(monkeypatch: Any, patched_loop: Any) -> None:
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, debug=True, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
debug=True,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_any_call(logging.DEBUG)
mock_logger.hasHandlers.assert_called_with()
assert isinstance(mock_logger.addHandler.call_args[0][0], logging.StreamHandler)
Expand All @@ -726,7 +747,13 @@ def test_run_app_default_logger_setup_requires_debug(patched_loop: Any) -> None:
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, debug=False, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
debug=False,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_not_called()
mock_logger.addHandler.assert_not_called()
Expand All @@ -745,7 +772,13 @@ def test_run_app_default_logger_setup_requires_default_logger(
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, debug=True, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
debug=True,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_not_called()
mock_logger.addHandler.assert_not_called()
Expand All @@ -762,7 +795,13 @@ def test_run_app_default_logger_setup_only_if_unconfigured(patched_loop: Any) ->
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, debug=True, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
debug=True,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_called_with()
mock_logger.addHandler.assert_not_called()
Expand All @@ -779,7 +818,7 @@ async def on_startup(app):

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.cancelled()


Expand All @@ -797,7 +836,7 @@ async def on_startup(app):

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.done()


Expand All @@ -823,7 +862,7 @@ async def on_startup(app):

exc_handler = mock.Mock()
patched_loop.set_exception_handler(exc_handler)
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.done()

msg = {
Expand All @@ -846,7 +885,12 @@ def base_runner_init_spy(self, *args, **kwargs):

app = web.Application()
monkeypatch.setattr(BaseRunner, "__init__", base_runner_init_spy)
web.run_app(app, keepalive_timeout=new_timeout, print=stopper(patched_loop))
web.run_app(
app,
keepalive_timeout=new_timeout,
print=stopper(patched_loop),
loop=patched_loop,
)


def test_run_app_context_vars(patched_loop: Any):
Expand Down Expand Up @@ -877,5 +921,5 @@ async def init():
count += 1
return app

web.run_app(init(), print=stopper(patched_loop))
web.run_app(init(), print=stopper(patched_loop), loop=patched_loop)
assert count == 3
24 changes: 24 additions & 0 deletions tests/test_web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,27 @@ async def mock_create_server(*args, **kwargs):
assert server is runner.server
assert host is None
assert port == 8080


def test_run_after_asyncio_run() -> None:
async def nothing():
pass

def spy():
spy.called = True

spy.called = False

async def shutdown():
spy()
raise web.GracefulExit()

# asyncio.run() creates a new loop and closes it.
asyncio.run(nothing())

app = web.Application()
# create_task() will delay the function until app is run.
app.on_startup.append(lambda a: asyncio.create_task(shutdown()))

web.run_app(app)
assert spy.called, "run_app() should work after asyncio.run()."