diff --git a/docs/api/ipykernel.rst b/docs/api/ipykernel.rst index 2e1cf20d..dd46d084 100644 --- a/docs/api/ipykernel.rst +++ b/docs/api/ipykernel.rst @@ -110,6 +110,30 @@ Submodules :show-inheritance: +.. automodule:: ipykernel.shellchannel + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: ipykernel.subshell + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: ipykernel.subshell_manager + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: ipykernel.thread + :members: + :undoc-members: + :show-inheritance: + + .. automodule:: ipykernel.trio_runner :members: :undoc-members: diff --git a/ipykernel/control.py b/ipykernel/control.py index a70377c0..21d6d996 100644 --- a/ipykernel/control.py +++ b/ipykernel/control.py @@ -1,40 +1,11 @@ """A thread for a control channel.""" -from threading import Event, Thread -from anyio import create_task_group, run, to_thread +from .thread import CONTROL_THREAD_NAME, BaseThread -CONTROL_THREAD_NAME = "Control" - -class ControlThread(Thread): +class ControlThread(BaseThread): """A thread for a control channel.""" def __init__(self, **kwargs): """Initialize the thread.""" - Thread.__init__(self, name=CONTROL_THREAD_NAME, **kwargs) - self.pydev_do_not_trace = True - self.is_pydev_daemon_thread = True - self.__stop = Event() - self._task = None - - def set_task(self, task): - self._task = task - - def run(self): - """Run the thread.""" - self.name = CONTROL_THREAD_NAME - run(self._main) - - async def _main(self): - async with create_task_group() as tg: - if self._task is not None: - tg.start_soon(self._task) - await to_thread.run_sync(self.__stop.wait) - tg.cancel_scope.cancel() - - def stop(self): - """Stop the thread. - - This method is threadsafe. - """ - self.__stop.set() + super().__init__(name=CONTROL_THREAD_NAME, **kwargs) diff --git a/ipykernel/heartbeat.py b/ipykernel/heartbeat.py index d2890f67..9816959d 100644 --- a/ipykernel/heartbeat.py +++ b/ipykernel/heartbeat.py @@ -32,7 +32,7 @@ def __init__(self, context, addr=None): """Initialize the heartbeat thread.""" if addr is None: addr = ("tcp", localhost(), 0) - Thread.__init__(self, name="Heartbeat") + super().__init__(name="Heartbeat") self.context = context self.transport, self.ip, self.port = addr self.original_port = self.port diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 6280905c..beca44b1 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -40,7 +40,7 @@ class _IOPubThread(Thread): def __init__(self, tasks, **kwargs): """Initialize the thread.""" - Thread.__init__(self, name="IOPub", **kwargs) + super().__init__(name="IOPub", **kwargs) self._tasks = tasks self.pydev_do_not_trace = True self.is_pydev_daemon_thread = True @@ -170,10 +170,10 @@ async def _handle_event(self): for _ in range(n_events): event_f = self._events.popleft() event_f() - except Exception as e: + except Exception: if self.thread.__stop.is_set(): return - raise e + raise def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" @@ -202,10 +202,10 @@ async def _handle_pipe_msgs(self): try: while True: await self._handle_pipe_msg() - except Exception as e: + except Exception: if self.thread.__stop.is_set(): return - raise e + raise async def _handle_pipe_msg(self, msg=None): """handle a pipe message from a subprocess""" diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index c02c3cf3..2f462af4 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -53,6 +53,7 @@ from .iostream import IOPubThread from .ipkernel import IPythonKernel from .parentpoller import ParentPollerUnix, ParentPollerWindows +from .shellchannel import ShellChannelThread from .zmqshell import ZMQInteractiveShell # ----------------------------------------------------------------------------- @@ -143,6 +144,7 @@ class IPKernelApp(BaseIPythonApplication, InteractiveShellApp, ConnectionFileMix iopub_socket = Any() iopub_thread = Any() control_thread = Any() + shell_channel_thread = Any() _ports = Dict() @@ -367,6 +369,7 @@ def init_control(self, context): self.control_socket.router_handover = 1 self.control_thread = ControlThread(daemon=True) + self.shell_channel_thread = ShellChannelThread(context, self.shell_socket, daemon=True) def init_iopub(self, context): """Initialize the iopub channel.""" @@ -406,6 +409,10 @@ def close(self): self.log.debug("Closing control thread") self.control_thread.stop() self.control_thread.join() + if self.shell_channel_thread and self.shell_channel_thread.is_alive(): + self.log.debug("Closing shell channel thread") + self.shell_channel_thread.stop() + self.shell_channel_thread.join() if self.debugpy_socket and not self.debugpy_socket.closed: self.debugpy_socket.close() @@ -562,6 +569,7 @@ def init_kernel(self): debug_shell_socket=self.debug_shell_socket, shell_socket=self.shell_socket, control_thread=self.control_thread, + shell_channel_thread=self.shell_channel_thread, iopub_thread=self.iopub_thread, iopub_socket=self.iopub_socket, stdin_socket=self.stdin_socket, diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index 050f57be..99358f9b 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -18,7 +18,7 @@ from datetime import datetime from signal import SIGINT, SIGTERM, Signals -from .control import CONTROL_THREAD_NAME +from .thread import CONTROL_THREAD_NAME if sys.platform != "win32": from signal import SIGKILL @@ -103,6 +103,7 @@ class Kernel(SingletonConfigurable): debug_shell_socket = Any() control_thread = Any() + shell_channel_thread = Any() iopub_socket = Any() iopub_thread = Any() stdin_socket = Any() @@ -226,6 +227,9 @@ def _parent_header(self): "abort_request", "debug_request", "usage_request", + "create_subshell_request", + "delete_subshell_request", + "list_subshell_request", ] _eventloop_set: Event = Event() @@ -258,16 +262,17 @@ async def process_control(self): try: while True: await self.process_control_message() - except BaseException as e: - print("base exception") + except BaseException: if self.control_stop.is_set(): return - raise e + raise async def process_control_message(self, msg=None): """dispatch control requests""" assert self.control_socket is not None assert self.session is not None + assert self.control_thread is None or threading.current_thread() == self.control_thread + msg = msg or await self.control_socket.recv_multipart() copy = not isinstance(msg[0], zmq.Message) idents, msg = self.session.feed_identities(msg, copy=copy) @@ -356,28 +361,95 @@ async def advance_eventloop(): def _message_counter_default(self): return itertools.count() - async def shell_main(self): - async with create_task_group() as tg: - tg.start_soon(self.process_shell) - await to_thread.run_sync(self.shell_stop.wait) - tg.cancel_scope.cancel() + async def shell_channel_thread_main(self): + """Main loop for shell channel thread. + + Listen for incoming messages on kernel shell_socket. For each message + received, extract the subshell_id from the message header and forward the + message to the correct subshell via ZMQ inproc pair socket. + """ + assert self.shell_socket is not None + assert self.session is not None + assert self.shell_channel_thread is not None + assert threading.current_thread() == self.shell_channel_thread - async def process_shell(self): try: while True: - await self.process_shell_message() - except BaseException as e: + msg = await self.shell_socket.recv_multipart() + + # Deserialize whole message just to get subshell_id. + # Keep original message to send to subshell_id unmodified. + # Ideally only want to deserialize message once. + copy = not isinstance(msg[0], zmq.Message) + _, msg2 = self.session.feed_identities(msg, copy=copy) + try: + msg3 = self.session.deserialize(msg2, content=False, copy=copy) + subshell_id = msg3["header"].get("subshell_id") + + # Find inproc pair socket to use to send message to correct subshell. + socket = self.shell_channel_thread.manager.get_shell_channel_socket(subshell_id) + assert socket is not None + socket.send_multipart(msg, copy=False) + except Exception: + self.log.error("Invalid message", exc_info=True) # noqa: G201 + except BaseException: if self.shell_stop.is_set(): return - raise e + raise - async def process_shell_message(self, msg=None): - assert self.shell_socket is not None + async def shell_main(self, subshell_id: str | None): + """Main loop for a single subshell.""" + if self._supports_kernel_subshells: + if subshell_id is None: + assert threading.current_thread() == threading.main_thread() + else: + assert threading.current_thread() not in ( + self.shell_channel_thread, + threading.main_thread(), + ) + # Inproc pair socket that this subshell uses to talk to shell channel thread. + socket = self.shell_channel_thread.manager.get_other_socket(subshell_id) + else: + assert subshell_id is None + assert threading.current_thread() == threading.main_thread() + socket = self.shell_socket + + async with create_task_group() as tg: + tg.start_soon(self.process_shell, socket) + if subshell_id is None: + # Main subshell. + await to_thread.run_sync(self.shell_stop.wait) + tg.cancel_scope.cancel() + + async def process_shell(self, socket=None): + # socket=None is valid if kernel subshells are not supported. + try: + while True: + await self.process_shell_message(socket=socket) + except BaseException: + if self.shell_stop.is_set(): + return + raise + + async def process_shell_message(self, msg=None, socket=None): + # If socket is None kernel subshells are not supported so use socket=shell_socket. + # If msg is set, process that message. + # If msg is None, await the next message to arrive on the socket. assert self.session is not None + if self._supports_kernel_subshells: + assert threading.current_thread() not in ( + self.control_thread, + self.shell_channel_thread, + ) + assert socket is not None + else: + assert threading.current_thread() == threading.main_thread() + assert socket is None + socket = self.shell_socket - no_msg = msg is None if self._is_test else not await self.shell_socket.poll(0) + no_msg = msg is None if self._is_test else not await socket.poll(0) + msg = msg or await socket.recv_multipart(copy=False) - msg = msg or await self.shell_socket.recv_multipart() received_time = time.monotonic() copy = not isinstance(msg[0], zmq.Message) idents, msg = self.session.feed_identities(msg, copy=copy) @@ -401,7 +473,7 @@ async def process_shell_message(self, msg=None): elif received_time - self._aborted_time > self.stop_on_error_timeout: self._aborting = False if self._aborting: - await self._send_abort_reply(self.shell_socket, msg, idents) + await self._send_abort_reply(socket, msg, idents) self._publish_status("idle", "shell") return @@ -411,7 +483,7 @@ async def process_shell_message(self, msg=None): self.log.debug("\n*** MESSAGE TYPE:%s***", msg_type) self.log.debug(" Content: %s\n --->\n ", msg["content"]) - if not await self.should_handle(self.shell_socket, msg, idents): + if not await self.should_handle(socket, msg, idents): return handler = self.shell_handlers.get(msg_type) @@ -424,7 +496,7 @@ async def process_shell_message(self, msg=None): except Exception: self.log.debug("Unable to signal in pre_handler_hook:", exc_info=True) try: - result = handler(self.shell_socket, idents, msg) + result = handler(socket, idents, msg) if inspect.isawaitable(result): await result except Exception: @@ -465,7 +537,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: self.control_stop = threading.Event() if not self._is_test and self.control_socket is not None: if self.control_thread: - self.control_thread.set_task(self.control_main) + self.control_thread.add_task(self.control_main) self.control_thread.start() else: tg.start_soon(self.control_main) @@ -474,8 +546,19 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: self.shell_is_awaiting = False self.shell_is_blocking = False self.shell_stop = threading.Event() - if not self._is_test and self.shell_socket is not None: - tg.start_soon(self.shell_main) + + if self.shell_channel_thread: + tg.start_soon(self.shell_main, None) + + # Assign tasks to and start shell channel thread. + manager = self.shell_channel_thread.manager + self.shell_channel_thread.add_task(self.shell_channel_thread_main) + self.shell_channel_thread.add_task(manager.listen_from_control, self.shell_main) + self.shell_channel_thread.add_task(manager.listen_from_subshells) + self.shell_channel_thread.start() + else: + if not self._is_test and self.shell_socket is not None: + tg.start_soon(self.shell_main, None) def stop(self): if not self._eventloop_set.is_set(): @@ -635,8 +718,7 @@ async def execute_request(self, socket, ident, parent): cell_meta = parent.get("metadata", {}) cell_id = cell_meta.get("cellId") except Exception: - self.log.error("Got bad msg: ") - self.log.error("%s", parent) + self.log.error("Got bad msg from parent: %s", parent) return stop_on_error = content.get("stop_on_error", True) @@ -687,8 +769,8 @@ async def execute_request(self, socket, ident, parent): reply_msg = self.session.send( socket, "execute_reply", - reply_content, - parent, + content=reply_content, + parent=parent, metadata=metadata, ident=ident, ) @@ -806,14 +888,18 @@ async def connect_request(self, socket, ident, parent): @property def kernel_info(self): - return { + info = { "protocol_version": kernel_protocol_version, "implementation": self.implementation, "implementation_version": self.implementation_version, "language_info": self.language_info, "banner": self.banner, "help_links": self.help_links, + "supported_features": [], } + if self._supports_kernel_subshells: + info["supported_features"] = ["kernel subshells"] + return info async def kernel_info_request(self, socket, ident, parent): """Handle a kernel info request.""" @@ -984,6 +1070,62 @@ async def usage_request(self, socket, ident, parent): async def do_debug_request(self, msg): raise NotImplementedError + # --------------------------------------------------------------------------- + # Subshell control message handlers + # --------------------------------------------------------------------------- + + async def create_subshell_request(self, socket, ident, parent) -> None: + if not self.session: + return + if not self._supports_kernel_subshells: + self.log.error("Subshells are not supported by this kernel") + return + + # This should only be called in the control thread if it exists. + # Request is passed to shell channel thread to process. + other_socket = self.shell_channel_thread.manager.get_control_other_socket() + await other_socket.send_json({"type": "create"}) + reply = await other_socket.recv_json() + + self.session.send(socket, "create_subshell_reply", reply, parent, ident) + + async def delete_subshell_request(self, socket, ident, parent) -> None: + if not self.session: + return + if not self._supports_kernel_subshells: + self.log.error("KERNEL SUBSHELLS NOT SUPPORTED") + return + + try: + content = parent["content"] + subshell_id = content["subshell_id"] + except Exception: + self.log.error("Got bad msg from parent: %s", parent) + return + + # This should only be called in the control thread if it exists. + # Request is passed to shell channel thread to process. + other_socket = self.shell_channel_thread.manager.get_control_other_socket() + await other_socket.send_json({"type": "delete", "subshell_id": subshell_id}) + reply = await other_socket.recv_json() + + self.session.send(socket, "delete_subshell_reply", reply, parent, ident) + + async def list_subshell_request(self, socket, ident, parent) -> None: + if not self.session: + return + if not self._supports_kernel_subshells: + self.log.error("Subshells are not supported by this kernel") + return + + # This should only be called in the control thread if it exists. + # Request is passed to shell channel thread to process. + other_socket = self.shell_channel_thread.manager.get_control_other_socket() + await other_socket.send_json({"type": "list"}) + reply = await other_socket.recv_json() + + self.session.send(socket, "list_subshell_reply", reply, parent, ident) + # --------------------------------------------------------------------------- # Engine methods (DEPRECATED) # --------------------------------------------------------------------------- @@ -1274,3 +1416,7 @@ async def _at_shutdown(self): ident=self._topic("shutdown"), ) self.log.debug("%s", self._shutdown_message) + + @property + def _supports_kernel_subshells(self): + return self.shell_channel_thread is not None diff --git a/ipykernel/shellchannel.py b/ipykernel/shellchannel.py new file mode 100644 index 00000000..bc0459c4 --- /dev/null +++ b/ipykernel/shellchannel.py @@ -0,0 +1,34 @@ +"""A thread for a shell channel.""" +import zmq.asyncio + +from .subshell_manager import SubshellManager +from .thread import SHELL_CHANNEL_THREAD_NAME, BaseThread + + +class ShellChannelThread(BaseThread): + """A thread for a shell channel. + + Communicates with shell/subshell threads via pairs of ZMQ inproc sockets. + """ + + def __init__(self, context: zmq.asyncio.Context, shell_socket: zmq.asyncio.Socket, **kwargs): + """Initialize the thread.""" + super().__init__(name=SHELL_CHANNEL_THREAD_NAME, **kwargs) + self._manager: SubshellManager | None = None + self._context = context + self._shell_socket = shell_socket + + @property + def manager(self) -> SubshellManager: + # Lazy initialisation. + if self._manager is None: + self._manager = SubshellManager(self._context, self._shell_socket) + return self._manager + + def run(self) -> None: + """Run the thread.""" + try: + super().run() + finally: + if self._manager: + self._manager.close() diff --git a/ipykernel/subshell.py b/ipykernel/subshell.py new file mode 100644 index 00000000..18e15ab3 --- /dev/null +++ b/ipykernel/subshell.py @@ -0,0 +1,36 @@ +"""A thread for a subshell.""" + +from threading import current_thread + +import zmq.asyncio + +from .thread import BaseThread + + +class SubshellThread(BaseThread): + """A thread for a subshell.""" + + def __init__(self, subshell_id: str, **kwargs): + """Initialize the thread.""" + super().__init__(name=f"subshell-{subshell_id}", **kwargs) + + # Inproc PAIR socket, for communication with shell channel thread. + self._pair_socket: zmq.asyncio.Socket | None = None + + async def create_pair_socket(self, context: zmq.asyncio.Context, address: str) -> None: + """Create inproc PAIR socket, for communication with shell channel thread. + + Should be called from this thread, so usually via add_task before the + thread is started. + """ + assert current_thread() == self + self._pair_socket = context.socket(zmq.PAIR) + self._pair_socket.connect(address) + + def run(self) -> None: + try: + super().run() + finally: + if self._pair_socket is not None: + self._pair_socket.close() + self._pair_socket = None diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py new file mode 100644 index 00000000..805d6f81 --- /dev/null +++ b/ipykernel/subshell_manager.py @@ -0,0 +1,283 @@ +"""Manager of subshells in a kernel.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import typing as t +import uuid +from dataclasses import dataclass +from threading import Lock, current_thread, main_thread + +import zmq +import zmq.asyncio +from anyio import create_memory_object_stream, create_task_group + +from .subshell import SubshellThread +from .thread import SHELL_CHANNEL_THREAD_NAME + + +@dataclass +class Subshell: + thread: SubshellThread + shell_channel_socket: zmq.asyncio.Socket + + +class SubshellManager: + """A manager of subshells. + + Controls the lifetimes of subshell threads and their associated ZMQ sockets. + Runs mostly in the shell channel thread. + + Care needed with threadsafe access here. All write access to the cache occurs in + the shell channel thread so there is only ever one write access at any one time. + Reading of cache information can be performed by other threads, so all reads are + protected by a lock so that they are atomic. + + Sending reply messages via the shell_socket is wrapped by another lock to protect + against multiple subshells attempting to send at the same time. + """ + + def __init__(self, context: zmq.asyncio.Context, shell_socket: zmq.asyncio.Socket): + assert current_thread() == main_thread() + + self._context: zmq.asyncio.Context = context + self._shell_socket = shell_socket + self._cache: dict[str, Subshell] = {} + self._lock_cache = Lock() + self._lock_shell_socket = Lock() + + # Inproc pair sockets for control channel and main shell (parent subshell). + # Each inproc pair has a "shell_channel" socket used in the shell channel + # thread, and an "other" socket used in the other thread. + self._control_shell_channel_socket = self._create_inproc_pair_socket("control", True) + self._control_other_socket = self._create_inproc_pair_socket("control", False) + self._parent_shell_channel_socket = self._create_inproc_pair_socket(None, True) + self._parent_other_socket = self._create_inproc_pair_socket(None, False) + + # anyio memory object stream for async queue-like communication between tasks. + # Used by _create_subshell to tell listen_from_subshells to spawn a new task. + self._send_stream, self._receive_stream = create_memory_object_stream[str]() + + def close(self) -> None: + """Stop all subshells and close all resources.""" + assert current_thread().name == SHELL_CHANNEL_THREAD_NAME + + self._send_stream.close() + self._receive_stream.close() + + for socket in ( + self._control_shell_channel_socket, + self._control_other_socket, + self._parent_shell_channel_socket, + self._parent_other_socket, + ): + if socket is not None: + socket.close() + + with self._lock_cache: + while True: + try: + _, subshell = self._cache.popitem() + except KeyError: + break + self._stop_subshell(subshell) + + def get_control_other_socket(self) -> zmq.asyncio.Socket: + return self._control_other_socket + + def get_other_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + """Return the other inproc pair socket for a subshell. + + This socket is accessed from the subshell thread. + """ + if subshell_id is None: + return self._parent_other_socket + with self._lock_cache: + socket = self._cache[subshell_id].thread._pair_socket + assert socket is not None + return socket + + def get_shell_channel_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + """Return the shell channel inproc pair socket for a subshell. + + This socket is accessed from the shell channel thread. + """ + if subshell_id is None: + return self._parent_shell_channel_socket + with self._lock_cache: + return self._cache[subshell_id].shell_channel_socket + + def list_subshell(self) -> list[str]: + """Return list of current subshell ids. + + Can be called by any subshell using %subshell magic. + """ + with self._lock_cache: + return list(self._cache) + + async def listen_from_control(self, subshell_task: t.Any) -> None: + """Listen for messages on the control inproc socket, handle those messages and + return replies on the same socket. Runs in the shell channel thread. + """ + assert current_thread().name == SHELL_CHANNEL_THREAD_NAME + + socket = self._control_shell_channel_socket + while True: + request = await socket.recv_json() # type: ignore[misc] + reply = await self._process_control_request(request, subshell_task) + await socket.send_json(reply) # type: ignore[func-returns-value] + + async def listen_from_subshells(self) -> None: + """Listen for reply messages on inproc sockets of all subshells and resend + those messages to the client via the shell_socket. + + Runs in the shell channel thread. + """ + assert current_thread().name == SHELL_CHANNEL_THREAD_NAME + + async with create_task_group() as tg: + tg.start_soon(self._listen_for_subshell_reply, None) + async for subshell_id in self._receive_stream: + tg.start_soon(self._listen_for_subshell_reply, subshell_id) + + def subshell_id_from_thread_id(self, thread_id: int) -> str | None: + """Return subshell_id of the specified thread_id. + + Raises RuntimeError if thread_id is not the main shell or a subshell. + + Only used by %subshell magic so does not have to be fast/cached. + """ + with self._lock_cache: + if thread_id == main_thread().ident: + return None + for id, subshell in self._cache.items(): + if subshell.thread.ident == thread_id: + return id + msg = f"Thread id {thread_id!r} does not correspond to a subshell of this kernel" + raise RuntimeError(msg) + + def _create_inproc_pair_socket( + self, name: str | None, shell_channel_end: bool + ) -> zmq.asyncio.Socket: + """Create and return a single ZMQ inproc pair socket.""" + address = self._get_inproc_socket_address(name) + socket = self._context.socket(zmq.PAIR) + if shell_channel_end: + socket.bind(address) + else: + socket.connect(address) + return socket + + async def _create_subshell(self, subshell_task: t.Any) -> str: + """Create and start a new subshell thread.""" + assert current_thread().name == SHELL_CHANNEL_THREAD_NAME + + subshell_id = str(uuid.uuid4()) + thread = SubshellThread(subshell_id) + + with self._lock_cache: + assert subshell_id not in self._cache + shell_channel_socket = self._create_inproc_pair_socket(subshell_id, True) + self._cache[subshell_id] = Subshell(thread, shell_channel_socket) + + # Tell task running listen_from_subshells to create a new task to listen for + # reply messages from the new subshell to resend to the client. + await self._send_stream.send(subshell_id) + + address = self._get_inproc_socket_address(subshell_id) + thread.add_task(thread.create_pair_socket, self._context, address) + thread.add_task(subshell_task, subshell_id) + thread.start() + + return subshell_id + + def _delete_subshell(self, subshell_id: str) -> None: + """Delete subshell identified by subshell_id. + + Raises key error if subshell_id not in cache. + """ + assert current_thread().name == SHELL_CHANNEL_THREAD_NAME + + with self._lock_cache: + subshell = self._cache.pop(subshell_id) + + self._stop_subshell(subshell) + + def _get_inproc_socket_address(self, name: str | None) -> str: + full_name = f"subshell-{name}" if name else "subshell" + return f"inproc://{full_name}" + + def _get_shell_channel_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + if subshell_id is None: + return self._parent_shell_channel_socket + with self._lock_cache: + return self._cache[subshell_id].shell_channel_socket + + def _is_subshell(self, subshell_id: str | None) -> bool: + if subshell_id is None: + return True + with self._lock_cache: + return subshell_id in self._cache + + async def _listen_for_subshell_reply(self, subshell_id: str | None) -> None: + """Listen for reply messages on specified subshell inproc socket and + resend to the client via the shell_socket. + + Runs in the shell channel thread. + """ + assert current_thread().name == SHELL_CHANNEL_THREAD_NAME + + shell_channel_socket = self._get_shell_channel_socket(subshell_id) + + try: + while True: + msg = await shell_channel_socket.recv_multipart(copy=False) + with self._lock_shell_socket: + await self._shell_socket.send_multipart(msg) + except BaseException: + if not self._is_subshell(subshell_id): + # Subshell no longer exists so exit gracefully + return + raise + + async def _process_control_request( + self, request: dict[str, t.Any], subshell_task: t.Any + ) -> dict[str, t.Any]: + """Process a control request message received on the control inproc + socket and return the reply. Runs in the shell channel thread. + """ + assert current_thread().name == SHELL_CHANNEL_THREAD_NAME + + try: + type = request["type"] + reply: dict[str, t.Any] = {"status": "ok"} + + if type == "create": + reply["subshell_id"] = await self._create_subshell(subshell_task) + elif type == "delete": + subshell_id = request["subshell_id"] + self._delete_subshell(subshell_id) + elif type == "list": + reply["subshell_id"] = self.list_subshell() + else: + msg = f"Unrecognised message type {type!r}" + raise RuntimeError(msg) + except BaseException as err: + reply = { + "status": "error", + "evalue": str(err), + } + return reply + + def _stop_subshell(self, subshell: Subshell) -> None: + """Stop a subshell thread and close all of its resources.""" + assert current_thread().name == SHELL_CHANNEL_THREAD_NAME + + thread = subshell.thread + if thread.is_alive(): + thread.stop() + thread.join() + + # Closing the shell_channel_socket terminates the task that is listening on it. + subshell.shell_channel_socket.close() diff --git a/ipykernel/thread.py b/ipykernel/thread.py new file mode 100644 index 00000000..a63011de --- /dev/null +++ b/ipykernel/thread.py @@ -0,0 +1,42 @@ +"""Base class for threads.""" +import typing as t +from threading import Event, Thread + +from anyio import create_task_group, run, to_thread + +CONTROL_THREAD_NAME = "Control" +SHELL_CHANNEL_THREAD_NAME = "Shell channel" + + +class BaseThread(Thread): + """Base class for threads.""" + + def __init__(self, **kwargs): + """Initialize the thread.""" + super().__init__(**kwargs) + self.pydev_do_not_trace = True + self.is_pydev_daemon_thread = True + self.__stop = Event() + self._tasks_and_args: t.List[t.Tuple[t.Any, t.Any]] = [] + + def add_task(self, task: t.Any, *args: t.Any) -> None: + # May only add tasks before the thread is started. + self._tasks_and_args.append((task, args)) + + def run(self) -> t.Any: + """Run the thread.""" + return run(self._main) + + async def _main(self) -> None: + async with create_task_group() as tg: + for task, args in self._tasks_and_args: + tg.start_soon(task, *args) + await to_thread.run_sync(self.__stop.wait) + tg.cancel_scope.cancel() + + def stop(self) -> None: + """Stop the thread. + + This method is threadsafe. + """ + self.__stop.set() diff --git a/ipykernel/zmqshell.py b/ipykernel/zmqshell.py index bc99d000..3f97e817 100644 --- a/ipykernel/zmqshell.py +++ b/ipykernel/zmqshell.py @@ -16,9 +16,9 @@ import os import sys +import threading import warnings from pathlib import Path -from threading import local from IPython.core import page, payloadpage from IPython.core.autocall import ZMQExitAutocall @@ -69,7 +69,7 @@ def _flush_streams(self): @default("_thread_local") def _default_thread_local(self): """Initialize our thread local storage""" - return local() + return threading.local() @property def _hooks(self): @@ -439,6 +439,39 @@ def autosave(self, arg_s): else: print("Autosave disabled") + @line_magic + def subshell(self, arg_s): + """ + List all current subshells + """ + from ipykernel.kernelapp import IPKernelApp + + if not IPKernelApp.initialized(): + msg = "Not in a running Kernel" + raise RuntimeError(msg) + + app = IPKernelApp.instance() + kernel = app.kernel + + if not getattr(kernel, "_supports_kernel_subshells", False): + print("Kernel does not support subshells") + return + + thread_id = threading.current_thread().ident + manager = kernel.shell_channel_thread.manager + try: + subshell_id = manager.subshell_id_from_thread_id(thread_id) + except RuntimeError: + subshell_id = "unknown" + subshell_id_list = manager.list_subshell() + + print(f"subshell id: {subshell_id}") + print(f"thread id: {thread_id}") + print(f"main thread id: {threading.main_thread().ident}") + print(f"pid: {os.getpid()}") + print(f"thread count: {threading.active_count()}") + print(f"subshell list: {subshell_id_list}") + class ZMQInteractiveShell(InteractiveShell): """A subclass of InteractiveShell for ZMQ.""" diff --git a/tests/test_ipkernel_direct.py b/tests/test_ipkernel_direct.py index cea2ec99..dfd0445c 100644 --- a/tests/test_ipkernel_direct.py +++ b/tests/test_ipkernel_direct.py @@ -27,6 +27,10 @@ async def test_properties(ipkernel: IPythonKernel) -> None: async def test_direct_kernel_info_request(ipkernel): reply = await ipkernel.test_shell_message("kernel_info_request", {}) assert reply["header"]["msg_type"] == "kernel_info_reply" + assert ( + "supported_features" not in reply["content"] + or "kernel subshells" not in reply["content"]["supported_features"] + ) async def test_direct_execute_request(ipkernel: MockIPyKernel) -> None: diff --git a/tests/test_kernel_direct.py b/tests/test_kernel_direct.py index ea3c6fe7..50801b03 100644 --- a/tests/test_kernel_direct.py +++ b/tests/test_kernel_direct.py @@ -16,6 +16,10 @@ async def test_direct_kernel_info_request(kernel): reply = await kernel.test_shell_message("kernel_info_request", {}) assert reply["header"]["msg_type"] == "kernel_info_reply" + assert ( + "supported_features" not in reply["content"] + or "kernel subshells" not in reply["content"]["supported_features"] + ) async def test_direct_execute_request(kernel): diff --git a/tests/test_message_spec.py b/tests/test_message_spec.py index d98503ee..694de44b 100644 --- a/tests/test_message_spec.py +++ b/tests/test_message_spec.py @@ -239,6 +239,21 @@ class HistoryReply(Reply): history = List(List()) +# Subshell control messages + + +class CreateSubshellReply(Reply): + subshell_id = Unicode() + + +class DeleteSubshellReply(Reply): + pass + + +class ListSubshellReply(Reply): + subshell_id = List(Unicode()) + + references = { "execute_reply": ExecuteReply(), "inspect_reply": InspectReply(), @@ -255,6 +270,9 @@ class HistoryReply(Reply): "stream": Stream(), "display_data": DisplayData(), "header": RHeader(), + "create_subshell_reply": CreateSubshellReply(), + "delete_subshell_reply": DeleteSubshellReply(), + "list_subshell_reply": ListSubshellReply(), } # ----------------------------------------------------------------------------- @@ -498,6 +516,8 @@ def test_kernel_info_request(): msg_id = KC.kernel_info() reply = get_reply(KC, msg_id, TIMEOUT) validate_message(reply, "kernel_info_reply", msg_id) + assert "supported_features" in reply["content"] + assert "kernel subshells" in reply["content"]["supported_features"] def test_connect_request(): @@ -509,6 +529,29 @@ def test_connect_request(): validate_message(reply, "connect_reply", msg_id) +def test_subshell(): + flush_channels() + + msg = KC.session.msg("create_subshell_request") + KC.control_channel.send(msg) + msg_id = msg["header"]["msg_id"] + reply = get_reply(KC, msg_id, TIMEOUT, channel="control") + validate_message(reply, "create_subshell_reply", msg_id) + subshell_id = reply["content"]["subshell_id"] + + msg = KC.session.msg("list_subshell_request") + KC.control_channel.send(msg) + msg_id = msg["header"]["msg_id"] + reply = get_reply(KC, msg_id, TIMEOUT, channel="control") + validate_message(reply, "list_subshell_reply", msg_id) + + msg = KC.session.msg("delete_subshell_request", {"subshell_id": subshell_id}) + KC.control_channel.send(msg) + msg_id = msg["header"]["msg_id"] + reply = get_reply(KC, msg_id, TIMEOUT, channel="control") + validate_message(reply, "delete_subshell_reply", msg_id) + + @pytest.mark.skipif( version_info < (5, 0), reason="earlier Jupyter Client don't have comm_info", diff --git a/tests/test_subshells.py b/tests/test_subshells.py new file mode 100644 index 00000000..f1328dda --- /dev/null +++ b/tests/test_subshells.py @@ -0,0 +1,269 @@ +"""Test kernel subshells.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import platform +import time +from datetime import datetime, timedelta + +import pytest +from jupyter_client.blocking.client import BlockingKernelClient + +from .utils import TIMEOUT, get_replies, get_reply, new_kernel + +# Helpers + + +def create_subshell_helper(kc: BlockingKernelClient): + msg = kc.session.msg("create_subshell_request") + kc.control_channel.send(msg) + msg_id = msg["header"]["msg_id"] + reply = get_reply(kc, msg_id, TIMEOUT, channel="control") + return reply["content"] + + +def delete_subshell_helper(kc: BlockingKernelClient, subshell_id: str): + msg = kc.session.msg("delete_subshell_request", {"subshell_id": subshell_id}) + kc.control_channel.send(msg) + msg_id = msg["header"]["msg_id"] + reply = get_reply(kc, msg_id, TIMEOUT, channel="control") + return reply["content"] + + +def list_subshell_helper(kc: BlockingKernelClient): + msg = kc.session.msg("list_subshell_request") + kc.control_channel.send(msg) + msg_id = msg["header"]["msg_id"] + reply = get_reply(kc, msg_id, TIMEOUT, channel="control") + return reply["content"] + + +def execute_request_subshell_id( + kc: BlockingKernelClient, code: str, subshell_id: str | None, terminator: str = "\n" +): + msg = kc.session.msg("execute_request", {"code": code}) + msg["header"]["subshell_id"] = subshell_id + msg_id = msg["msg_id"] + kc.shell_channel.send(msg) + stdout = "" + while True: + msg = kc.get_iopub_msg() + # Get the stream messages corresponding to msg_id + if ( + msg["msg_type"] == "stream" + and msg["parent_header"]["msg_id"] == msg_id + and msg["content"]["name"] == "stdout" + ): + stdout += msg["content"]["text"] + if stdout.endswith(terminator): + break + return stdout.strip() + + +def execute_thread_count(kc: BlockingKernelClient) -> int: + code = "import threading as t; print(t.active_count())" + return int(execute_request_subshell_id(kc, code, None)) + + +def execute_thread_ids(kc: BlockingKernelClient, subshell_id: str | None = None) -> tuple[str, str]: + code = "import threading as t; print(t.get_ident(), t.main_thread().ident)" + return execute_request_subshell_id(kc, code, subshell_id).split() + + +# Tests + + +def test_supported(): + with new_kernel() as kc: + msg_id = kc.kernel_info() + reply = get_reply(kc, msg_id, TIMEOUT) + assert "supported_features" in reply["content"] + assert "kernel subshells" in reply["content"]["supported_features"] + + +def test_subshell_id_lifetime(): + with new_kernel() as kc: + assert list_subshell_helper(kc)["subshell_id"] == [] + subshell_id = create_subshell_helper(kc)["subshell_id"] + assert list_subshell_helper(kc)["subshell_id"] == [subshell_id] + delete_subshell_helper(kc, subshell_id) + assert list_subshell_helper(kc)["subshell_id"] == [] + + +def test_delete_non_existent(): + with new_kernel() as kc: + reply = delete_subshell_helper(kc, "unknown_subshell_id") + assert reply["status"] == "error" + assert "evalue" in reply + + +def test_thread_counts(): + with new_kernel() as kc: + nthreads = execute_thread_count(kc) + + subshell_id = create_subshell_helper(kc)["subshell_id"] + nthreads2 = execute_thread_count(kc) + assert nthreads2 > nthreads + + delete_subshell_helper(kc, subshell_id) + nthreads3 = execute_thread_count(kc) + assert nthreads3 == nthreads + + +def test_thread_ids(): + with new_kernel() as kc: + subshell_id = create_subshell_helper(kc)["subshell_id"] + + thread_id, main_thread_id = execute_thread_ids(kc) + assert thread_id == main_thread_id + + thread_id, main_thread_id = execute_thread_ids(kc, subshell_id) + assert thread_id != main_thread_id + + delete_subshell_helper(kc, subshell_id) + + +@pytest.mark.parametrize("are_subshells", [(False, True), (True, False), (True, True)]) +@pytest.mark.parametrize("overlap", [True, False]) +def test_run_concurrently_sequence(are_subshells, overlap): + with new_kernel() as kc: + subshell_ids = [ + create_subshell_helper(kc)["subshell_id"] if is_subshell else None + for is_subshell in are_subshells + ] + if overlap: + codes = [ + "import time; start0=True; end0=False; time.sleep(0.2); end0=True", + "assert start0; assert not end0; time.sleep(0.2); assert end0", + ] + else: + codes = [ + "import time; start0=True; end0=False; time.sleep(0.2); assert end1", + "assert start0; assert not end0; end1=True", + ] + + msgs = [] + for subshell_id, code in zip(subshell_ids, codes): + msg = kc.session.msg("execute_request", {"code": code}) + msg["header"]["subshell_id"] = subshell_id + kc.shell_channel.send(msg) + msgs.append(msg) + if len(msgs) == 1: + time.sleep(0.1) # Wait for first execute_request to start. + + replies = get_replies(kc, [msg["msg_id"] for msg in msgs]) + + for subshell_id in subshell_ids: + if subshell_id: + delete_subshell_helper(kc, subshell_id) + + for reply in replies: + assert reply["content"]["status"] == "ok" + + +@pytest.mark.parametrize("include_main_shell", [True, False]) +def test_run_concurrently_timing(include_main_shell): + with new_kernel() as kc: + subshell_ids = [ + None if include_main_shell else create_subshell_helper(kc)["subshell_id"], + create_subshell_helper(kc)["subshell_id"], + ] + + times = (0.2, 0.2) + # Prepare messages, times are sleep times in seconds. + # Identical times for both subshells is a harder test as preparing and sending + # the execute_reply messages may overlap. + msgs = [] + for id, sleep in zip(subshell_ids, times): + code = f"import time; time.sleep({sleep})" + msg = kc.session.msg("execute_request", {"code": code}) + msg["header"]["subshell_id"] = id + msgs.append(msg) + + # Send messages + start = datetime.now() + for msg in msgs: + kc.shell_channel.send(msg) + + _ = get_replies(kc, [msg["msg_id"] for msg in msgs]) + end = datetime.now() + + for subshell_id in subshell_ids: + if subshell_id: + delete_subshell_helper(kc, subshell_id) + + duration = end - start + assert duration >= timedelta(seconds=max(times)) + # Care is needed with this test as runtime conditions such as gathering + # coverage can slow it down causing the following assert to fail. + # The sleep time of 0.2 is empirically determined to run OK in CI, but + # consider increasing it if the following fails. + assert duration < timedelta(seconds=sum(times)) + + +def test_execution_count(): + with new_kernel() as kc: + subshell_id = create_subshell_helper(kc)["subshell_id"] + + # Prepare messages + times = (0.2, 0.1, 0.4, 0.15) # Sleep seconds + msgs = [] + for id, sleep in zip((None, subshell_id, None, subshell_id), times): + code = f"import time; time.sleep({sleep})" + msg = kc.session.msg("execute_request", {"code": code}) + msg["header"]["subshell_id"] = id + msgs.append(msg) + + for msg in msgs: + kc.shell_channel.send(msg) + + # Wait for replies, may be in any order. + replies = get_replies(kc, [msg["msg_id"] for msg in msgs]) + + delete_subshell_helper(kc, subshell_id) + + execution_counts = [r["content"]["execution_count"] for r in replies] + ec = execution_counts[0] + assert execution_counts == [ec, ec - 1, ec + 2, ec + 1] + + +def test_create_while_execute(): + with new_kernel() as kc: + # Send request to execute code on main subshell. + msg = kc.session.msg("execute_request", {"code": "import time; time.sleep(0.05)"}) + kc.shell_channel.send(msg) + + # Create subshell via control channel. + control_msg = kc.session.msg("create_subshell_request") + kc.control_channel.send(control_msg) + control_reply = get_reply(kc, control_msg["header"]["msg_id"], TIMEOUT, channel="control") + subshell_id = control_reply["content"]["subshell_id"] + control_date = control_reply["header"]["date"] + + # Get result message from main subshell. + shell_date = get_reply(kc, msg["msg_id"])["header"]["date"] + + delete_subshell_helper(kc, subshell_id) + + assert control_date < shell_date + + +@pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="does not work on PyPy", +) +def test_shutdown_with_subshell(): + # Based on test_kernel.py::test_shutdown + with new_kernel() as kc: + km = kc.parent + subshell_id = create_subshell_helper(kc)["subshell_id"] + assert list_subshell_helper(kc)["subshell_id"] == [subshell_id] + kc.shutdown() + for _ in range(100): # 10 s timeout + if km.is_alive(): + time.sleep(0.1) + else: + break + assert not km.is_alive() diff --git a/tests/utils.py b/tests/utils.py index b1b4119f..b20e8fcb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,7 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations import atexit import os @@ -68,6 +69,28 @@ def get_reply(kc, msg_id, timeout=TIMEOUT, channel="shell"): return reply +def get_replies(kc, msg_ids: list[str], timeout=TIMEOUT, channel="shell"): + # Get replies which may arrive in any order as they may be running on different subshells. + # Replies are returned in the same order as the msg_ids, not in the order of arrival. + t0 = time() + count = 0 + replies = [None] * len(msg_ids) + while count < len(msg_ids): + get_msg = getattr(kc, f"get_{channel}_msg") + reply = get_msg(timeout=timeout) + try: + msg_id = reply["parent_header"]["msg_id"] + replies[msg_ids.index(msg_id)] = reply + count += 1 + except ValueError: + # Allow debugging ignored replies + print(f"Ignoring reply not to any of {msg_ids}: {reply}") + t1 = time() + timeout -= t1 - t0 + t0 = t1 + return replies + + def execute(code="", kc=None, **kwargs): """wrapper for doing common steps for validating an execution request""" from .test_message_spec import validate_message