Skip to content

Commit

Permalink
Replace Tornado with AnyIO
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Oct 26, 2023
1 parent 2a8adb9 commit 8ff47dc
Show file tree
Hide file tree
Showing 27 changed files with 879 additions and 871 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
run: |
hatch run typing:test
hatch run lint:style
pipx run interrogate -vv .
pipx run interrogate -vv . --fail-under 90
pipx run doc8 --max-line-length=200
check_release:
Expand Down
40 changes: 4 additions & 36 deletions examples/embedding/inprocess_terminal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""An in-process terminal example."""
import os
import sys

import tornado
from anyio import run
from jupyter_console.ptshell import ZMQTerminalInteractiveShell

from ipykernel.inprocess.manager import InProcessKernelManager
Expand All @@ -13,46 +12,15 @@ def print_process_id():
print("Process ID is:", os.getpid())


def init_asyncio_patch():
"""set default asyncio policy to be compatible with tornado
Tornado 6 (at least) is not compatible with the default
asyncio implementation on Windows
Pick the older SelectorEventLoopPolicy on Windows
if the known-incompatible default policy is in use.
do this as early as possible to make it a low priority and overridable
ref: https://github.com/tornadoweb/tornado/issues/2608
FIXME: if/when tornado supports the defaults in asyncio,
remove and bump tornado requirement for py38
"""
if (
sys.platform.startswith("win")
and sys.version_info >= (3, 8)
and tornado.version_info < (6, 1)
):
import asyncio

try:
from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
except ImportError:
pass
# not affected
else:
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
# fallback to the pre-3.8 default of Selector
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())


def main():
async def main():
"""The main function."""
print_process_id()

# Create an in-process kernel
# >>> print_process_id()
# will print the same process ID as the main process
init_asyncio_patch()
kernel_manager = InProcessKernelManager()
kernel_manager.start_kernel()
await kernel_manager.start_kernel()
kernel = kernel_manager.kernel
kernel.gui = "qt4"
kernel.shell.push({"foo": 43, "print_process_id": print_process_id})
Expand All @@ -64,4 +32,4 @@ def main():


if __name__ == "__main__":
main()
run(main)
24 changes: 16 additions & 8 deletions ipykernel/control.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A thread for a control channel."""
from threading import Thread
from threading import Event, Thread

from tornado.ioloop import IOLoop
from anyio import create_task_group, run, to_thread

CONTROL_THREAD_NAME = "Control"

Expand All @@ -12,21 +12,29 @@ class ControlThread(Thread):
def __init__(self, **kwargs):
"""Initialize the thread."""
Thread.__init__(self, name=CONTROL_THREAD_NAME, **kwargs)
self.io_loop = IOLoop(make_current=False)
self.pydev_do_not_trace = True
self.is_pydev_daemon_thread = True
self.__stop = Event()
self._task = None

def set_task(self, task):
self._task = task

def run(self):
"""Run the thread."""
self.name = CONTROL_THREAD_NAME
try:
self.io_loop.start()
finally:
self.io_loop.close()
run(self._main)

async def _main(self):
async with create_task_group() as tg:
if self._task is not None:
tg.start_soon(self._task)
await to_thread.run_sync(self.__stop.wait)
tg.cancel_scope.cancel()

def stop(self):
"""Stop the thread.
This method is threadsafe.
"""
self.io_loop.add_callback(self.io_loop.stop)
self.__stop.set()
67 changes: 41 additions & 26 deletions ipykernel/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import re
import sys
import typing as t
from math import inf
from typing import Any

import zmq
from anyio import Event, create_memory_object_stream
from IPython.core.getipython import get_ipython
from IPython.core.inputtransformer2 import leading_empty_lines
from tornado.locks import Event
from tornado.queues import Queue
from zmq.utils import jsonapi

try:
Expand Down Expand Up @@ -116,7 +117,9 @@ def __init__(self, event_callback, log):
self.tcp_buffer = ""
self._reset_tcp_pos()
self.event_callback = event_callback
self.message_queue: Queue[t.Any] = Queue()
self.message_send_stream, self.message_receive_stream = create_memory_object_stream[
dict
](max_buffer_size=inf)
self.log = log

