From 0d367a8779f5e0a701bdb478f478edfc307eaaf1 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 May 2018 07:58:40 -0400 Subject: [PATCH] log when we go above maximum number of workers --- dask_kubernetes/core.py | 7 +++++-- dask_kubernetes/tests/test_core.py | 19 ++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/dask_kubernetes/core.py b/dask_kubernetes/core.py index ce2f50089..0c6b1033c 100644 --- a/dask_kubernetes/core.py +++ b/dask_kubernetes/core.py @@ -392,8 +392,11 @@ def scale_up(self, n, pods=None, **kwargs): -------- >>> cluster.scale_up(20) # ask for twenty workers """ - if dask.config.get('kubernetes.count.max') is not None: - n = min(n, dask.config.get('kubernetes.count.max')) + maximum = dask.config.get('kubernetes.count.max') + if maximum is not None and maximum < n: + logger.info("Tried to scale beyond maximum number of workers %d > %d", + n, maximum) + n = maximum pods = pods or self._cleanup_succeeded_pods(self.pods()) to_create = n - len(pods) new_pods = [] diff --git a/dask_kubernetes/tests/test_core.py b/dask_kubernetes/tests/test_core.py index 3f6f0d18c..0d5815ccc 100644 --- a/dask_kubernetes/tests/test_core.py +++ b/dask_kubernetes/tests/test_core.py @@ -8,7 +8,7 @@ import pytest from dask_kubernetes import KubeCluster, make_pod_spec from dask.distributed import Client, wait -from distributed.utils_test import loop # noqa: F401 +from distributed.utils_test import loop, captured_logger # noqa: F401 from distributed.utils import tmpfile import kubernetes from random import random @@ -456,3 +456,20 @@ def test_escape_username(pod_spec, loop, ns): assert '!' not in cluster.name finally: os.environ['LOGNAME'] = old_logname + + +def test_maximum(cluster): + with dask.config.set(**{'kubernetes.count.max': 1}): + with captured_logger('dask-kubernetes') as logger: + cluster.scale(10) + + start = time() + while len(cluster.scheduler.workers) <= 0: + sleep(0.1) + assert time() < start + 60 + + sleep(0.5) + assert len(cluster.scheduler.workers) == 1 + + result = logger.getvalue() + assert "scale beyond maximum number of workers" in result.lower()