Skip to content

Commit

Permalink
Ensure that identity provider is used for auth,
Browse files Browse the repository at this point in the history
even in websockets (but only if those inherit from JupyterHandler,
and if they do not fallback to previous implementation and warn).
  • Loading branch information
krassowski committed Feb 20, 2024
1 parent 5e7615d commit 4265f4e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 14 deletions.
21 changes: 15 additions & 6 deletions jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def check_host(self) -> bool:
)
return allow

async def prepare(self) -> Awaitable[None] | None: # type:ignore[override]
async def prepare(self, *, _redirect_to_login=True) -> Awaitable[None] | None: # type:ignore[override]
"""Prepare a response."""
# Set the current Jupyter Handler context variable.
CallContext.set(CallContext.JUPYTER_HANDLER, self)
Expand Down Expand Up @@ -636,9 +636,18 @@ async def prepare(self) -> Awaitable[None] | None: # type:ignore[override]
raise HTTPError(403)
method = getattr(self, self.request.method.lower())
if not getattr(method, "__allow_unauthenticated", False):
# reuse `web.authenticated` logic, which redirects to the login
# page on GET and HEAD and otherwise raises 403
return web.authenticated(lambda _: super().prepare)(self)
if _redirect_to_login:
# reuse `web.authenticated` logic, which redirects to the login
# page on GET and HEAD and otherwise raises 403
return web.authenticated(lambda _: super().prepare())(self)
else:
# raise 403 if user is not known without redirecting to login page
user = self.current_user
if user is None:
self.log.warning(
f"Couldn't authenticate {self.__class__.__name__} connection"
)
raise web.HTTPError(403)

return super().prepare()

Expand Down Expand Up @@ -736,7 +745,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None:
class APIHandler(JupyterHandler):
"""Base class for API handlers"""

async def prepare(self) -> None:
async def prepare(self) -> None: # type:ignore[override]
"""Prepare an API response."""
await super().prepare()
if not self.check_origin():
Expand Down Expand Up @@ -848,7 +857,7 @@ def options(self, *args: Any, **kwargs: Any) -> None:
class Template404(JupyterHandler):
"""Render our 404 template"""

async def prepare(self) -> None:
async def prepare(self) -> None: # type:ignore[override]
"""Prepare a 404 response."""
await super().prepare()
raise web.HTTPError(404)
Expand Down
23 changes: 20 additions & 3 deletions jupyter_server/base/websocket.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Base websocket classes."""
import re
import warnings
from typing import Optional, no_type_check
from urllib.parse import urlparse

from tornado import ioloop, web
from tornado.iostream import IOStream

from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.utils import JupyterServerAuthWarning

# ping interval for keeping websockets alive (30 seconds)
WS_PING_INTERVAL = 30000

Expand Down Expand Up @@ -84,7 +88,10 @@ def clear_cookie(self, *args, **kwargs):

@no_type_check
def _maybe_auth(self):
"""Verify authentication if required"""
"""Verify authentication if required.
Only used when the websocket class does not inherit from JupyterHandler.
"""
if not self.settings.get("allow_unauthenticated_access", False):
if not self.request.method:
raise web.HTTPError(403)
Expand All @@ -100,8 +107,18 @@ def _maybe_auth(self):
@no_type_check
def prepare(self, *args, **kwargs):
"""Handle a get request."""
self._maybe_auth()
return super().prepare(*args, **kwargs)
if not isinstance(self, JupyterHandler):
should_authenticate = not self.settings.get("allow_unauthenticated_access", False)
if "identity_provider" in self.settings and should_authenticate:
warnings.warn(
"WebSocketMixin sub-class does not inherit from JupyterHandler"
" preventing proper authentication using custom identity provider.",
JupyterServerAuthWarning,
stacklevel=2,
)
self._maybe_auth()
return super().prepare(*args, **kwargs)
return super().prepare(*args, **kwargs, _redirect_to_login=False)

@no_type_check
def open(self, *args, **kwargs):
Expand Down
48 changes: 43 additions & 5 deletions tests/base/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

from jupyter_server.auth import IdentityProvider, User
from jupyter_server.auth.decorator import allow_unauthenticated
from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.base.websocket import WebSocketMixin
from jupyter_server.serverapp import ServerApp
from jupyter_server.utils import url_path_join
from jupyter_server.utils import JupyterServerAuthWarning, url_path_join


class MockHandler(WebSocketMixin, WebSocketHandler):
Expand Down Expand Up @@ -66,11 +67,15 @@ async def test_ping_client_timeout(mixin):
mixin.write_message("hello")


class NoAuthRulesWebsocketHandler(MockHandler):
class MockJupyterHandler(MockHandler, JupyterHandler):
pass


class PermissiveWebsocketHandler(MockHandler):
class NoAuthRulesWebsocketHandler(MockJupyterHandler):
pass


class PermissiveWebsocketHandler(MockJupyterHandler):
@allow_unauthenticated
def get(self, *args, **kwargs) -> None:
return super().get(*args, **kwargs)
Expand Down Expand Up @@ -148,5 +153,38 @@ def fetch():
iidp = IndiscriminateIdentityProvider()
# should allow access with the user set be the identity provider
with patch.dict(jp_serverapp.web_app.settings, {"identity_provider": iidp}):
res = await fetch()
assert res.code == 200
ws = await fetch()
ws.close()


class PermissivePlainWebsocketHandler(MockHandler):
# note: inherits from MockHandler not MockJupyterHandler
@allow_unauthenticated
def get(self, *args, **kwargs) -> None:
return super().get(*args, **kwargs)


@pytest.mark.parametrize(
"jp_server_config",
[
{
"ServerApp": {
"allow_unauthenticated_access": False,
"identity_provider": IndiscriminateIdentityProvider(),
}
}
],
)
async def test_websocket_auth_warns_mixin_lacks_jupyter_handler(jp_serverapp, jp_ws_fetch):
app: ServerApp = jp_serverapp
app.web_app.add_handlers(
".*$",
[(url_path_join(app.base_url, "permissive"), PermissivePlainWebsocketHandler)],
)

with pytest.warns(
JupyterServerAuthWarning,
match="WebSocketMixin sub-class does not inherit from JupyterHandler",
):
ws = await jp_ws_fetch("permissive", headers={"Authorization": ""})
ws.close()

0 comments on commit 4265f4e

Please sign in to comment.