Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flush control queue prior to handling shell messages #658

Merged
merged 3 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ async def _abort_queues(self):
""" The in-process kernel doesn't abort requests. """
pass

async def _flush_control_queue(self):
"""No need to flush control queues for in-process"""
pass

def _input_request(self, prompt, ident, parent, password=False):
# Flush output before making the request.
self.raw_input_str = None
Expand Down
35 changes: 33 additions & 2 deletions ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Distributed under the terms of the Modified BSD License.

import asyncio
import concurrent.futures
from datetime import datetime
from functools import partial
import itertools
Expand Down Expand Up @@ -213,8 +214,34 @@ def dispatch_control(self, msg):
async def poll_control_queue(self):
while True:
msg = await self.control_queue.get()
# handle tracers from _flush_control_queue
if isinstance(msg, (concurrent.futures.Future, asyncio.Future)):
msg.set_result(None)
continue
await self.process_control(msg)

async def _flush_control_queue(self):
"""Flush the control queue, wait for processing of any pending messages"""
if self.control_thread:
control_loop = self.control_thread.io_loop
# concurrent.futures.Futures are threadsafe
# and can be used to await across threads
tracer_future = concurrent.futures.Future()
awaitable_future = asyncio.wrap_future(tracer_future)
else:
control_loop = self.io_loop
tracer_future = awaitable_future = asyncio.Future()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice way to make the flush awaitable across threads. I learned something.


def _flush():
# control_stream.flush puts messages on the queue
self.control_stream.flush()
# put Future on the queue after all of those,
# so we can wait for all queued messages to be processed
self.control_queue.put(tracer_future)

control_loop.add_callback(_flush)
return awaitable_future

async def process_control(self, msg):
"""dispatch control requests"""
idents, msg = self.session.feed_identities(msg, copy=False)
Expand Down Expand Up @@ -265,6 +292,10 @@ def should_handle(self, stream, msg, idents):

async def dispatch_shell(self, msg):
"""dispatch shell requests"""

# flush control queue before handling shell requests
await self._flush_control_queue()

idents, msg = self.session.feed_identities(msg, copy=False)
try:
msg = self.session.deserialize(msg, content=True, copy=False)
Expand Down Expand Up @@ -630,7 +661,7 @@ async def inspect_request(self, stream, ident, parent):
content.get('detail_level', 0),
)
if inspect.isawaitable(reply_content):
reply_content = await reply_content
reply_content = await reply_content

# Before we send this object over, we scrub it for JSON usage
reply_content = json_clean(reply_content)
Expand Down Expand Up @@ -944,7 +975,7 @@ def _input_request(self, prompt, ident, parent, password=False):
raise KeyboardInterrupt("Interrupted by user") from None
except Exception as e:
self.log.warning("Invalid Message:", exc_info=True)

try:
value = reply["content"]["value"]
except Exception:
Expand Down
45 changes: 45 additions & 0 deletions ipykernel/tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,48 @@ def test_interrupt_during_pdb_set_trace():
# If we failed to interrupt interrupt, this will timeout:
reply = get_reply(kc, msg_id2, TIMEOUT)
validate_message(reply, 'execute_reply', msg_id2)


def test_control_thread_priority():

N = 5
with new_kernel() as kc:
msg_id = kc.execute("pass")
get_reply(kc, msg_id)

sleep_msg_id = kc.execute("import asyncio; await asyncio.sleep(2)")

# submit N shell messages
shell_msg_ids = []
for i in range(N):
shell_msg_ids.append(kc.execute(f"i = {i}"))

# ensure all shell messages have arrived at the kernel before any control messages
time.sleep(0.5)
# at this point, shell messages should be waiting in msg_queue,
# rather than zmq while the kernel is still in the middle of processing
# the first execution

# now send N control messages
control_msg_ids = []
for i in range(N):
msg = kc.session.msg("kernel_info_request", {})
kc.control_channel.send(msg)
control_msg_ids.append(msg["header"]["msg_id"])

# finally, collect the replies on both channels for comparison
sleep_reply = get_reply(kc, sleep_msg_id)
shell_replies = []
for msg_id in shell_msg_ids:
shell_replies.append(get_reply(kc, msg_id))

control_replies = []
for msg_id in control_msg_ids:
control_replies.append(get_reply(kc, msg_id, channel="control"))

# verify that all control messages were handled before all shell messages
shell_dates = [msg["header"]["date"] for msg in shell_replies]
control_dates = [msg["header"]["date"] for msg in control_replies]
# comparing first to last ought to be enough, since queues preserve order
# use <= in case of very-fast handling and/or low resolution timers
assert control_dates[-1] <= shell_dates[0]
8 changes: 5 additions & 3 deletions ipykernel/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def flush_channels(kc=None):
validate_message(msg)


def get_reply(kc, msg_id, timeout):
timeout = TIMEOUT
def get_reply(kc, msg_id, timeout=TIMEOUT, channel='shell'):
t0 = time()
while True:
reply = kc.get_shell_msg(timeout=timeout)
get_msg = getattr(kc, f'get_{channel}_msg')
reply = get_msg(timeout=timeout)
if reply['parent_header']['msg_id'] == msg_id:
break
# Allow debugging ignored replies
print(f"Ignoring reply not to {msg_id}: {reply}")
t1 = time()
timeout -= t1 - t0
t0 = t1
Expand Down