diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c5f4c0fba6b..1ef637de3fd 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1457,9 +1457,14 @@ def __init__( dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") / 2.0 ) - self.WORKER_SATURATION = dask.config.get( - "distributed.scheduler.worker-saturation" - ) + + sat = dask.config.get("distributed.scheduler.worker-saturation") + try: + self.WORKER_SATURATION = float(sat) + except ValueError: + raise ValueError( + f"Unsupported `distributed.scheduler.worker-saturation` value {sat!r}. Must be a float." + ) self.transition_counter = 0 self._idle_transition_counter = 0 self.transition_counter_max = transition_counter_max diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 1df0ceb1619..94ab7439f67 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -410,9 +410,10 @@ async def test_queued_remove_add_worker(c, s, a, b): @pytest.mark.parametrize( - "saturation, expected_task_counts", + "saturation_config, expected_task_counts", [ (2.5, (5, 2)), + ("2.5", (5, 2)), (2.0, (4, 2)), (1.0, (2, 1)), (-1.0, (1, 1)), @@ -421,16 +422,17 @@ async def test_queued_remove_add_worker(c, s, a, b): ], ) def test_saturation_factor( - saturation: int | float, expected_task_counts: tuple[int, int] + saturation_config: int | float | str, expected_task_counts: tuple[int, int] ) -> None: @gen_cluster( client=True, nthreads=[("", 2), ("", 1)], config={ - "distributed.scheduler.worker-saturation": saturation, + "distributed.scheduler.worker-saturation": saturation_config, }, ) async def _test_saturation_factor(c, s, a, b): + saturation = float(saturation_config) event = Event() fs = c.map( lambda _: event.wait(), range(10), key=[f"wait-{i}" for i in range(10)] @@ -453,6 +455,14 @@ async def _test_saturation_factor(c, s, a, b): _test_saturation_factor() +@gen_test() +async def test_bad_saturation_factor(): + with pytest.raises(ValueError, match="foo"): + with dask.config.set({"distributed.scheduler.worker-saturation": "foo"}): + async with Scheduler(dashboard_address=":0", validate=True): + pass + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) async def test_move_data_over_break_restrictions(client, s, a, b, c): [x] = await client.scatter([1], workers=b.address)