Skip to content

Commit

Permalink
Implement sub-shells
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 5, 2023
1 parent 817258d commit 83e3502
Show file tree
Hide file tree
Showing 16 changed files with 637 additions and 547 deletions.
54 changes: 54 additions & 0 deletions ipykernel/athread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import asyncio
import threading

import janus


class AThread(threading.Thread):
"""A thread that can run async tasks."""

def __init__(self, name, awaitables=None):
super().__init__(name=name, daemon=True)
self._aws = list(awaitables) if awaitables is not None else []
self._lock = threading.Lock()
self.__initialized = False
self._stopped = False

def run(self):
asyncio.run(self._main())

async def _main(self):
with self._lock:
if self._stopped:
return
self._queue = janus.Queue()
self.__initialized = True
self._tasks = [asyncio.create_task(aw) for aw in self._aws]

while True:
try:
aw = await self._queue.async_q.get()
except BaseException:
break
if aw is None:
break
self._tasks.append(asyncio.create_task(aw))

for task in self._tasks:
task.cancel()

def create_task(self, awaitable):
"""Create a task in the thread (thread-safe)."""
with self._lock:
if self.__initialized:
self._queue.sync_q.put(awaitable)
else:
self._aws.append(awaitable)

def stop(self):
"""Stop the thread (thread-safe)."""
with self._lock:
if self.__initialized:
self._queue.sync_q.put(None)
else:
self._stopped = True
27 changes: 4 additions & 23 deletions ipykernel/control.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,11 @@
"""A thread for a control channel."""
from threading import Thread
from .athread import AThread

from tornado.ioloop import IOLoop


class ControlThread(Thread):
class ControlThread(AThread):
"""A thread for a control channel."""

def __init__(self, **kwargs):
def __init__(self):
"""Initialize the thread."""
Thread.__init__(self, name="Control", **kwargs)
self.io_loop = IOLoop(make_current=False)
super().__init__(name="Control")
self.pydev_do_not_trace = True
self.is_pydev_daemon_thread = True

def run(self):
"""Run the thread."""
self.name = "Control"
try:
self.io_loop.start()
finally:
self.io_loop.close()

