diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 942eab8ff9a..117fdbe8e59 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -239,6 +239,10 @@ properties: type: object description: Settings for Dask's embedded HTTP Server properties: + api-key: + type: string + description: | + API key required to access private HTTP API methods routes: type: array description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 74d59addb35..2587ea3012e 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -49,6 +49,7 @@ distributed: lease-timeout: 30s # Maximum interval to wait for a Client refresh before a lease is invalidated and released. http: + api-key: "" routes: - distributed.http.scheduler.prometheus - distributed.http.scheduler.info @@ -56,6 +57,7 @@ distributed: - distributed.http.health - distributed.http.proxy - distributed.http.statics + - distributed.http.scheduler.api allowed-imports: - dask diff --git a/distributed/http/scheduler/api.py b/distributed/http/scheduler/api.py index 0a710a58ffd..42300e14e91 100644 --- a/distributed/http/scheduler/api.py +++ b/distributed/http/scheduler/api.py @@ -1,10 +1,44 @@ from __future__ import annotations +import asyncio +import functools import json +import dask.config + from distributed.http.utils import RequestHandler +def require_auth(method_func): + @functools.wraps(method_func) + def wrapper(self): + auth = self.request.headers.get("Authorization", None) + key = dask.config.get("distributed.scheduler.http.api-key") + if key is False or ( + key and auth and auth.startswith("Bearer ") and key == auth.split(" ")[-1] + ): + if not asyncio.iscoroutinefunction(method_func): + return method_func(self) + else: + + async def tmp(): + return await method_func(self) + + return tmp() + else: + self.set_status(403, "Unauthorized") + if not asyncio.iscoroutinefunction(method_func): + return + else: + # When wrapping a coroutine we need to return a coroutine even if it just returns None + async def tmp(): + return + + return tmp() + + return wrapper + + class APIHandler(RequestHandler): def get(self): self.write("API V1") @@ -12,6 +46,7 @@ def get(self): class RetireWorkersHandler(RequestHandler): + @require_auth async def post(self): self.set_header("Content-Type", "application/json") scheduler = self.server diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index e86f004b47e..a9858a57b4c 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -251,60 +251,107 @@ async def test_eventstream(c, s, a, b): ws_client.close() -def test_api_disabled_by_default(): - assert "distributed.http.scheduler.api" not in dask.config.get( - "distributed.scheduler.http.routes" - ) +@gen_cluster(client=True, clean_kwargs={"threads": False}) +async def test_api(c, s, a, b): + async with aiohttp.ClientSession() as session: + async with session.get( + "http://localhost:%d/api/v1" % s.http_server.port + ) as resp: + assert resp.status == 200 + assert resp.headers["Content-Type"] == "text/plain" + assert (await resp.text()) == "API V1" + + +@gen_cluster(client=True, clean_kwargs={"threads": False}) +async def test_api_auth_defaults(c, s, a, b): + async with aiohttp.ClientSession() as session: + url = f"http://localhost:{s.http_server.port}/api/v1/retire_workers" + params = {"workers": [a.address, b.address]} + + async with session.post(url, json=params) as resp: + assert resp.status == 403 @gen_cluster( client=True, clean_kwargs={"threads": False}, config={ - "distributed.scheduler.http.routes": DEFAULT_ROUTES - + ["distributed.http.scheduler.api"] + "distributed.scheduler.http.api-key": "abc123", }, ) -async def test_api(c, s, a, b): +async def test_api_auth(c, s, a, b): async with aiohttp.ClientSession() as session: - async with session.get( - "http://localhost:%d/api/v1" % s.http_server.port + url = f"http://localhost:{s.http_server.port}/api/v1/retire_workers" + params = {"workers": [a.address, b.address]} + + async with session.post(url, json=params) as resp: + assert resp.status == 403 + + async with session.post( + url, json=params, headers={"Authorization": "Bearer foobarbaz"} + ) as resp: + assert resp.status == 403 + + async with session.post( + url, json=params, headers={"Authorization": "Bearer abc"} + ) as resp: + assert resp.status == 403 + + async with session.post( + url, json=params, headers={"Authorization": "Bearer "} + ) as resp: + assert resp.status == 403 + + async with session.post( + url, json=params, headers={"Authorization": "Bearer abc123456"} + ) as resp: + assert resp.status == 403 + + async with session.post( + url, json=params, headers={"Authorization": "Bearer abc123"} ) as resp: assert resp.status == 200 - assert resp.headers["Content-Type"] == "text/plain" - assert (await resp.text()) == "API V1" @gen_cluster( client=True, clean_kwargs={"threads": False}, config={ - "distributed.scheduler.http.routes": DEFAULT_ROUTES - + ["distributed.http.scheduler.api"] + "distributed.scheduler.http.api-key": False, }, ) -async def test_retire_workers(c, s, a, b): +async def test_api_auth_disabled(c, s, a, b): async with aiohttp.ClientSession() as session: + url = f"http://localhost:{s.http_server.port}/api/v1/retire_workers" params = {"workers": [a.address, b.address]} - async with session.post( - "http://localhost:%d/api/v1/retire_workers" % s.http_server.port, - json=params, - ) as resp: + + async with session.post(url, json=params) as resp: assert resp.status == 200 - assert resp.headers["Content-Type"] == "application/json" - retired_workers_info = json.loads(await resp.text()) - assert len(retired_workers_info) == 2 @gen_cluster( client=True, clean_kwargs={"threads": False}, config={ - "distributed.scheduler.http.routes": DEFAULT_ROUTES - + ["distributed.http.scheduler.api"] + "distributed.scheduler.http.api-key": "abc123", }, ) -async def test_get_workers(c, s, a, b): +async def test_api_retire_workers(c, s, a, b): + async with aiohttp.ClientSession() as session: + url = f"http://localhost:{s.http_server.port}/api/v1/retire_workers" + params = {"workers": [a.address, b.address]} + + async with session.post( + url, json=params, headers={"Authorization": "Bearer abc123"} + ) as resp: + assert resp.status == 200 + assert resp.headers["Content-Type"] == "application/json" + retired_workers_info = json.loads(await resp.text()) + assert len(retired_workers_info) == 2 + + +@gen_cluster(client=True, clean_kwargs={"threads": False}) +async def test_api_get_workers(c, s, a, b): async with aiohttp.ClientSession() as session: async with session.get( "http://localhost:%d/api/v1/get_workers" % s.http_server.port @@ -316,15 +363,8 @@ async def test_get_workers(c, s, a, b): assert set(workers_address) == {a.address, b.address} -@gen_cluster( - client=True, - clean_kwargs={"threads": False}, - config={ - "distributed.scheduler.http.routes": DEFAULT_ROUTES - + ["distributed.http.scheduler.api"] - }, -) -async def test_adaptive_target(c, s, a, b): +@gen_cluster(client=True, clean_kwargs={"threads": False}) +async def test_api_adaptive_target(c, s, a, b): async with aiohttp.ClientSession() as session: async with session.get( "http://localhost:%d/api/v1/adaptive_target" % s.http_server.port diff --git a/docs/source/http_services.rst b/docs/source/http_services.rst index 31bb62292a7..047fcd46e60 100644 --- a/docs/source/http_services.rst +++ b/docs/source/http_services.rst @@ -54,7 +54,7 @@ Scheduler API Scheduler methods exposed by the API with an example of the request body they take -- ``/api/v1/retire_workers`` : retire certain workers on the scheduler +- ``/api/v1/retire_workers`` : retire certain workers on the scheduler (requires auth) .. code-block:: json @@ -63,7 +63,18 @@ Scheduler methods exposed by the API with an example of the request body they ta } - ``/api/v1/get_workers`` : get all workers on the scheduler -- ``/api/v1/adaptive_target`` : get the target number of workers based on the scheduler's load +- ``/api/v1/adaptive_target`` : get the target number of workers based on the scheduler's load + +.. note:: + API methods that modify the state of the scheduler require an API key to be set in the ``Authorization`` header. + This API key can be set via ``distributed.scheduler.http.api-key`` in the Dask config. + + .. code-block:: console + + $ curl -H "Authorization: Bearer {api-key}" http://localhost:8787/api/v1/retire_workers + +.. warning:: + API authentication can be disabled by setting ``distributed.scheduler.http.api-key`` to ``False`` but this is not recommended. Individual bokeh plots ----------------------