diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index dd75d559..5d6c8ef4 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -95,7 +95,7 @@ jobs: pip-install-constraints: >- jupyter-server==1.0 simpervisor==1.0 - tornado==5.0 + tornado==5.1 traitlets==4.2.1 steps: diff --git a/jupyter_server_proxy/handlers.py b/jupyter_server_proxy/handlers.py index 2890dc3a..a0e12260 100644 --- a/jupyter_server_proxy/handlers.py +++ b/jupyter_server_proxy/handlers.py @@ -116,7 +116,6 @@ def __init__(self, *args, **kwargs): "rewrite_response", tuple(), ) - self.subprotocols = None super().__init__(*args, **kwargs) # Support/use jupyter_server config arguments allow_origin and allow_origin_pat @@ -489,15 +488,28 @@ async def start_websocket_connection(): self.log.info(f"Trying to establish websocket connection to {client_uri}") self._record_activity() request = httpclient.HTTPRequest(url=client_uri, headers=headers) + subprotocols = ( + [self.selected_subprotocol] if self.selected_subprotocol else None + ) self.ws = await pingable_ws_connect( request=request, on_message_callback=message_cb, on_ping_callback=ping_cb, - subprotocols=self.subprotocols, + subprotocols=subprotocols, resolver=resolver, ) self._record_activity() self.log.info(f"Websocket connection established to {client_uri}") + if ( + subprotocols + and self.ws.selected_subprotocol != self.selected_subprotocol + ): + self.log.warn( + f"Websocket subprotocol between proxy/server ({self.ws.selected_subprotocol}) " + f"became different than for client/proxy ({self.selected_subprotocol}) " + "due to https://github.com/jupyterhub/jupyter-server-proxy/issues/459. " + f"Requested subprotocols were {subprotocols}." + ) # Wait for the WebSocket to be connected before resolving. # Otherwise, messages sent by the client before the @@ -531,12 +543,25 @@ def check_xsrf_cookie(self): """ def select_subprotocol(self, subprotocols): - """Select a single Sec-WebSocket-Protocol during handshake.""" - self.subprotocols = subprotocols - if isinstance(subprotocols, list) and subprotocols: - self.log.debug(f"Client sent subprotocols: {subprotocols}") + """ + Select a single Sec-WebSocket-Protocol during handshake. + + Note that this subprotocol selection should really be delegated to the + server we proxy to, but we don't! For this to happen, we would need to + delay accepting the handshake with the client until we have successfully + handshaked with the server. This issue is tracked via + https://github.com/jupyterhub/jupyter-server-proxy/issues/459. + + Overrides `tornado.websocket.WebSocketHandler.select_subprotocol` that + includes an informative docstring: + https://github.com/tornadoweb/tornado/blob/v6.4.0/tornado/websocket.py#L337-L360. + """ + if subprotocols: + self.log.debug( + f"Client sent subprotocols: {subprotocols}, selecting the first" + ) return subprotocols[0] - return super().select_subprotocol(subprotocols) + return None class LocalProxyHandler(ProxyHandler): diff --git a/pyproject.toml b/pyproject.toml index 2a1e1ad8..f13e94f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,13 +50,14 @@ dependencies = [ "importlib_metadata >=4.8.3 ; python_version<\"3.10\"", "jupyter-server >=1.0", "simpervisor >=1.0", - "tornado >=5.0", + "tornado >=5.1", "traitlets >= 4.2.1", ] [project.optional-dependencies] test = [ "pytest", + "pytest-asyncio", "pytest-cov", "pytest-html", ] @@ -195,11 +196,16 @@ src = "pyproject.toml" [[tool.tbump.file]] src = "labextension/package.json" + +# pytest is used for running Python based tests +# +# ref: https://docs.pytest.org/en/stable/ +# [tool.pytest.ini_options] -cache_dir = "build/.cache/pytest" -testpaths = ["tests"] addopts = [ - "-vv", + "--verbose", + "--durations=10", + "--color=yes", "--cov=jupyter_server_proxy", "--cov-branch", "--cov-context=test", @@ -207,9 +213,16 @@ addopts = [ "--cov-report=html:build/coverage", "--no-cov-on-fail", "--html=build/pytest/index.html", - "--color=yes", ] +asyncio_mode = "auto" +testpaths = ["tests"] +cache_dir = "build/.cache/pytest" + +# pytest-cov / coverage is used to measure code coverage of tests +# +# ref: https://coverage.readthedocs.io/en/stable/config.html +# [tool.coverage.run] data_file = "build/.coverage" concurrency = [ diff --git a/tests/resources/websocket.py b/tests/resources/websocket.py index dda24d7c..fe8bf82a 100644 --- a/tests/resources/websocket.py +++ b/tests/resources/websocket.py @@ -54,26 +54,48 @@ def get(self): class EchoWebSocket(tornado.websocket.WebSocketHandler): + """Echoes back received messages.""" + def on_message(self, message): self.write_message(message) class HeadersWebSocket(tornado.websocket.WebSocketHandler): + """Echoes back incoming request headers.""" + def on_message(self, message): self.write_message(json.dumps(dict(self.request.headers))) class SubprotocolWebSocket(tornado.websocket.WebSocketHandler): + """ + Echoes back requested subprotocols and selected subprotocol as a JSON + encoded message, and selects subprotocols in a very particular way to help + us test things. + """ + def __init__(self, *args, **kwargs): - self._subprotocols = None + self._requested_subprotocols = None super().__init__(*args, **kwargs) def select_subprotocol(self, subprotocols): - self._subprotocols = subprotocols - return None + self._requested_subprotocols = subprotocols if subprotocols else None + + if not subprotocols: + return None + if "please_select_no_protocol" in subprotocols: + return None + if "favored" in subprotocols: + return "favored" + else: + return subprotocols[0] def on_message(self, message): - self.write_message(json.dumps(self._subprotocols)) + response = { + "requested_subprotocols": self._requested_subprotocols, + "selected_subprotocol": self.selected_subprotocol, + } + self.write_message(json.dumps(response)) def main(): diff --git a/tests/test_proxies.py b/tests/test_proxies.py index 5605b4d1..7e16d849 100644 --- a/tests/test_proxies.py +++ b/tests/test_proxies.py @@ -1,4 +1,3 @@ -import asyncio import gzip import json import sys @@ -332,14 +331,9 @@ def test_server_content_encoding_header( assert f.read() == b"this is a test" -@pytest.fixture(scope="module") -def event_loop(): - loop = asyncio.get_event_loop() - yield loop - loop.close() - - -async def _websocket_echo(a_server_port_and_token: Tuple[int, str]) -> None: +async def test_server_proxy_websocket_messages( + a_server_port_and_token: Tuple[int, str] +) -> None: PORT = a_server_port_and_token[0] url = f"ws://{LOCALHOST}:{PORT}/python-websocket/echosocket" conn = await websocket_connect(url) @@ -349,13 +343,7 @@ async def _websocket_echo(a_server_port_and_token: Tuple[int, str]) -> None: assert msg == expected_msg -def test_server_proxy_websocket( - event_loop, a_server_port_and_token: Tuple[int, str] -) -> None: - event_loop.run_until_complete(_websocket_echo(a_server_port_and_token)) - - -async def _websocket_headers(a_server_port_and_token: Tuple[int, str]) -> None: +async def test_server_proxy_websocket_headers(a_server_port_and_token: Tuple[int, str]): PORT = a_server_port_and_token[0] url = f"ws://{LOCALHOST}:{PORT}/python-websocket/headerssocket" conn = await websocket_connect(url) @@ -366,25 +354,68 @@ async def _websocket_headers(a_server_port_and_token: Tuple[int, str]) -> None: assert headers["X-Custom-Header"] == "pytest-23456" -def test_server_proxy_websocket_headers( - event_loop, a_server_port_and_token: Tuple[int, str] +@pytest.mark.parametrize( + "client_requested,server_received,server_responded,proxy_responded", + [ + (None, None, None, None), + (["first"], ["first"], "first", "first"), + # IMPORTANT: The tests below verify current bugged behavior, and the + # commented out tests is what we want to succeed! + # + # The proxy websocket should actually respond the handshake + # with a subprotocol based on a the server handshake + # response, but we are finalizing the client/proxy handshake + # before the proxy/server handshake, and that makes it + # impossible. We currently instead just pick the first + # requested protocol no matter what what subprotocol the + # server picks. + # + # Bug 1 - server wasn't passed all subprotocols: + (["first", "second"], ["first"], "first", "first"), + # (["first", "second"], ["first", "second"], "first", "first"), + # + # Bug 2 - server_responded doesn't match proxy_responded: + (["first", "favored"], ["first"], "first", "first"), + # (["first", "favored"], ["first", "favored"], "favored", "favored"), + ( + ["please_select_no_protocol"], + ["please_select_no_protocol"], + None, + "please_select_no_protocol", + ), + # (["please_select_no_protocol"], ["please_select_no_protocol"], None, None), + ], +) +async def test_server_proxy_websocket_subprotocols( + a_server_port_and_token: Tuple[int, str], + client_requested, + server_received, + server_responded, + proxy_responded, ): - event_loop.run_until_complete(_websocket_headers(a_server_port_and_token)) - - -async def _websocket_subprotocols(a_server_port_and_token: Tuple[int, str]) -> None: PORT, TOKEN = a_server_port_and_token url = f"ws://{LOCALHOST}:{PORT}/python-websocket/subprotocolsocket" - conn = await websocket_connect(url, subprotocols=["protocol_1", "protocol_2"]) + conn = await websocket_connect(url, subprotocols=client_requested) await conn.write_message("Hello, world!") + + # verify understanding of websocket_connect that this test relies on + if client_requested: + assert "Sec-Websocket-Protocol" in conn.request.headers + else: + assert "Sec-Websocket-Protocol" not in conn.request.headers + msg = await conn.read_message() - assert json.loads(msg) == ["protocol_1", "protocol_2"] + info = json.loads(msg) + assert info["requested_subprotocols"] == server_received + assert info["selected_subprotocol"] == server_responded + assert conn.selected_subprotocol == proxy_responded -def test_server_proxy_websocket_subprotocols( - event_loop, a_server_port_and_token: Tuple[int, str] -): - event_loop.run_until_complete(_websocket_subprotocols(a_server_port_and_token)) + # verify proxy response headers directly + if proxy_responded is None: + assert "Sec-Websocket-Protocol" not in conn.headers + else: + assert "Sec-Websocket-Protocol" in conn.headers @pytest.mark.parametrize( @@ -410,7 +441,9 @@ def test_bad_server_proxy_url( assert "X-ProxyContextPath" not in r.headers -def test_callable_environment_formatting(a_server_port_and_token: Tuple[int, str]) -> None: +def test_callable_environment_formatting( + a_server_port_and_token: Tuple[int, str] +) -> None: PORT, TOKEN = a_server_port_and_token r = request_get(PORT, "/python-http-callable-env/test", TOKEN) assert r.code == 200