diff --git a/jupyter_server_proxy/handlers.py b/jupyter_server_proxy/handlers.py index 2890dc3a..79d12fbe 100644 --- a/jupyter_server_proxy/handlers.py +++ b/jupyter_server_proxy/handlers.py @@ -493,7 +493,7 @@ async def start_websocket_connection(): request=request, on_message_callback=message_cb, on_ping_callback=ping_cb, - subprotocols=self.subprotocols, + subprotocols=self.subprotocols if self.subprotocols else None, resolver=resolver, ) self._record_activity() diff --git a/tests/resources/websocket.py b/tests/resources/websocket.py index dda24d7c..666698b7 100644 --- a/tests/resources/websocket.py +++ b/tests/resources/websocket.py @@ -36,7 +36,6 @@ def __init__(self): handlers = [ (r"/", MainHandler), (r"/echosocket", EchoWebSocket), - (r"/subprotocolsocket", SubprotocolWebSocket), (r"/headerssocket", HeadersWebSocket), ] settings = dict( @@ -63,19 +62,6 @@ def on_message(self, message): self.write_message(json.dumps(dict(self.request.headers))) -class SubprotocolWebSocket(tornado.websocket.WebSocketHandler): - def __init__(self, *args, **kwargs): - self._subprotocols = None - super().__init__(*args, **kwargs) - - def select_subprotocol(self, subprotocols): - self._subprotocols = subprotocols - return None - - def on_message(self, message): - self.write_message(json.dumps(self._subprotocols)) - - def main(): tornado.options.parse_command_line() app = Application() diff --git a/tests/test_proxies.py b/tests/test_proxies.py index 5605b4d1..4ad8924f 100644 --- a/tests/test_proxies.py +++ b/tests/test_proxies.py @@ -374,11 +374,13 @@ def test_server_proxy_websocket_headers( async def _websocket_subprotocols(a_server_port_and_token: Tuple[int, str]) -> None: PORT, TOKEN = a_server_port_and_token - url = f"ws://{LOCALHOST}:{PORT}/python-websocket/subprotocolsocket" + url = f"ws://{LOCALHOST}:{PORT}/python-websocket/headerssocket" conn = await websocket_connect(url, subprotocols=["protocol_1", "protocol_2"]) await conn.write_message("Hello, world!") msg = await conn.read_message() - assert json.loads(msg) == ["protocol_1", "protocol_2"] + headers = json.loads(msg) + assert "Sec-Websocket-Protocol" in headers + assert headers["Sec-Websocket-Protocol"] == "protocol_1,protocol_2" def test_server_proxy_websocket_subprotocols( @@ -387,6 +389,39 @@ def test_server_proxy_websocket_subprotocols( event_loop.run_until_complete(_websocket_subprotocols(a_server_port_and_token)) +async def _websocket_empty_subprotocols(a_server_port_and_token: Tuple[int, str]) -> None: + PORT, TOKEN = a_server_port_and_token + url = f"ws://{LOCALHOST}:{PORT}/python-websocket/headerssocket" + conn = await websocket_connect(url, subprotocols=[]) + await conn.write_message("Hello, world!") + msg = await conn.read_message() + headers = json.loads(msg) + assert "Sec-Websocket-Protocol" in headers + assert headers["Sec-Websocket-Protocol"] == "" + + +def test_server_proxy_websocket_empty_subprotocols( + event_loop, a_server_port_and_token: Tuple[int, str] +): + event_loop.run_until_complete(_websocket_empty_subprotocols(a_server_port_and_token)) + + +async def _websocket_no_subprotocols(a_server_port_and_token: Tuple[int, str]) -> None: + PORT, TOKEN = a_server_port_and_token + url = f"ws://{LOCALHOST}:{PORT}/python-websocket/headerssocket" + conn = await websocket_connect(url) + await conn.write_message("Hello, world!") + msg = await conn.read_message() + headers = json.loads(msg) + assert "Sec-Websocket-Protocol" not in headers + + +def test_server_proxy_websocket_no_subprotocols( + event_loop, a_server_port_and_token: Tuple[int, str] +): + event_loop.run_until_complete(_websocket_no_subprotocols(a_server_port_and_token)) + + @pytest.mark.parametrize( "proxy_path, status", [