Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

buffer messages when websocket connection is interrupted #2871

Merged
merged 5 commits into from
Oct 6, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions notebook/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def post(self, kernel_id, action):


class ZMQChannelsHandler(AuthenticatedZMQStreamHandler):
'''There is one ZMQChannelsHandler per running kernel and it oversees all
the sessions.
'''

# class-level registry of open sessions
# allows checking for conflict on session-id,
Expand Down Expand Up @@ -252,10 +255,14 @@ def _register_session(self):
self.log.warning("Replacing stale connection: %s", self.session_key)
yield stale_handler.close()
self._open_sessions[self.session_key] = self

def open(self, kernel_id):
super(ZMQChannelsHandler, self).open()
self.kernel_manager.notify_connect(kernel_id)

# on new connections, flush the message buffer
replay_buffer = self.kernel_manager.stop_buffering(kernel_id, self.session_key)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok cool you made the kernel manager dictate stopping the buffering


try:
self.create_stream()
except web.HTTPError as e:
Expand All @@ -266,9 +273,16 @@ def open(self, kernel_id):
if not stream.closed():
stream.close()
self.close()
else:
for channel, stream in self.channels.items():
stream.on_recv_stream(self._on_zmq_reply)
return

if replay_buffer:
self.log.info("Replaying %s buffered messages", len(replay_buffer))
for channel, msg_list in replay_buffer:
stream = self.channels[channel]
self._on_zmq_reply(stream, msg_list)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should we do if we fail during the replay?


for channel, stream in self.channels.items():
stream.on_recv_stream(self._on_zmq_reply)

def on_message(self, msg):
if not self.channels:
Expand All @@ -288,7 +302,7 @@ def on_message(self, msg):
return
stream = self.channels[channel]
self.session.send(stream, msg)

def _on_zmq_reply(self, stream, msg_list):
idents, fed_msg_list = self.session.feed_identities(msg_list)
msg = self.session.deserialize(fed_msg_list)
Expand All @@ -301,7 +315,6 @@ def write_stderr(error_message):
)
msg['channel'] = 'iopub'
self.write_message(json.dumps(msg, default=date_default))

channel = getattr(stream, 'channel', None)
msg_type = msg['header']['msg_type']

Expand Down Expand Up @@ -408,6 +421,7 @@ def on_close(self):
# unregister myself as an open session (only if it's really me)
if self._open_sessions.get(self.session_key) is self:
self._open_sessions.pop(self.session_key)

km = self.kernel_manager
if self.kernel_id in km:
km.notify_disconnect(self.kernel_id)
Expand All @@ -417,6 +431,13 @@ def on_close(self):
km.remove_restart_callback(
self.kernel_id, self.on_restart_failed, 'dead',
)

# start buffering instead of closing if this was the last connection
if km._kernel_connections[self.kernel_id] == 0:
km.start_buffering(self.kernel_id, self.session_key, self.channels)
self._close_future.set_result(None)
return

# This method can be called twice, once by self.kernel_died and once
# from the WebSocket close event. If the WebSocket connection is
# closed before the ZMQ streams are setup, they could be None.
Expand Down
86 changes: 83 additions & 3 deletions notebook/services/kernels/kernelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

from collections import defaultdict
from functools import partial
import os

from tornado import gen, web
Expand All @@ -15,13 +17,13 @@

from jupyter_client.session import Session
from jupyter_client.multikernelmanager import MultiKernelManager
from traitlets import Bool, Dict, List, Unicode, TraitError, Integer, default, validate
from traitlets import Any, Bool, Dict, List, Unicode, TraitError, Integer, default, validate

from notebook.utils import to_os_path, exists
from notebook._tz import utcnow, isoformat
from ipython_genutils.py3compat import getcwd

from datetime import datetime, timedelta
from datetime import timedelta


class MappingKernelManager(MultiKernelManager):
Expand Down Expand Up @@ -81,6 +83,11 @@ def _update_root_dir(self, proposal):
Only effective if cull_idle_timeout is not 0."""
)

_kernel_buffers = Any()
@default('_kernel_buffers')
def _default_kernel_buffers(self):
return defaultdict(lambda: {'buffer': [], 'session_key': '', 'channels': {}})

#-------------------------------------------------------------------------
# Methods for managing kernels and sessions
#-------------------------------------------------------------------------
Expand Down Expand Up @@ -142,10 +149,82 @@ def start_kernel(self, kernel_id=None, path=None, **kwargs):
# py2-compat
raise gen.Return(kernel_id)

def start_buffering(self, kernel_id, session_key, channels):
"""Start buffering messages for a kernel

Parameters
----------
kernel_id : str
The id of the kernel to stop buffering.
session_key: str
The session_key, if any, that should get the buffer.
If the session_key matches the current buffered session_key,
the buffer will be returned.
channels: dict({'channel': ZMQStream})
The zmq channels whose messages should be buffered.
"""
self.log.info("Starting buffering for %s", session_key)
self._check_kernel_id(kernel_id)
# clear previous buffering state
self.stop_buffering(kernel_id)
buffer_info = self._kernel_buffers[kernel_id]
# record the session key because only one session can buffer
buffer_info['session_key'] = session_key
# TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple
buffer_info['buffer'] = []
buffer_info['channels'] = channels

# forward any future messages to the internal buffer
def buffer_msg(channel, msg_parts):
self.log.debug("Buffering msg on %s:%s", kernel_id, channel)
buffer_info['buffer'].append((channel, msg_parts))

for channel, stream in channels.items():
stream.on_recv(partial(buffer_msg, channel))

def stop_buffering(self, kernel_id, session_key=None):
"""Stop buffering kernel messages

if session_key matches the current buffered session for the kernel,
the buffer will be returned. Otherwise, an empty list will be returned.

Parameters
----------
kernel_id : str
The id of the kernel to stop buffering.
session_key: str, optional
The session_key, if any, that should get the buffer.
If the session_key matches the current buffered session_key,
the buffer will be returned.
"""
self.log.debug("Clearing buffer for %s", kernel_id)
self._check_kernel_id(kernel_id)

if kernel_id not in self._kernel_buffers:
return
buffer_info = self._kernel_buffers.pop(kernel_id)
# close buffering streams
for stream in buffer_info['channels'].values():
if not stream.closed():
stream.on_recv(None)
stream.socket.close()
stream.close()

msg_buffer = buffer_info['buffer']
if msg_buffer and buffer_info['session_key'] != session_key:
self.log.info("Discarding %s buffered messages for %s",
len(msg_buffer), buffer_info['session_key'])
msg_buffer = []

# return previous buffer if it matched the session key
return msg_buffer

def shutdown_kernel(self, kernel_id, now=False):
"""Shutdown a kernel by kernel_id"""
self._check_kernel_id(kernel_id)
self._kernels[kernel_id]._activity_stream.close()
kernel = self._kernels[kernel_id]
kernel._activity_stream.close()
self.stop_buffering(kernel_id)
self._kernel_connections.pop(kernel_id, None)
return super(MappingKernelManager, self).shutdown_kernel(kernel_id, now=now)

Expand Down Expand Up @@ -256,6 +335,7 @@ def record_activity(msg_list):

idents, fed_msg_list = session.feed_identities(msg_list)
msg = session.deserialize(fed_msg_list)

msg_type = msg['header']['msg_type']
self.log.debug("activity on %s: %s", kernel_id, msg_type)
if msg_type == 'status':
Expand Down