diff --git a/jupyter_server/auth/decorator.py b/jupyter_server/auth/decorator.py index a92866b4e8..46090b2555 100644 --- a/jupyter_server/auth/decorator.py +++ b/jupyter_server/auth/decorator.py @@ -85,3 +85,58 @@ async def inner(self, *args, **kwargs): return cast(FuncT, wrapper(method)) return cast(FuncT, wrapper) + + +def allow_unauthenticated(method: FuncT) -> FuncT: + """A decorator for tornado.web.RequestHandler methods + that allows any user to make the following request. + + Selectively disables the 'authentication' layer of REST API which + is active when `ServerApp.allow_unauthenticated_access = False`. + + To be used exclusively on endpoints which may be considered public, + for example the login page handler. + + .. versionadded:: 2.13 + + Parameters + ---------- + method : bound callable + the endpoint method to remove authentication from. + """ + + @wraps(method) + def wrapper(self, *args, **kwargs): + return method(self, *args, **kwargs) + + setattr(wrapper, "__allow_unauthenticated", True) + + return cast(FuncT, wrapper) + + +def ws_authenticated(method: FuncT) -> FuncT: + """A decorator for websockets derived from `WebSocketHandler` + that authenticates user before allowing to proceed. + + Differently from tornado.web.authenticated, does not redirect + to the login page, which would be meaningless for websockets. + + .. versionadded:: 2.13 + + Parameters + ---------- + method : bound callable + the endpoint method to add authentication for. + """ + + @wraps(method) + def wrapper(self, *args, **kwargs): + user = self.current_user + if user is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise HTTPError(403) + return method(self, *args, **kwargs) + + setattr(wrapper, "__allow_unauthenticated", False) + + return cast(FuncT, wrapper) diff --git a/jupyter_server/auth/login.py b/jupyter_server/auth/login.py index 22832df341..0628681224 100644 --- a/jupyter_server/auth/login.py +++ b/jupyter_server/auth/login.py @@ -9,6 +9,7 @@ from tornado.escape import url_escape from ..base.handlers import JupyterHandler +from .decorator import allow_unauthenticated from .security import passwd_check, set_password @@ -73,6 +74,7 @@ def _redirect_safe(self, url, default=None): url = default self.redirect(url) + @allow_unauthenticated def get(self): """Get the login form.""" if self.current_user: @@ -81,6 +83,7 @@ def get(self): else: self._render() + @allow_unauthenticated def post(self): """Post a login.""" user = self.current_user = self.identity_provider.process_login_form(self) @@ -110,6 +113,7 @@ def passwd_check(self, a, b): """Check a passwd.""" return passwd_check(a, b) + @allow_unauthenticated def post(self): """Post a login form.""" typed_password = self.get_argument("password", default="") diff --git a/jupyter_server/auth/logout.py b/jupyter_server/auth/logout.py index 3db7f796ba..584404cf4c 100644 --- a/jupyter_server/auth/logout.py +++ b/jupyter_server/auth/logout.py @@ -3,11 +3,13 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. from ..base.handlers import JupyterHandler +from .decorator import allow_unauthenticated class LogoutHandler(JupyterHandler): """An auth logout handler.""" + @allow_unauthenticated def get(self): """Handle a logout.""" self.identity_provider.clear_login_cookie(self) diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 10f36b7679..b6f0c90232 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -14,7 +14,7 @@ import warnings from http.client import responses from logging import Logger -from typing import TYPE_CHECKING, Any, Awaitable, Sequence, cast +from typing import TYPE_CHECKING, Any, Awaitable, Coroutine, Sequence, cast from urllib.parse import urlparse import prometheus_client @@ -29,7 +29,7 @@ from jupyter_server import CallContext from jupyter_server._sysinfo import get_sys_info from jupyter_server._tz import utcnow -from jupyter_server.auth.decorator import authorized +from jupyter_server.auth.decorator import allow_unauthenticated, authorized from jupyter_server.auth.identity import User from jupyter_server.i18n import combine_translations from jupyter_server.services.security import csp_report_uri @@ -589,7 +589,7 @@ def check_host(self) -> bool: ) return allow - async def prepare(self) -> Awaitable[None] | None: # type:ignore[override] + async def prepare(self, *, _redirect_to_login=True) -> Awaitable[None] | None: # type:ignore[override] """Prepare a response.""" # Set the current Jupyter Handler context variable. CallContext.set(CallContext.JUPYTER_HANDLER, self) @@ -630,6 +630,25 @@ async def prepare(self) -> Awaitable[None] | None: # type:ignore[override] self.set_cors_headers() if self.request.method not in {"GET", "HEAD", "OPTIONS"}: self.check_xsrf_cookie() + + if not self.settings.get("allow_unauthenticated_access", False): + if not self.request.method: + raise HTTPError(403) + method = getattr(self, self.request.method.lower()) + if not getattr(method, "__allow_unauthenticated", False): + if _redirect_to_login: + # reuse `web.authenticated` logic, which redirects to the login + # page on GET and HEAD and otherwise raises 403 + return web.authenticated(lambda _: super().prepare())(self) + else: + # raise 403 if user is not known without redirecting to login page + user = self.current_user + if user is None: + self.log.warning( + f"Couldn't authenticate {self.__class__.__name__} connection" + ) + raise web.HTTPError(403) + return super().prepare() # --------------------------------------------------------------- @@ -726,7 +745,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None: class APIHandler(JupyterHandler): """Base class for API handlers""" - async def prepare(self) -> None: + async def prepare(self) -> None: # type:ignore[override] """Prepare an API response.""" await super().prepare() if not self.check_origin(): @@ -794,6 +813,7 @@ def finish(self, *args: Any, **kwargs: Any) -> Future[Any]: self.set_header("Content-Type", set_content_type) return super().finish(*args, **kwargs) + @allow_unauthenticated def options(self, *args: Any, **kwargs: Any) -> None: """Get the options.""" if "Access-Control-Allow-Headers" in self.settings.get("headers", {}): @@ -837,7 +857,7 @@ def options(self, *args: Any, **kwargs: Any) -> None: class Template404(JupyterHandler): """Render our 404 template""" - async def prepare(self) -> None: + async def prepare(self) -> None: # type:ignore[override] """Prepare a 404 response.""" await super().prepare() raise web.HTTPError(404) @@ -1002,6 +1022,18 @@ def compute_etag(self) -> str | None: """Compute the etag.""" return None + # access is allowed as this class is used to serve static assets on login page + # TODO: create an allow-list of files used on login page and remove this decorator + @allow_unauthenticated + def get(self, path: str, include_body: bool = True) -> Coroutine[Any, Any, None]: + return super().get(path, include_body) + + # access is allowed as this class is used to serve static assets on login page + # TODO: create an allow-list of files used on login page and remove this decorator + @allow_unauthenticated + def head(self, path: str) -> Awaitable[None]: + return super().head(path) + @classmethod def get_absolute_path(cls, roots: Sequence[str], path: str) -> str: """locate a file to serve on our static file search path""" @@ -1036,6 +1068,7 @@ class APIVersionHandler(APIHandler): _track_activity = False + @allow_unauthenticated def get(self) -> None: """Get the server version info.""" # not authenticated, so give as few info as possible @@ -1048,6 +1081,7 @@ class TrailingSlashHandler(web.RequestHandler): This should be the first, highest priority handler. """ + @allow_unauthenticated def get(self) -> None: """Handle trailing slashes in a get.""" assert self.request.uri is not None @@ -1064,6 +1098,7 @@ def get(self) -> None: class MainHandler(JupyterHandler): """Simple handler for base_url.""" + @allow_unauthenticated def get(self) -> None: """Get the main template.""" html = self.render_template("main.html") @@ -1104,18 +1139,20 @@ async def redirect_to_files(self: Any, path: str) -> None: self.log.debug("Redirecting %s to %s", self.request.path, url) self.redirect(url) + @allow_unauthenticated async def get(self, path: str = "") -> None: return await self.redirect_to_files(self, path) class RedirectWithParams(web.RequestHandler): - """Sam as web.RedirectHandler, but preserves URL parameters""" + """Same as web.RedirectHandler, but preserves URL parameters""" def initialize(self, url: str, permanent: bool = True) -> None: """Initialize a redirect handler.""" self._url = url self._permanent = permanent + @allow_unauthenticated def get(self) -> None: """Get a redirect.""" sep = "&" if "?" in self._url else "?" @@ -1128,6 +1165,7 @@ class PrometheusMetricsHandler(JupyterHandler): Return prometheus metrics for this server """ + @allow_unauthenticated def get(self) -> None: """Get prometheus metrics.""" if self.settings["authenticate_prometheus"] and not self.logged_in: @@ -1137,6 +1175,18 @@ def get(self) -> None: self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY)) +class PublicStaticFileHandler(web.StaticFileHandler): + """Same as web.StaticFileHandler, but decorated to acknowledge that auth is not required.""" + + @allow_unauthenticated + def head(self, path: str) -> Awaitable[None]: + return super().head(path) + + @allow_unauthenticated + def get(self, path: str, include_body: bool = True) -> Coroutine[Any, Any, None]: + return super().get(path, include_body) + + # ----------------------------------------------------------------------------- # URL pattern fragments for reuse # ----------------------------------------------------------------------------- @@ -1152,6 +1202,6 @@ def get(self) -> None: default_handlers = [ (r".*/", TrailingSlashHandler), (r"api", APIVersionHandler), - (r"/(robots\.txt|favicon\.ico)", web.StaticFileHandler), + (r"/(robots\.txt|favicon\.ico)", PublicStaticFileHandler), (r"/metrics", PrometheusMetricsHandler), ] diff --git a/jupyter_server/base/websocket.py b/jupyter_server/base/websocket.py index a27b7a72a7..8780d7afc4 100644 --- a/jupyter_server/base/websocket.py +++ b/jupyter_server/base/websocket.py @@ -1,11 +1,15 @@ """Base websocket classes.""" import re +import warnings from typing import Optional, no_type_check from urllib.parse import urlparse -from tornado import ioloop +from tornado import ioloop, web from tornado.iostream import IOStream +from jupyter_server.base.handlers import JupyterHandler +from jupyter_server.utils import JupyterServerAuthWarning + # ping interval for keeping websockets alive (30 seconds) WS_PING_INTERVAL = 30000 @@ -82,6 +86,40 @@ def check_origin(self, origin: Optional[str] = None) -> bool: def clear_cookie(self, *args, **kwargs): """meaningless for websockets""" + @no_type_check + def _maybe_auth(self): + """Verify authentication if required. + + Only used when the websocket class does not inherit from JupyterHandler. + """ + if not self.settings.get("allow_unauthenticated_access", False): + if not self.request.method: + raise web.HTTPError(403) + method = getattr(self, self.request.method.lower()) + if not getattr(method, "__allow_unauthenticated", False): + # rather than re-using `web.authenticated` which also redirects + # to login page on GET, just raise 403 if user is not known + user = self.current_user + if user is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise web.HTTPError(403) + + @no_type_check + def prepare(self, *args, **kwargs): + """Handle a get request.""" + if not isinstance(self, JupyterHandler): + should_authenticate = not self.settings.get("allow_unauthenticated_access", False) + if "identity_provider" in self.settings and should_authenticate: + warnings.warn( + "WebSocketMixin sub-class does not inherit from JupyterHandler" + " preventing proper authentication using custom identity provider.", + JupyterServerAuthWarning, + stacklevel=2, + ) + self._maybe_auth() + return super().prepare(*args, **kwargs) + return super().prepare(*args, **kwargs, _redirect_to_login=False) + @no_type_check def open(self, *args, **kwargs): """Open the websocket.""" diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index aeeab5a94d..b676086747 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -358,7 +358,7 @@ def _prepare_handlers(self): ) new_handlers.append(handler) - webapp.add_handlers(".*$", new_handlers) # type:ignore[arg-type] + webapp.add_handlers(".*$", new_handlers) def _prepare_templates(self): """Add templates to web app settings if extension has templates.""" diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 5002715c90..821650d6c7 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -41,6 +41,7 @@ from tornado.httputil import url_concat from tornado.log import LogFormatter, access_log, app_log, gen_log from tornado.netutil import bind_sockets +from tornado.routing import Matcher, Rule if not sys.platform.startswith("win"): from tornado.netutil import bind_unix_socket @@ -115,6 +116,7 @@ ) from jupyter_server.services.sessions.sessionmanager import SessionManager from jupyter_server.utils import ( + JupyterServerAuthWarning, check_pid, fetch, unix_socket_in_use, @@ -257,7 +259,7 @@ def __init__( warnings.warn( "authorizer unspecified. Using permissive AllowAllAuthorizer." " Specify an authorizer to avoid this message.", - RuntimeWarning, + JupyterServerAuthWarning, stacklevel=2, ) authorizer = AllowAllAuthorizer(parent=jupyter_app, identity_provider=identity_provider) @@ -284,8 +286,49 @@ def __init__( ) handlers = self.init_handlers(default_services, settings) + undecorated_methods = [] + for matcher, handler, *_ in handlers: + undecorated_methods.extend(self._check_handler_auth(matcher, handler)) + + if undecorated_methods: + message = ( + "Core endpoints without @allow_unauthenticated, @ws_authenticated, nor @web.authenticated:\n" + + "\n".join(undecorated_methods) + ) + if jupyter_app.allow_unauthenticated_access: + warnings.warn( + message, + JupyterServerAuthWarning, + stacklevel=2, + ) + else: + raise Exception(message) + super().__init__(handlers, **settings) + def add_handlers(self, host_pattern, host_handlers): + undecorated_methods = [] + for rule in host_handlers: + if isinstance(rule, Rule): + matcher = rule.matcher + handler = rule.target + else: + matcher, handler, *_ = rule + undecorated_methods.extend(self._check_handler_auth(matcher, handler)) + + if undecorated_methods and not self.settings["allow_unauthenticated_access"]: + message = ( + "Extension endpoints without @allow_unauthenticated, @ws_authenticated, nor @web.authenticated:\n" + + "\n".join(undecorated_methods) + ) + warnings.warn( + message, + JupyterServerAuthWarning, + stacklevel=2, + ) + + return super().add_handlers(host_pattern, host_handlers) + def init_settings( self, jupyter_app, @@ -376,6 +419,7 @@ def init_settings( "login_url": url_path_join(base_url, "/login"), "xsrf_cookies": True, "disable_check_xsrf": jupyter_app.disable_check_xsrf, + "allow_unauthenticated_access": jupyter_app.allow_unauthenticated_access, "allow_remote_access": jupyter_app.allow_remote_access, "local_hostnames": jupyter_app.local_hostnames, "authenticate_prometheus": jupyter_app.authenticate_prometheus, @@ -494,6 +538,40 @@ def last_activity(self): sources.extend(self.settings["last_activity_times"].values()) return max(sources) + def _check_handler_auth( + self, matcher: t.Union[str, Matcher], handler: type[web.RequestHandler] + ): + missing_authentication = [] + for method_name in handler.SUPPORTED_METHODS: + method = getattr(handler, method_name.lower()) + is_unimplemented = method == web.RequestHandler._unimplemented_method + is_allowlisted = hasattr(method, "__allow_unauthenticated") + is_blocklisted = _has_tornado_web_authenticated(method) + if not is_unimplemented and not is_allowlisted and not is_blocklisted: + missing_authentication.append( + f"- {method_name} of {handler.__name__} registered for {matcher}" + ) + return missing_authentication + + +def _has_tornado_web_authenticated(method: t.Callable[..., t.Any]) -> bool: + """Check if given method was decorated with @web.authenticated. + + Note: it is ok if we reject on @authorized @web.authenticated + because the correct order is @web.authenticated @authorized. + """ + if not hasattr(method, "__wrapped__"): + return False + if not hasattr(method, "__code__"): + return False + code = method.__code__ + if hasattr(code, "co_qualname"): + # new in 3.11 + return code.co_qualname.startswith("authenticated") # type:ignore[no-any-return] + elif hasattr(code, "co_filename"): + return code.co_filename.replace("\\", "/").endswith("tornado/web.py") + return False + class JupyterPasswordApp(JupyterApp): """Set a password for the Jupyter server. @@ -1222,6 +1300,33 @@ def _deprecated_password_config(self, change: t.Any) -> None: """, ) + _allow_unauthenticated_access_env = "JUPYTER_SERVER_ALLOW_UNAUTHENTICATED_ACCESS" + + allow_unauthenticated_access = Bool( + True, + config=True, + help=f"""Allow unauthenticated access to endpoints without authentication rule. + + When set to `True` (default in jupyter-server 2.0, subject to change + in the future), any request to an endpoint without an authentication rule + (either `@tornado.web.authenticated`, or `@allow_unauthenticated`) + will be permitted, regardless of whether user has logged in or not. + + When set to `False`, logging in will be required for access to each endpoint, + excluding the endpoints marked with `@allow_unauthenticated` decorator. + + This option can be configured using `{_allow_unauthenticated_access_env}` + environment variable: any non-empty value other than "true" and "yes" will + prevent unauthenticated access to endpoints without `@allow_unauthenticated`. + """, + ) + + @default("allow_unauthenticated_access") + def _allow_unauthenticated_access_default(self): + if os.getenv(self._allow_unauthenticated_access_env): + return os.environ[self._allow_unauthenticated_access_env].lower() in ["true", "yes"] + return True + allow_remote_access = Bool( config=True, help="""Allow requests where the Host header doesn't point to a local server diff --git a/jupyter_server/services/api/handlers.py b/jupyter_server/services/api/handlers.py index 8b9e44f9cf..625d9ca372 100644 --- a/jupyter_server/services/api/handlers.py +++ b/jupyter_server/services/api/handlers.py @@ -25,6 +25,11 @@ def initialize(self): """Initialize the API spec handler.""" web.StaticFileHandler.initialize(self, path=os.path.dirname(__file__)) + @web.authenticated + @authorized + def head(self): + return self.get("api.yaml", include_body=False) + @web.authenticated @authorized def get(self): diff --git a/jupyter_server/services/contents/handlers.py b/jupyter_server/services/contents/handlers.py index a7c7ffff17..8044310f60 100644 --- a/jupyter_server/services/contents/handlers.py +++ b/jupyter_server/services/contents/handlers.py @@ -16,7 +16,7 @@ from jupyter_core.utils import ensure_async from tornado import web -from jupyter_server.auth.decorator import authorized +from jupyter_server.auth.decorator import allow_unauthenticated, authorized from jupyter_server.base.handlers import APIHandler, JupyterHandler, path_regex from jupyter_server.utils import url_escape, url_path_join @@ -400,6 +400,7 @@ class NotebooksRedirectHandler(JupyterHandler): "DELETE", ) # type:ignore[assignment] + @allow_unauthenticated def get(self, path): """Handle a notebooks redirect.""" self.log.warning("/api/notebooks is deprecated, use /api/contents") diff --git a/jupyter_server/services/events/handlers.py b/jupyter_server/services/events/handlers.py index 1ca28b948c..ce580048f2 100644 --- a/jupyter_server/services/events/handlers.py +++ b/jupyter_server/services/events/handlers.py @@ -12,7 +12,7 @@ from jupyter_core.utils import ensure_async from tornado import web, websocket -from jupyter_server.auth.decorator import authorized +from jupyter_server.auth.decorator import authorized, ws_authenticated from jupyter_server.base.handlers import JupyterHandler from ...base.handlers import APIHandler @@ -29,16 +29,11 @@ class SubscribeWebsocket( auth_resource = AUTH_RESOURCE async def pre_get(self): - """Handles authentication/authorization when + """Handles authorization when attempting to subscribe to events emitted by Jupyter Server's eventbus. """ - # authenticate the request before opening the websocket user = self.current_user - if user is None: - self.log.warning("Couldn't authenticate WebSocket connection") - raise web.HTTPError(403) - # authorize the user. authorized = await ensure_async( self.authorizer.is_authorized(self, user, "execute", "events") @@ -46,6 +41,7 @@ async def pre_get(self): if not authorized: raise web.HTTPError(403) + @ws_authenticated async def get(self, *args, **kwargs): """Get an event socket.""" await ensure_async(self.pre_get()) diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index 4c2c1c8914..374df76f3e 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -6,6 +6,7 @@ from tornado import web from tornado.websocket import WebSocketHandler +from jupyter_server.auth.decorator import ws_authenticated from jupyter_server.base.handlers import JupyterHandler from jupyter_server.base.websocket import WebSocketMixin @@ -34,11 +35,7 @@ def get_compression_options(self): async def pre_get(self): """Handle a pre_get.""" - # authenticate first user = self.current_user - if user is None: - self.log.warning("Couldn't authenticate WebSocket connection") - raise web.HTTPError(403) # authorize the user. authorized = await ensure_async( @@ -61,6 +58,7 @@ async def pre_get(self): if hasattr(self.connection, "prepare"): await self.connection.prepare() + @ws_authenticated async def get(self, kernel_id): """Handle a get request for a kernel.""" self.kernel_id = kernel_id diff --git a/jupyter_server/utils.py b/jupyter_server/utils.py index 968d1fd27c..8ef0e21a7f 100644 --- a/jupyter_server/utils.py +++ b/jupyter_server/utils.py @@ -439,3 +439,11 @@ def import_item(name: str) -> Any: else: # called with un-dotted string return __import__(parts[0]) + + +class JupyterServerAuthWarning(RuntimeWarning): + """Emitted when authentication configuration issue is detected. + + Intended for filtering out expected warnings in tests, including + downstream tests, rather than for users to silence this warning. + """ diff --git a/pyproject.toml b/pyproject.toml index be3a03d49b..a89f887451 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ test = [ "ipykernel", "pytest-console-scripts", "pytest-timeout", - "pytest-jupyter[server]>=0.4", + "pytest-jupyter[server]>=0.7", "pytest>=7.0", "requests", "pre-commit", @@ -241,6 +241,7 @@ filterwarnings = [ "error", "ignore:datetime.datetime.utc:DeprecationWarning", "module:add_callback_from_signal is deprecated:DeprecationWarning", + "ignore::jupyter_server.utils.JupyterServerAuthWarning" ] [tool.coverage.report] diff --git a/tests/base/test_handlers.py b/tests/base/test_handlers.py index 370100fe9d..c2960c1775 100644 --- a/tests/base/test_handlers.py +++ b/tests/base/test_handlers.py @@ -1,12 +1,15 @@ """Test Base Handlers""" import os import warnings -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +import pytest +from tornado.httpclient import HTTPClientError from tornado.httpserver import HTTPRequest from tornado.httputil import HTTPHeaders -from jupyter_server.auth import AllowAllAuthorizer, IdentityProvider +from jupyter_server.auth import AllowAllAuthorizer, IdentityProvider, User +from jupyter_server.auth.decorator import allow_unauthenticated from jupyter_server.base.handlers import ( APIHandler, APIVersionHandler, @@ -18,6 +21,7 @@ RedirectWithParams, ) from jupyter_server.serverapp import ServerApp +from jupyter_server.utils import url_path_join def test_authenticated_handler(jp_serverapp): @@ -61,6 +65,134 @@ def test_jupyter_handler(jp_serverapp): assert handler.check_referer() is True +class NoAuthRulesHandler(JupyterHandler): + def options(self) -> None: + self.finish({}) + + def get(self) -> None: + self.finish({}) + + +class PermissiveHandler(JupyterHandler): + @allow_unauthenticated + def options(self) -> None: + self.finish({}) + + +@pytest.mark.parametrize( + "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": True}}] +) +async def test_jupyter_handler_auth_permissive(jp_serverapp, jp_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [ + (url_path_join(app.base_url, "no-rules"), NoAuthRulesHandler), + (url_path_join(app.base_url, "permissive"), PermissiveHandler), + ], + ) + + # should always permit access when `@allow_unauthenticated` is used + res = await jp_fetch("permissive", method="OPTIONS", headers={"Authorization": ""}) + assert res.code == 200 + + # should allow access when no authentication rules are set up + res = await jp_fetch("no-rules", method="OPTIONS", headers={"Authorization": ""}) + assert res.code == 200 + + +@pytest.mark.parametrize( + "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": False}}] +) +async def test_jupyter_handler_auth_required(jp_serverapp, jp_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [ + (url_path_join(app.base_url, "no-rules"), NoAuthRulesHandler), + (url_path_join(app.base_url, "permissive"), PermissiveHandler), + ], + ) + + # should always permit access when `@allow_unauthenticated` is used + res = await jp_fetch("permissive", method="OPTIONS", headers={"Authorization": ""}) + assert res.code == 200 + + # should forbid access when no authentication rules are set up: + # - by redirecting to login page for GET and HEAD + res = await jp_fetch( + "no-rules", + method="GET", + headers={"Authorization": ""}, + follow_redirects=False, + raise_error=False, + ) + assert res.code == 302 + assert "/login" in res.headers["Location"] + + # - by returning 403 immediately for other requests + with pytest.raises(HTTPClientError) as exception: + res = await jp_fetch("no-rules", method="OPTIONS", headers={"Authorization": ""}) + assert exception.value.code == 403 + + +@pytest.mark.parametrize( + "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": False}}] +) +async def test_jupyter_handler_auth_calls_prepare(jp_serverapp, jp_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [ + (url_path_join(app.base_url, "no-rules"), NoAuthRulesHandler), + (url_path_join(app.base_url, "permissive"), PermissiveHandler), + ], + ) + + # should call `super.prepare()` in `@allow_unauthenticated` code path + with patch.object(JupyterHandler, "prepare", return_value=None) as mock: + res = await jp_fetch("permissive", method="OPTIONS") + assert res.code == 200 + assert mock.call_count == 1 + + # should call `super.prepare()` in code path that checks authentication + with patch.object(JupyterHandler, "prepare", return_value=None) as mock: + res = await jp_fetch("no-rules", method="OPTIONS") + assert res.code == 200 + assert mock.call_count == 1 + + +class IndiscriminateIdentityProvider(IdentityProvider): + async def get_user(self, handler): + return User(username="test") + + +@pytest.mark.parametrize( + "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": False}}] +) +async def test_jupyter_handler_auth_respsects_identity_provider(jp_serverapp, jp_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [(url_path_join(app.base_url, "no-rules"), NoAuthRulesHandler)], + ) + + def fetch(): + return jp_fetch("no-rules", method="OPTIONS", headers={"Authorization": ""}) + + # If no identity provider is set the following request should fail + # because the default tornado user would not be found: + with pytest.raises(HTTPClientError) as exception: + await fetch() + assert exception.value.code == 403 + + iidp = IndiscriminateIdentityProvider() + # should allow access with the user set be the identity provider + with patch.dict(jp_serverapp.web_app.settings, {"identity_provider": iidp}): + res = await fetch() + assert res.code == 200 + + def test_api_handler(jp_serverapp): app: ServerApp = jp_serverapp headers = HTTPHeaders({"Origin": "foo"}) diff --git a/tests/base/test_websocket.py b/tests/base/test_websocket.py index ee6ee3ee62..17ee009227 100644 --- a/tests/base/test_websocket.py +++ b/tests/base/test_websocket.py @@ -1,15 +1,20 @@ """Test Base Websocket classes""" import logging import time -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from tornado.httpclient import HTTPClientError from tornado.httpserver import HTTPRequest from tornado.httputil import HTTPHeaders from tornado.websocket import WebSocketClosedError, WebSocketHandler +from jupyter_server.auth import IdentityProvider, User +from jupyter_server.auth.decorator import allow_unauthenticated +from jupyter_server.base.handlers import JupyterHandler from jupyter_server.base.websocket import WebSocketMixin from jupyter_server.serverapp import ServerApp +from jupyter_server.utils import JupyterServerAuthWarning, url_path_join class MockHandler(WebSocketMixin, WebSocketHandler): @@ -60,3 +65,126 @@ async def test_ping_client_timeout(mixin): mixin.send_ping() with pytest.raises(WebSocketClosedError): mixin.write_message("hello") + + +class MockJupyterHandler(MockHandler, JupyterHandler): + pass + + +class NoAuthRulesWebsocketHandler(MockJupyterHandler): + pass + + +class PermissiveWebsocketHandler(MockJupyterHandler): + @allow_unauthenticated + def get(self, *args, **kwargs) -> None: + return super().get(*args, **kwargs) + + +@pytest.mark.parametrize( + "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": True}}] +) +async def test_websocket_auth_permissive(jp_serverapp, jp_ws_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [ + (url_path_join(app.base_url, "no-rules"), NoAuthRulesWebsocketHandler), + (url_path_join(app.base_url, "permissive"), PermissiveWebsocketHandler), + ], + ) + + # should always permit access when `@allow_unauthenticated` is used + ws = await jp_ws_fetch("permissive", headers={"Authorization": ""}) + ws.close() + + # should allow access when no authentication rules are set up + ws = await jp_ws_fetch("no-rules", headers={"Authorization": ""}) + ws.close() + + +@pytest.mark.parametrize( + "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": False}}] +) +async def test_websocket_auth_required(jp_serverapp, jp_ws_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [ + (url_path_join(app.base_url, "no-rules"), NoAuthRulesWebsocketHandler), + (url_path_join(app.base_url, "permissive"), PermissiveWebsocketHandler), + ], + ) + + # should always permit access when `@allow_unauthenticated` is used + ws = await jp_ws_fetch("permissive", headers={"Authorization": ""}) + ws.close() + + # should forbid access when no authentication rules are set up + with pytest.raises(HTTPClientError) as exception: + ws = await jp_ws_fetch("no-rules", headers={"Authorization": ""}) + assert exception.value.code == 403 + + +class IndiscriminateIdentityProvider(IdentityProvider): + async def get_user(self, handler): + return User(username="test") + + +@pytest.mark.parametrize( + "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": False}}] +) +async def test_websocket_auth_respsects_identity_provider(jp_serverapp, jp_ws_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [(url_path_join(app.base_url, "no-rules"), NoAuthRulesWebsocketHandler)], + ) + + def fetch(): + return jp_ws_fetch("no-rules", headers={"Authorization": ""}) + + # If no identity provider is set the following request should fail + # because the default tornado user would not be found: + with pytest.raises(HTTPClientError) as exception: + await fetch() + assert exception.value.code == 403 + + iidp = IndiscriminateIdentityProvider() + # should allow access with the user set be the identity provider + with patch.dict(jp_serverapp.web_app.settings, {"identity_provider": iidp}): + ws = await fetch() + ws.close() + + +class PermissivePlainWebsocketHandler(MockHandler): + # note: inherits from MockHandler not MockJupyterHandler + @allow_unauthenticated + def get(self, *args, **kwargs) -> None: + return super().get(*args, **kwargs) + + +@pytest.mark.parametrize( + "jp_server_config", + [ + { + "ServerApp": { + "allow_unauthenticated_access": False, + "identity_provider": IndiscriminateIdentityProvider(), + } + } + ], +) +async def test_websocket_auth_warns_mixin_lacks_jupyter_handler(jp_serverapp, jp_ws_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [(url_path_join(app.base_url, "permissive"), PermissivePlainWebsocketHandler)], + ) + + with pytest.warns( + JupyterServerAuthWarning, + match="WebSocketMixin sub-class does not inherit from JupyterHandler", + ): + ws = await jp_ws_fetch("permissive", headers={"Authorization": ""}) + ws.close() diff --git a/tests/extension/test_handler.py b/tests/extension/test_handler.py index 3151cf2b4d..c351c6ef38 100644 --- a/tests/extension/test_handler.py +++ b/tests/extension/test_handler.py @@ -20,6 +20,47 @@ async def test_handler_template(jp_fetch, mock_template): assert r.code == 200 +@pytest.mark.parametrize( + "jp_server_config", + [ + { + "ServerApp": { + "allow_unauthenticated_access": False, + "jpserver_extensions": {"tests.extension.mockextensions": True}, + } + } + ], +) +async def test_handler_gets_blocked(jp_fetch, jp_server_config): + # should redirect to login page if authorization token is missing + r = await jp_fetch( + "mock", + method="GET", + headers={"Authorization": ""}, + follow_redirects=False, + raise_error=False, + ) + assert r.code == 302 + assert "/login" in r.headers["Location"] + # should still work if authorization token is present + r = await jp_fetch("mock", method="GET") + assert r.code == 200 + + +def test_serverapp_warns_of_unauthenticated_handler(jp_configurable_serverapp): + # should warn about the handler missing decorator when unauthenticated access forbidden + expected_warning = "Extension endpoints without @allow_unauthenticated, @ws_authenticated, nor @web.authenticated:" + with pytest.warns(RuntimeWarning, match=expected_warning) as record: + jp_configurable_serverapp(allow_unauthenticated_access=False) + assert any( + [ + "GET of MockExtensionTemplateHandler registered for /a%40b/mock_template" + in r.message.args[0] + for r in record + ] + ) + + @pytest.mark.parametrize( "jp_server_config", [ diff --git a/tests/test_serverapp.py b/tests/test_serverapp.py index df703f550c..e5fe9ddae8 100644 --- a/tests/test_serverapp.py +++ b/tests/test_serverapp.py @@ -9,16 +9,19 @@ import pytest from jupyter_core.application import NoStart +from tornado import web from traitlets import TraitError from traitlets.config import Config from traitlets.tests.utils import check_help_all_output +from jupyter_server.auth.decorator import allow_unauthenticated, authorized from jupyter_server.auth.security import passwd_check from jupyter_server.serverapp import ( JupyterPasswordApp, JupyterServerListApp, ServerApp, ServerWebApplication, + _has_tornado_web_authenticated, list_running_servers, random_ports, ) @@ -163,6 +166,26 @@ def test_server_password(tmp_path, jp_configurable_serverapp): passwd_check(sv.identity_provider.hashed_password, password) +@pytest.mark.parametrize( + "env,expected", + [ + ["yes", True], + ["Yes", True], + ["True", True], + ["true", True], + ["TRUE", True], + ["no", False], + ["nooo", False], + ["FALSE", False], + ["false", False], + ], +) +def test_allow_unauthenticated_env_var(jp_configurable_serverapp, env, expected): + with patch.dict("os.environ", {"JUPYTER_SERVER_ALLOW_UNAUTHENTICATED_ACCESS": env}): + app = jp_configurable_serverapp() + assert app.allow_unauthenticated_access == expected + + def test_list_running_servers(jp_serverapp, jp_web_app): servers = list(list_running_servers(jp_serverapp.runtime_dir)) assert len(servers) >= 1 @@ -617,3 +640,22 @@ def test_immutable_cache_trait(): serverapp.init_configurables() serverapp.init_webapp() assert serverapp.web_app.settings["static_immutable_cache"] == ["/test/immutable"] + + +def test(): + pass + + +@pytest.mark.parametrize( + "method, expected", + [ + [test, False], + [allow_unauthenticated(test), False], + [authorized(test), False], + [web.authenticated(test), True], + [web.authenticated(authorized(test)), True], + [authorized(web.authenticated(test)), False], # wrong order! + ], +) +def test_tornado_authentication_detection(method, expected): + assert _has_tornado_web_authenticated(method) == expected