def _reset_tcp_pos(self):
Expand All @@ -135,7 +138,7 @@ def _put_message(self, raw_msg):
else:
self.log.debug("QUEUE - put message:")
self.log.debug(msg)
self.message_queue.put_nowait(msg)
self.message_send_stream.send_nowait(msg)

def put_tcp_frame(self, frame):
"""Put a tcp frame in the queue."""
Expand Down Expand Up @@ -186,25 +189,31 @@ def put_tcp_frame(self, frame):

async def get_message(self):
"""Get a message from the queue."""
return await self.message_queue.get()
return await self.message_receive_stream.receive()


class DebugpyClient:
"""A client for debugpy."""

def __init__(self, log, debugpy_stream, event_callback):
def __init__(self, log, debugpy_socket, event_callback):
"""Initialize the client."""
self.log = log
self.debugpy_stream = debugpy_stream
self.debugpy_socket = debugpy_socket
self.event_callback = event_callback
self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
self.debugpy_host = "127.0.0.1"
self.debugpy_port = -1
self.routing_id = None
self.wait_for_attach = True
self.init_event = Event()
self._init_event = None
self.init_event_seq = -1

@property
def init_event(self):
if self._init_event is None:
self._init_event = Event()
return self._init_event

def _get_endpoint(self):
host, port = self.get_host_port()
return "tcp://" + host + ":" + str(port)
Expand All @@ -215,9 +224,9 @@ def _forward_event(self, msg):
self.init_event_seq = msg["seq"]
self.event_callback(msg)

def _send_request(self, msg):
async def _send_request(self, msg):
if self.routing_id is None:
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
content = jsonapi.dumps(
msg,
default=json_default,
Expand All @@ -232,7 +241,7 @@ def _send_request(self, msg):
self.log.debug("DEBUGPYCLIENT:")
self.log.debug(self.routing_id)
self.log.debug(buf)
self.debugpy_stream.send_multipart((self.routing_id, buf))
await self.debugpy_socket.send_multipart((self.routing_id, buf))

async def _wait_for_response(self):
# Since events are never pushed to the message_queue
Expand All @@ -250,7 +259,7 @@ async def _handle_init_sequence(self):
"seq": int(self.init_event_seq) + 1,
"command": "configurationDone",
}
self._send_request(configurationDone)
await self._send_request(configurationDone)

# 3] Waits for configurationDone response
await self._wait_for_response()
Expand All @@ -262,7 +271,7 @@ async def _handle_init_sequence(self):
def get_host_port(self):
"""Get the host debugpy port."""
if self.debugpy_port == -1:
socket = self.debugpy_stream.socket
socket = self.debugpy_socket
socket.bind_to_random_port("tcp://" + self.debugpy_host)
self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8")
socket.unbind(self.endpoint)
Expand All @@ -272,14 +281,13 @@ def get_host_port(self):

def connect_tcp_socket(self):
"""Connect to the tcp socket."""
self.debugpy_stream.socket.connect(self._get_endpoint())
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
self.debugpy_socket.connect(self._get_endpoint())
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)

def disconnect_tcp_socket(self):
"""Disconnect from the tcp socket."""
self.debugpy_stream.socket.disconnect(self._get_endpoint())
self.debugpy_socket.disconnect(self._get_endpoint())
self.routing_id = None
self.init_event = Event()
self.init_event_seq = -1
self.wait_for_attach = True

Expand All @@ -289,7 +297,7 @@ def receive_dap_frame(self, frame):

async def send_dap_request(self, msg):
"""Send a dap request."""
self._send_request(msg)
await self._send_request(msg)
if self.wait_for_attach and msg["command"] == "attach":
rep = await self._handle_init_sequence()
self.wait_for_attach = False
Expand Down Expand Up @@ -325,17 +333,19 @@ class Debugger:
]

