diff --git a/examples/telnet/hello-world.py b/examples/telnet/hello-world.py index f73533356..68b6918af 100755 --- a/examples/telnet/hello-world.py +++ b/examples/telnet/hello-world.py @@ -32,10 +32,7 @@ async def interact(connection): async def main(): server = TelnetServer(interact=interact, port=2323) - server.start() - - # Run forever. - await Future() + await server.run() if __name__ == "__main__": diff --git a/examples/telnet/toolbar.py b/examples/telnet/toolbar.py index d8d1ed59a..d73a4db11 100755 --- a/examples/telnet/toolbar.py +++ b/examples/telnet/toolbar.py @@ -37,10 +37,7 @@ def get_toolbar(): async def main(): server = TelnetServer(interact=interact, port=2323) - server.start() - - # Run forever. - await Future() + await server.run() if __name__ == "__main__": diff --git a/src/prompt_toolkit/contrib/telnet/server.py b/src/prompt_toolkit/contrib/telnet/server.py index e72550c5a..ddd7ca785 100644 --- a/src/prompt_toolkit/contrib/telnet/server.py +++ b/src/prompt_toolkit/contrib/telnet/server.py @@ -283,10 +283,11 @@ def __init__( self.encoding = encoding self.style = style self.enable_cpr = enable_cpr + + self._run_task: asyncio.Task[None] | None = None self._application_tasks: list[asyncio.Task[None]] = [] self.connections: set[TelnetConnection] = set() - self._listen_socket: socket.socket | None = None @classmethod def _create_socket(cls, host: str, port: int) -> socket.socket: @@ -298,44 +299,74 @@ def _create_socket(cls, host: str, port: int) -> socket.socket: s.listen(4) return s - def start(self) -> None: + async def run(self, ready_cb: Callable[[], None] | None = None) -> None: """ - Start the telnet server. - Don't forget to call `loop.run_forever()` after doing this. + Run the telnet server, until this gets cancelled. + + :param ready_cb: Callback that will be called at the point that we're + actually listening. """ - self._listen_socket = self._create_socket(self.host, self.port) + socket = self._create_socket(self.host, self.port) logger.info( "Listening for telnet connections on %s port %r", self.host, self.port ) - get_running_loop().add_reader(self._listen_socket, self._accept) + get_running_loop().add_reader(socket, lambda: self._accept(socket)) + + if ready_cb: + ready_cb() + + try: + # Run forever, until cancelled. + await asyncio.Future() + finally: + get_running_loop().remove_reader(socket) + socket.close() + + # Wait for all applications to finish. + for t in self._application_tasks: + t.cancel() + + # (This is similar to + # `Application.cancel_and_wait_for_background_tasks`. We wait for the + # background tasks to complete, but don't propagate exceptions, because + # we can't use `ExceptionGroup` yet.) + if len(self._application_tasks) > 0: + await asyncio.wait( + self._application_tasks, + timeout=None, + return_when=asyncio.ALL_COMPLETED, + ) + + def start(self) -> None: + """ + Start the telnet server (stop by calling and awaiting `stop()`). + + Note: When possible, it's better to call `.run()` instead. + """ + if self._run_task is not None: + # Already running. + return + + self._run_task = get_running_loop().create_task(self.run()) async def stop(self) -> None: - if self._listen_socket: - get_running_loop().remove_reader(self._listen_socket) - self._listen_socket.close() - - # Wait for all applications to finish. - for t in self._application_tasks: - t.cancel() - - # (This is similar to - # `Application.cancel_and_wait_for_background_tasks`. We wait for the - # background tasks to complete, but don't propagate exceptions, because - # we can't use `ExceptionGroup` yet.) - if len(self._application_tasks) > 0: - await asyncio.wait( - self._application_tasks, timeout=None, return_when=asyncio.ALL_COMPLETED - ) + """ + Stop a telnet server that was started using `.start()` and wait for the + cancellation to complete. + """ + if self._run_task is not None: + self._run_task.cancel() + try: + await self._run_task + except asyncio.CancelledError: + pass - def _accept(self) -> None: + def _accept(self, listen_socket: socket.socket) -> None: """ Accept new incoming connection. """ - if self._listen_socket is None: - return # Should not happen. `_accept` is called after `start`. - - conn, addr = self._listen_socket.accept() + conn, addr = listen_socket.accept() logger.info("New connection %r %r", *addr) # Run application for this connection.