From 38224fbc1635591c2e698616e3643af508bcaa03 Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 3 Oct 2017 19:48:00 +0200 Subject: [PATCH] restore actual zmq channels when resuming connection rather than establishing new connections fixes failure to resume shell channel --- notebook/services/kernels/handlers.py | 41 ++++++++++++---------- notebook/services/kernels/kernelmanager.py | 35 ++++++++++++------ 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/notebook/services/kernels/handlers.py b/notebook/services/kernels/handlers.py index dfdd8a3d65..006047e27c 100644 --- a/notebook/services/kernels/handlers.py +++ b/notebook/services/kernels/handlers.py @@ -261,25 +261,28 @@ def open(self, kernel_id): self.kernel_manager.notify_connect(kernel_id) # on new connections, flush the message buffer - replay_buffer = self.kernel_manager.stop_buffering(kernel_id, self.session_key) - - try: - self.create_stream() - except web.HTTPError as e: - self.log.error("Error opening stream: %s", e) - # WebSockets don't response to traditional error codes so we - # close the connection. - for channel, stream in self.channels.items(): - if not stream.closed(): - stream.close() - self.close() - return - - if replay_buffer: - self.log.info("Replaying %s buffered messages", len(replay_buffer)) - for channel, msg_list in replay_buffer: - stream = self.channels[channel] - self._on_zmq_reply(stream, msg_list) + buffer_info = self.kernel_manager.get_buffer(kernel_id, self.session_key) + if buffer_info and buffer_info['session_key'] == self.session_key: + self.log.info("Restoring connection for %s", self.session_key) + self.channels = buffer_info['channels'] + replay_buffer = buffer_info['buffer'] + if replay_buffer: + self.log.info("Replaying %s buffered messages", len(replay_buffer)) + for channel, msg_list in replay_buffer: + stream = self.channels[channel] + self._on_zmq_reply(stream, msg_list) + else: + try: + self.create_stream() + except web.HTTPError as e: + self.log.error("Error opening stream: %s", e) + # WebSockets don't response to traditional error codes so we + # close the connection. + for channel, stream in self.channels.items(): + if not stream.closed(): + stream.close() + self.close() + return for channel, stream in self.channels.items(): stream.on_recv_stream(self._on_zmq_reply) diff --git a/notebook/services/kernels/kernelmanager.py b/notebook/services/kernels/kernelmanager.py index ed733a9c1c..6f47ed6a71 100644 --- a/notebook/services/kernels/kernelmanager.py +++ b/notebook/services/kernels/kernelmanager.py @@ -182,11 +182,9 @@ def buffer_msg(channel, msg_parts): for channel, stream in channels.items(): stream.on_recv(partial(buffer_msg, channel)) - def stop_buffering(self, kernel_id, session_key=None): - """Stop buffering kernel messages - - if session_key matches the current buffered session for the kernel, - the buffer will be returned. Otherwise, an empty list will be returned. + + def get_buffer(self, kernel_id, session_key): + """Get the buffer for a given kernel Parameters ---------- @@ -197,6 +195,27 @@ def stop_buffering(self, kernel_id, session_key=None): If the session_key matches the current buffered session_key, the buffer will be returned. """ + self.log.debug("Getting buffer for %s", kernel_id) + if kernel_id not in self._kernel_buffers: + return + + buffer_info = self._kernel_buffers[kernel_id] + if buffer_info['session_key'] == session_key: + # remove buffer + self._kernel_buffers.pop(kernel_id) + # only return buffer_info if it's a match + return buffer_info + else: + self.stop_buffering(kernel_id) + + def stop_buffering(self, kernel_id): + """Stop buffering kernel messages + + Parameters + ---------- + kernel_id : str + The id of the kernel to stop buffering. + """ self.log.debug("Clearing buffer for %s", kernel_id) self._check_kernel_id(kernel_id) @@ -211,13 +230,9 @@ def stop_buffering(self, kernel_id, session_key=None): stream.close() msg_buffer = buffer_info['buffer'] - if msg_buffer and buffer_info['session_key'] != session_key: + if msg_buffer: self.log.info("Discarding %s buffered messages for %s", len(msg_buffer), buffer_info['session_key']) - msg_buffer = [] - - # return previous buffer if it matched the session key - return msg_buffer def shutdown_kernel(self, kernel_id, now=False): """Shutdown a kernel by kernel_id"""