def __init__(
self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True
self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True
):
"""Initialize the debugger."""
self.log = log
self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event)
self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event)
self.shell_socket = shell_socket
self.session = session
self.is_started = False
self.event_callback = event_callback
self.just_my_code = just_my_code
self.stopped_queue: Queue[t.Any] = Queue()
self.stopped_send_stream, self.stopped_receive_stream = create_memory_object_stream[
dict
](max_buffer_size=inf)

self.started_debug_handlers = {}
for msg_type in Debugger.started_debug_msg_types:
Expand All @@ -360,7 +370,7 @@ def __init__(
def _handle_event(self, msg):
if msg["event"] == "stopped":
if msg["body"]["allThreadsStopped"]:
self.stopped_queue.put_nowait(msg)
self.stopped_send_stream.send_nowait(msg)
# Do not forward the event now, will be done in the handle_stopped_event
return
else:
Expand Down Expand Up @@ -400,7 +410,7 @@ async def handle_stopped_event(self):
"""Handle a stopped event."""
# Wait for a stopped event message in the stopped queue
# This message is used for triggering the 'threads' request
event = await self.stopped_queue.get()
event = await self.stopped_receive_stream.receive()
req = {"seq": event["seq"] + 1, "type": "request", "command": "threads"}
rep = await self._forward_message(req)
for thread in rep["body"]["threads"]:
Expand All @@ -412,7 +422,7 @@ async def handle_stopped_event(self):
def tcp_client(self):
return self.debugpy_client

def start(self):
async def start(self):
"""Start the debugger."""
if not self.debugpy_initialized:
tmp_dir = get_tmp_directory()
Expand All @@ -430,7 +440,12 @@ def start(self):
(self.shell_socket.getsockopt(ROUTING_ID)),
)

ident, msg = self.session.recv(self.shell_socket, mode=0)
msg = await self.shell_socket.recv_multipart()
ident, msg = self.session.feed_identities(msg, copy=True)
try:
msg = self.session.deserialize(msg, content=True, copy=True)
except Exception:
self.log.error("Invalid message", exc_info=True)
self.debugpy_initialized = msg["content"]["status"] == "ok"

# Don't remove leading empty lines when debugging so the breakpoints are correctly positioned
Expand Down Expand Up @@ -719,7 +734,7 @@ async def process_request(self, message):
if self.is_started:
self.log.info("The debugger has already started")
else:
self.is_started = self.start()
self.is_started = await self.start()
if self.is_started:
self.log.info("The debugger has started")
else:
Expand Down
9 changes: 4 additions & 5 deletions ipykernel/eventloops.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,12 @@ def loop_asyncio(kernel):
loop._should_close = False # type:ignore[attr-defined]

# pause eventloop when there's an event on a zmq socket
def process_stream_events(stream):
def process_stream_events(socket):
"""fall back to main loop when there's a socket event"""
if stream.flush(limit=1):
loop.stop()
loop.stop()

notifier = partial(process_stream_events, kernel.shell_stream)
loop.add_reader(kernel.shell_stream.getsockopt(zmq.FD), notifier)
notifier = partial(process_stream_events, kernel.shell_socket)
loop.add_reader(kernel.shell_socket.getsockopt(zmq.FD), notifier)
loop.call_soon(notifier)

while True:
Expand Down
5 changes: 2 additions & 3 deletions ipykernel/inprocess/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class BlockingInProcessKernelClient(InProcessKernelClient):
iopub_channel_class = Type(BlockingInProcessChannel) # type:ignore[arg-type]
stdin_channel_class = Type(BlockingInProcessStdInChannel) # type:ignore[arg-type]

def wait_for_ready(self):
async def wait_for_ready(self):
"""Wait for kernel info reply on shell channel."""
while True:
self.kernel_info()
await self.kernel_info()
try:
msg = self.shell_channel.get_msg(block=True, timeout=1)
except Empty:
Expand All @@ -103,6 +103,5 @@ def wait_for_ready(self):
while True:
try:
msg = self.iopub_channel.get_msg(block=True, timeout=0.2)
print(msg["msg_type"])
except Empty:
break
Loading

0 comments on commit 8ff47dc

Please sign in to comment.