Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure no blank Sec-Websocket-Protocol headers and warn if websocket subprotocol edge case occur #458

Merged
merged 5 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ jobs:
pip-install-constraints: >-
jupyter-server==1.0
simpervisor==1.0
tornado==5.0
tornado==5.1
traitlets==4.2.1

steps:
Expand Down
39 changes: 32 additions & 7 deletions jupyter_server_proxy/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(self, *args, **kwargs):
"rewrite_response",
tuple(),
)
self.subprotocols = None
super().__init__(*args, **kwargs)

# Support/use jupyter_server config arguments allow_origin and allow_origin_pat
Expand Down Expand Up @@ -489,15 +488,28 @@ async def start_websocket_connection():
self.log.info(f"Trying to establish websocket connection to {client_uri}")
self._record_activity()
request = httpclient.HTTPRequest(url=client_uri, headers=headers)
subprotocols = (
[self.selected_subprotocol] if self.selected_subprotocol else None
)
self.ws = await pingable_ws_connect(
request=request,
on_message_callback=message_cb,
on_ping_callback=ping_cb,
subprotocols=self.subprotocols,
subprotocols=subprotocols,
resolver=resolver,
)
self._record_activity()
self.log.info(f"Websocket connection established to {client_uri}")
if (
subprotocols
and self.ws.selected_subprotocol != self.selected_subprotocol
):
self.log.warn(
f"Websocket subprotocol between proxy/server ({self.ws.selected_subprotocol}) "
f"became different than for client/proxy ({self.selected_subprotocol}) "
"due to https://github.com/jupyterhub/jupyter-server-proxy/issues/459. "
f"Requested subprotocols were {subprotocols}."
)

# Wait for the WebSocket to be connected before resolving.
# Otherwise, messages sent by the client before the
Expand Down Expand Up @@ -531,12 +543,25 @@ def check_xsrf_cookie(self):
"""

def select_subprotocol(self, subprotocols):
"""Select a single Sec-WebSocket-Protocol during handshake."""
self.subprotocols = subprotocols
if isinstance(subprotocols, list) and subprotocols:
self.log.debug(f"Client sent subprotocols: {subprotocols}")
"""
Select a single Sec-WebSocket-Protocol during handshake.

Note that this subprotocol selection should really be delegated to the
server we proxy to, but we don't! For this to happen, we would need to
delay accepting the handshake with the client until we have successfully
handshaked with the server. This issue is tracked via
https://github.com/jupyterhub/jupyter-server-proxy/issues/459.