def stop(self):
"""Stop the thread.
This method is threadsafe.
"""
self.io_loop.add_callback(self.io_loop.stop)
48 changes: 26 additions & 22 deletions ipykernel/debugger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Debugger implementation for the IPython kernel."""
import asyncio
import os
import re
import sys
Expand All @@ -7,8 +8,6 @@
import zmq
from IPython.core.getipython import get_ipython
from IPython.core.inputtransformer2 import leading_empty_lines
from tornado.locks import Event
from tornado.queues import Queue
from zmq.utils import jsonapi

try:
Expand Down Expand Up @@ -116,7 +115,7 @@ def __init__(self, event_callback, log):
self.tcp_buffer = ""
self._reset_tcp_pos()
self.event_callback = event_callback
self.message_queue: Queue[t.Any] = Queue()
self.message_queue: asyncio.Queue[t.Any] = asyncio.Queue()
self.log = log

def _reset_tcp_pos(self):
Expand Down Expand Up @@ -192,17 +191,17 @@ async def get_message(self):
class DebugpyClient:
"""A client for debugpy."""

def __init__(self, log, debugpy_stream, event_callback):
def __init__(self, log, debugpy_socket, event_callback):
"""Initialize the client."""
self.log = log
self.debugpy_stream = debugpy_stream
self.debugpy_socket = debugpy_socket
self.event_callback = event_callback
self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
self.debugpy_host = "127.0.0.1"
self.debugpy_port = -1
self.routing_id = None
self.wait_for_attach = True
self.init_event = Event()
self.init_event = asyncio.Event()
self.init_event_seq = -1

def _get_endpoint(self):
Expand All @@ -215,9 +214,9 @@ def _forward_event(self, msg):
self.init_event_seq = msg["seq"]
self.event_callback(msg)

def _send_request(self, msg):
async def _send_request(self, msg):
if self.routing_id is None:
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
content = jsonapi.dumps(
msg,
default=json_default,
Expand All @@ -232,7 +231,7 @@ def _send_request(self, msg):
self.log.debug("DEBUGPYCLIENT:")
self.log.debug(self.routing_id)
self.log.debug(buf)
self.debugpy_stream.send_multipart((self.routing_id, buf))
await self.debugpy_socket.send_multipart((self.routing_id, buf))

async def _wait_for_response(self):
# Since events are never pushed to the message_queue
Expand All @@ -250,7 +249,7 @@ async def _handle_init_sequence(self):
"seq": int(self.init_event_seq) + 1,
"command": "configurationDone",
}
self._send_request(configurationDone)
await self._send_request(configurationDone)

# 3] Waits for configurationDone response
await self._wait_for_response()
Expand All @@ -262,7 +261,7 @@ async def _handle_init_sequence(self):
def get_host_port(self):
"""Get the host debugpy port."""
if self.debugpy_port == -1:
socket = self.debugpy_stream.socket
socket = self.debugpy_socket
socket.bind_to_random_port("tcp://" + self.debugpy_host)
self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8")
socket.unbind(self.endpoint)
Expand All @@ -272,14 +271,14 @@ def get_host_port(self):

def connect_tcp_socket(self):
"""Connect to the tcp socket."""
self.debugpy_stream.socket.connect(self._get_endpoint())
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
self.debugpy_socket.connect(self._get_endpoint())
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)

def disconnect_tcp_socket(self):
"""Disconnect from the tcp socket."""
self.debugpy_stream.socket.disconnect(self._get_endpoint())
self.debugpy_socket.disconnect(self._get_endpoint())
self.routing_id = None
self.init_event = Event()
self.init_event = asyncio.Event()
self.init_event_seq = -1
self.wait_for_attach = True

Expand All @@ -289,7 +288,7 @@ def receive_dap_frame(self, frame):

async def send_dap_request(self, msg):
"""Send a dap request."""
self._send_request(msg)
await self._send_request(msg)
if self.wait_for_attach and msg["command"] == "attach":
rep = await self._handle_init_sequence()
self.wait_for_attach = False
Expand Down Expand Up @@ -319,17 +318,17 @@ class Debugger:
static_debug_msg_types = ["debugInfo", "inspectVariables", "richInspectVariables", "modules"]

def __init__(
self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True
self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True
):
"""Initialize the debugger."""
self.log = log
self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event)
self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event)
self.shell_socket = shell_socket
self.session = session
self.is_started = False
self.event_callback = event_callback
self.just_my_code = just_my_code
self.stopped_queue: Queue[t.Any] = Queue()
self.stopped_queue: asyncio.Queue[t.Any] = asyncio.Queue()

self.started_debug_handlers = {}
for msg_type in Debugger.started_debug_msg_types:
Expand Down Expand Up @@ -406,7 +405,7 @@ async def handle_stopped_event(self):
def tcp_client(self):
return self.debugpy_client

def start(self):
async def start(self):
"""Start the debugger."""
if not self.debugpy_initialized:
tmp_dir = get_tmp_directory()
Expand All @@ -424,7 +423,12 @@ def start(self):
(self.shell_socket.getsockopt(ROUTING_ID)),
)

ident, msg = self.session.recv(self.shell_socket, mode=0)
msg = await self.shell_socket.recv_multipart()
idents, msg = self.session.feed_identities(msg, copy=True)
try:
msg = self.session.deserialize(msg, content=True, copy=True)
except BaseException:
self.log.error("Invalid Message", exc_info=True)
self.debugpy_initialized = msg["content"]["status"] == "ok"

# Don't remove leading empty lines when debugging so the breakpoints are correctly positioned
Expand Down Expand Up @@ -685,7 +689,7 @@ async def process_request(self, message):
if self.is_started:
self.log.info("The debugger has already started")
else:
self.is_started = self.start()
self.is_started = await self.start()
if self.is_started:
self.log.info("The debugger has started")
else:
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class InProcessKernel(IPythonKernel):
_underlying_iopub_socket = Instance(DummySocket, ())
iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment]

shell_stream = Instance(DummySocket, ())
# shell_stream = Instance(DummySocket, ())

@default("iopub_thread")
def _default_iopub_thread(self):
Expand Down
Loading

0 comments on commit 83e3502

Please sign in to comment.