Overrides `tornado.websocket.WebSocketHandler.select_subprotocol` that
includes an informative docstring:
https://github.com/tornadoweb/tornado/blob/v6.4.0/tornado/websocket.py#L337-L360.
"""
if subprotocols:
self.log.debug(
f"Client sent subprotocols: {subprotocols}, selecting the first"
)
return subprotocols[0]
return super().select_subprotocol(subprotocols)
return None


class LocalProxyHandler(ProxyHandler):
Expand Down
23 changes: 18 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ dependencies = [
"importlib_metadata >=4.8.3 ; python_version<\"3.10\"",
"jupyter-server >=1.0",
"simpervisor >=1.0",
"tornado >=5.0",
"tornado >=5.1",
consideRatio marked this conversation as resolved.
Show resolved Hide resolved
"traitlets >= 4.2.1",
]

[project.optional-dependencies]
test = [
"pytest",
"pytest-asyncio",
"pytest-cov",
"pytest-html",
]
Expand Down Expand Up @@ -195,21 +196,33 @@ src = "pyproject.toml"
[[tool.tbump.file]]
src = "labextension/package.json"


# pytest is used for running Python based tests
#
# ref: https://docs.pytest.org/en/stable/
#
[tool.pytest.ini_options]
cache_dir = "build/.cache/pytest"
testpaths = ["tests"]
addopts = [
"-vv",
"--verbose",
"--durations=10",
"--color=yes",
"--cov=jupyter_server_proxy",
"--cov-branch",
"--cov-context=test",
"--cov-report=term-missing:skip-covered",
"--cov-report=html:build/coverage",
"--no-cov-on-fail",
"--html=build/pytest/index.html",
"--color=yes",
]
asyncio_mode = "auto"
testpaths = ["tests"]
cache_dir = "build/.cache/pytest"


# pytest-cov / coverage is used to measure code coverage of tests
#
# ref: https://coverage.readthedocs.io/en/stable/config.html
#
[tool.coverage.run]
data_file = "build/.coverage"
concurrency = [
Expand Down
30 changes: 26 additions & 4 deletions tests/resources/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,48 @@ def get(self):


class EchoWebSocket(tornado.websocket.WebSocketHandler):
"""Echoes back received messages."""

def on_message(self, message):
self.write_message(message)


class HeadersWebSocket(tornado.websocket.WebSocketHandler):
"""Echoes back incoming request headers."""

def on_message(self, message):
self.write_message(json.dumps(dict(self.request.headers)))


class SubprotocolWebSocket(tornado.websocket.WebSocketHandler):
"""
Echoes back requested subprotocols and selected subprotocol as a JSON
encoded message, and selects subprotocols in a very particular way to help
us test things.
"""

def __init__(self, *args, **kwargs):
self._subprotocols = None
self._requested_subprotocols = None
super().__init__(*args, **kwargs)

def select_subprotocol(self, subprotocols):
self._subprotocols = subprotocols
return None
self._requested_subprotocols = subprotocols if subprotocols else None

if not subprotocols:
return None
if "please_select_no_protocol" in subprotocols:
return None
if "favored" in subprotocols:
return "favored"
else:
return subprotocols[0]

def on_message(self, message):
self.write_message(json.dumps(self._subprotocols))
response = {
"requested_subprotocols": self._requested_subprotocols,
"selected_subprotocol": self.selected_subprotocol,
}
self.write_message(json.dumps(response))


def main():
Expand Down
91 changes: 62 additions & 29 deletions tests/test_proxies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import gzip
import json
import sys
Expand Down Expand Up @@ -332,14 +331,9 @@ def test_server_content_encoding_header(
assert f.read() == b"this is a test"


@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()


async def _websocket_echo(a_server_port_and_token: Tuple[int, str]) -> None:
async def test_server_proxy_websocket_messages(
a_server_port_and_token: Tuple[int, str]
) -> None:
PORT = a_server_port_and_token[0]
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/echosocket"
conn = await websocket_connect(url)
Expand All @@ -349,13 +343,7 @@ async def _websocket_echo(a_server_port_and_token: Tuple[int, str]) -> None:
assert msg == expected_msg


def test_server_proxy_websocket(
event_loop, a_server_port_and_token: Tuple[int, str]
) -> None:
event_loop.run_until_complete(_websocket_echo(a_server_port_and_token))


async def _websocket_headers(a_server_port_and_token: Tuple[int, str]) -> None:
async def test_server_proxy_websocket_headers(a_server_port_and_token: Tuple[int, str]):
PORT = a_server_port_and_token[0]
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/headerssocket"
conn = await websocket_connect(url)
Expand All @@ -366,25 +354,68 @@ async def _websocket_headers(a_server_port_and_token: Tuple[int, str]) -> None:
assert headers["X-Custom-Header"] == "pytest-23456"


def test_server_proxy_websocket_headers(
event_loop, a_server_port_and_token: Tuple[int, str]
@pytest.mark.parametrize(
"client_requested,server_received,server_responded,proxy_responded",
[
(None, None, None, None),
(["first"], ["first"], "first", "first"),
# IMPORTANT: The tests below verify current bugged behavior, and the
# commented out tests is what we want to succeed!
#
# The proxy websocket should actually respond the handshake
# with a subprotocol based on a the server handshake
# response, but we are finalizing the client/proxy handshake
# before the proxy/server handshake, and that makes it
# impossible. We currently instead just pick the first
# requested protocol no matter what what subprotocol the
# server picks.
#
# Bug 1 - server wasn't passed all subprotocols:
(["first", "second"], ["first"], "first", "first"),
# (["first", "second"], ["first", "second"], "first", "first"),
#
# Bug 2 - server_responded doesn't match proxy_responded:
(["first", "favored"], ["first"], "first", "first"),
# (["first", "favored"], ["first", "favored"], "favored", "favored"),
(
["please_select_no_protocol"],
["please_select_no_protocol"],
None,
"please_select_no_protocol",
),
# (["please_select_no_protocol"], ["please_select_no_protocol"], None, None),
],
)
async def test_server_proxy_websocket_subprotocols(
a_server_port_and_token: Tuple[int, str],
client_requested,
server_received,
server_responded,
proxy_responded,
):
event_loop.run_until_complete(_websocket_headers(a_server_port_and_token))


async def _websocket_subprotocols(a_server_port_and_token: Tuple[int, str]) -> None:
PORT, TOKEN = a_server_port_and_token
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/subprotocolsocket"
conn = await websocket_connect(url, subprotocols=["protocol_1", "protocol_2"])
conn = await websocket_connect(url, subprotocols=client_requested)
await conn.write_message("Hello, world!")

# verify understanding of websocket_connect that this test relies on
if client_requested:
assert "Sec-Websocket-Protocol" in conn.request.headers
else:
assert "Sec-Websocket-Protocol" not in conn.request.headers

msg = await conn.read_message()
assert json.loads(msg) == ["protocol_1", "protocol_2"]
info = json.loads(msg)

assert info["requested_subprotocols"] == server_received
assert info["selected_subprotocol"] == server_responded
assert conn.selected_subprotocol == proxy_responded

def test_server_proxy_websocket_subprotocols(
event_loop, a_server_port_and_token: Tuple[int, str]
):
event_loop.run_until_complete(_websocket_subprotocols(a_server_port_and_token))
# verify proxy response headers directly
if proxy_responded is None:
assert "Sec-Websocket-Protocol" not in conn.headers
else:
assert "Sec-Websocket-Protocol" in conn.headers


@pytest.mark.parametrize(
Expand All @@ -410,7 +441,9 @@ def test_bad_server_proxy_url(
assert "X-ProxyContextPath" not in r.headers


def test_callable_environment_formatting(a_server_port_and_token: Tuple[int, str]) -> None:
def test_callable_environment_formatting(
a_server_port_and_token: Tuple[int, str]
) -> None:
PORT, TOKEN = a_server_port_and_token
r = request_get(PORT, "/python-http-callable-env/test", TOKEN)
assert r.code == 200