diff --git a/distributed/__init__.py b/distributed/__init__.py index 9ac135d87dd..279796ff97b 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -29,8 +29,10 @@ from distributed.core import Status, connect, rpc from distributed.deploy import Adaptive, LocalCluster, SpecCluster, SSHCluster from distributed.diagnostics.plugin import ( + CondaInstall, Environ, NannyPlugin, + PackageInstall, PipInstall, SchedulerPlugin, UploadDirectory, @@ -109,6 +111,7 @@ def _(): "CancelledError", "Client", "CompatibleExecutor", + "CondaInstall", "Environ", "Event", "Future", @@ -118,6 +121,7 @@ def _(): "MultiLock", "Nanny", "NannyPlugin", + "PackageInstall", "PipInstall", "Pub", "Queue", diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 4b431fb8ff5..a8445feea19 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import logging import os import socket @@ -8,7 +9,7 @@ import uuid import zipfile from collections.abc import Awaitable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from dask.utils import funcname, tmpfile @@ -240,8 +241,8 @@ def _get_plugin_name(plugin: SchedulerPlugin | WorkerPlugin | NannyPlugin) -> st return funcname(type(plugin)) + "-" + str(uuid.uuid4()) -class PipInstall(WorkerPlugin): - """A Worker Plugin to pip install a set of packages +class PackageInstall(WorkerPlugin, abc.ABC): + """Abstract parent class for a worker plugin to install a set of packages This accepts a set of packages to install on all workers. You can also optionally ask for the worker to restart itself after @@ -256,29 +257,32 @@ class PipInstall(WorkerPlugin): Parameters ---------- - packages : List[str] - A list of strings to place after "pip install" command - pip_options : List[str] - Additional options to pass to pip. - restart : bool, default False - Whether or not to restart the worker after pip installing + packages + A list of packages (with optional versions) to install + restart + Whether or not to restart the worker after installing the packages Only functions if the worker has an attached nanny process - Examples + See Also -------- - >>> from dask.distributed import PipInstall - >>> plugin = PipInstall(packages=["scikit-learn"], pip_options=["--upgrade"]) - - >>> client.register_worker_plugin(plugin) + CondaInstall + PipInstall """ - name = "pip" + INSTALLER: ClassVar[str] - def __init__(self, packages, pip_options=None, restart=False): + name: str + packages: list[str] + restart: bool + + def __init__( + self, + packages: list[str], + restart: bool, + ): self.packages = packages self.restart = restart - self.pip_options = pip_options or [] - self.id = f"pip-install-{uuid.uuid4()}" + self.name = f"{self.INSTALLER}-install-{uuid.uuid4()}" async def setup(self, worker): from distributed.semaphore import Semaphore @@ -287,9 +291,13 @@ async def setup(self, worker): await Semaphore(max_leases=1, name=socket.gethostname(), register=True) ): if not await self._is_installed(worker): - logger.info("Pip installing the following packages: %s", self.packages) + logger.info( + "%s installing the following packages: %s", + self.INSTALLER, + self.packages, + ) await self._set_installed(worker) - self._install() + self.install() else: logger.info( "The following packages have already been installed: %s", @@ -301,18 +309,9 @@ async def setup(self, worker): await self._set_restarted(worker) worker.loop.add_callback(worker.close_gracefully, restart=True) - def _install(self): - proc = subprocess.Popen( - [sys.executable, "-m", "pip", "install"] + self.pip_options + self.packages, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - _, stderr = proc.communicate() - returncode = proc.wait() - if returncode != 0: - msg = f"Pip install failed with '{stderr.decode().strip()}'" - logger.error(msg) - raise RuntimeError(msg) + @abc.abstractmethod + def install(self) -> None: + """Install the requested packages""" async def _is_installed(self, worker): return await worker.client.get_metadata( @@ -327,7 +326,7 @@ async def _set_installed(self, worker): def _compose_installed_key(self): return [ - self.id, + self.name, "installed", socket.gethostname(), ] @@ -345,7 +344,148 @@ async def _set_restarted(self, worker): ) def _compose_restarted_key(self, worker): - return [self.id, "restarted", worker.nanny] + return [self.name, "restarted", worker.nanny] + + +class CondaInstall(PackageInstall): + """A Worker Plugin to conda install a set of packages + + This accepts a set of packages to install on all workers as well as + options to use when installing. + You can also optionally ask for the worker to restart itself after + performing this installation. + + .. note:: + + This will increase the time it takes to start up + each worker. If possible, we recommend including the + libraries in the worker environment or image. This is + primarily intended for experimentation and debugging. + + Parameters + ---------- + packages + A list of packages (with optional versions) to install using conda + conda_options + Additional options to pass to conda + restart + Whether or not to restart the worker after installing the packages + Only functions if the worker has an attached nanny process + + Examples + -------- + >>> from dask.distributed import CondaInstall + >>> plugin = CondaInstall(packages=["scikit-learn"], conda_options=["--update-deps"]) + + >>> client.register_worker_plugin(plugin) + + See Also + -------- + PackageInstall + PipInstall + """ + + INSTALLER = "conda" + + conda_options: list[str] + + def __init__( + self, + packages: list[str], + conda_options: list[str] | None = None, + restart: bool = False, + ): + super().__init__(packages, restart=restart) + self.conda_options = conda_options or [] + + def install(self) -> None: + try: + from conda.cli.python_api import Commands, run_command + except ModuleNotFoundError as e: # pragma: nocover + msg = ( + "conda install failed because conda could not be found. " + "Please make sure that conda is installed." + ) + logger.error(msg) + raise RuntimeError(msg) from e + try: + _, stderr, returncode = run_command( + Commands.INSTALL, self.conda_options + self.packages + ) + except Exception as e: + msg = "conda install failed" + logger.error(msg) + raise RuntimeError(msg) from e + + if returncode != 0: + msg = f"conda install failed with '{stderr.decode().strip()}'" + logger.error(msg) + raise RuntimeError(msg) + + +class PipInstall(PackageInstall): + """A Worker Plugin to pip install a set of packages + + This accepts a set of packages to install on all workers as well as + options to use when installing. + You can also optionally ask for the worker to restart itself after + performing this installation. + + .. note:: + + This will increase the time it takes to start up + each worker. If possible, we recommend including the + libraries in the worker environment or image. This is + primarily intended for experimentation and debugging. + + Parameters + ---------- + packages + A list of packages (with optional versions) to install using pip + pip_options + Additional options to pass to pip + restart + Whether or not to restart the worker after installing the packages + Only functions if the worker has an attached nanny process + + Examples + -------- + >>> from dask.distributed import PipInstall + >>> plugin = PipInstall(packages=["scikit-learn"], pip_options=["--upgrade"]) + + >>> client.register_worker_plugin(plugin) + + See Also + -------- + PackageInstall + CondaInstall + """ + + INSTALLER = "pip" + + pip_options: list[str] + + def __init__( + self, + packages: list[str], + pip_options: list[str] | None = None, + restart: bool = False, + ): + super().__init__(packages, restart=restart) + self.pip_options = pip_options or [] + + def install(self) -> None: + proc = subprocess.Popen( + [sys.executable, "-m", "pip", "install"] + self.pip_options + self.packages, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _, stderr = proc.communicate() + returncode = proc.wait() + if returncode != 0: + msg = f"pip install failed with '{stderr.decode().strip()}'" + logger.error(msg) + raise RuntimeError(msg) # Adapted from https://github.com/dask/distributed/issues/3560#issuecomment-596138522 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6194c0c5dce..ec4b0584ef1 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -15,7 +15,6 @@ from concurrent.futures.process import BrokenProcessPool from numbers import Number from operator import add -from textwrap import dedent from time import sleep from unittest import mock @@ -45,7 +44,7 @@ from distributed.compatibility import LINUX, WINDOWS, to_thread from distributed.core import CommClosedError, Status, rpc from distributed.diagnostics import nvml -from distributed.diagnostics.plugin import PipInstall +from distributed.diagnostics.plugin import CondaInstall, PackageInstall, PipInstall from distributed.metrics import time from distributed.protocol import pickle from distributed.scheduler import Scheduler @@ -1645,13 +1644,37 @@ async def test_pip_install(c, s, a): await c.register_worker_plugin( PipInstall(packages=["requests"], pip_options=["--upgrade"]) ) - + assert Popen.call_count == 1 args = Popen.call_args[0][0] assert "python" in args[0] assert args[1:] == ["-m", "pip", "install", "--upgrade", "requests"] - assert Popen.call_count == 1 logs = logger.getvalue() - assert "Pip installing" in logs + assert "pip installing" in logs + assert "failed" not in logs + assert "restart" not in logs + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_conda_install(c, s, a): + with captured_logger( + "distributed.diagnostics.plugin", level=logging.INFO + ) as logger: + run_command_mock = mock.Mock(name="run_command_mock") + run_command_mock.configure_mock(return_value=(b"", b"", 0)) + module_mock = mock.Mock(name="conda_cli_python_api_mock") + module_mock.run_command = run_command_mock + module_mock.Commands.INSTALL = "INSTALL" + with mock.patch.dict("sys.modules", {"conda.cli.python_api": module_mock}): + await c.register_worker_plugin( + CondaInstall(packages=["requests"], conda_options=["--update-deps"]) + ) + assert run_command_mock.call_count == 1 + command = run_command_mock.call_args[0][0] + assert command == "INSTALL" + arguments = run_command_mock.call_args[0][1] + assert arguments == ["--update-deps", "requests"] + logs = logger.getvalue() + assert "conda installing" in logs assert "failed" not in logs assert "restart" not in logs @@ -1683,77 +1706,121 @@ async def test_pip_install_fails(c, s, a, b): assert "not-a-package" in logs -@gen_cluster(client=True, nthreads=[]) -async def test_pip_install_restarts_on_nanny(c, s): - preload = dedent( - """\ - from unittest import mock - - mock.patch( - "distributed.diagnostics.plugin.PipInstall._install", return_value=None - ).start() - """ - ) - async with Nanny(s.address, preload=preload): - (addr,) = s.workers - await c.register_worker_plugin( - PipInstall(packages=["requests"], pip_options=["--upgrade"], restart=True) - ) +@gen_cluster(client=True, nthreads=[("", 2), ("", 2)]) +async def test_conda_install_fails_when_conda_not_found(c, s, a, b): + with captured_logger( + "distributed.diagnostics.plugin", level=logging.ERROR + ) as logger: + with mock.patch.dict("sys.modules", {"conda": None}): + with pytest.raises(RuntimeError): + await c.register_worker_plugin(CondaInstall(packages=["not-a-package"])) + logs = logger.getvalue() + assert "install failed" in logs + assert "conda could not be found" in logs - # Wait until the worker is restarted - while len(s.workers) != 1 or set(s.workers) == {addr}: - await asyncio.sleep(0.01) +@gen_cluster(client=True, nthreads=[("", 2), ("", 2)]) +async def test_conda_install_fails_when_conda_raises(c, s, a, b): + with captured_logger( + "distributed.diagnostics.plugin", level=logging.ERROR + ) as logger: + run_command_mock = mock.Mock(name="run_command_mock") + run_command_mock.configure_mock(side_effect=RuntimeError) + module_mock = mock.Mock(name="conda_cli_python_api_mock") + module_mock.run_command = run_command_mock + module_mock.Commands.INSTALL = "INSTALL" + with mock.patch.dict("sys.modules", {"conda.cli.python_api": module_mock}): + with pytest.raises(RuntimeError): + await c.register_worker_plugin(CondaInstall(packages=["not-a-package"])) + assert run_command_mock.call_count == 1 + logs = logger.getvalue() + assert "install failed" in logs -@gen_cluster(client=True, nthreads=[]) -async def test_pip_install_failing_does_not_restart_on_nanny(c, s): - preload = dedent( - """\ - from unittest import mock - - mock.patch( - "distributed.diagnostics.plugin.PipInstall._install", side_effect=RuntimeError - ).start() - """ - ) - async with Nanny(s.address, preload=preload) as n: - (addr,) = s.workers - with pytest.raises(RuntimeError): - await c.register_worker_plugin( - PipInstall( - packages=["requests"], pip_options=["--upgrade"], restart=True - ) - ) - # Nanny does not restart - assert n.status is Status.running - assert set(s.workers) == {addr} + +@gen_cluster(client=True, nthreads=[("", 2), ("", 2)]) +async def test_conda_install_fails_on_returncode(c, s, a, b): + with captured_logger( + "distributed.diagnostics.plugin", level=logging.ERROR + ) as logger: + run_command_mock = mock.Mock(name="run_command_mock") + run_command_mock.configure_mock(return_value=(b"", b"", 1)) + module_mock = mock.Mock(name="conda_cli_python_api_mock") + module_mock.run_command = run_command_mock + module_mock.Commands.INSTALL = "INSTALL" + with mock.patch.dict("sys.modules", {"conda.cli.python_api": module_mock}): + with pytest.raises(RuntimeError): + await c.register_worker_plugin(CondaInstall(packages=["not-a-package"])) + assert run_command_mock.call_count == 1 + logs = logger.getvalue() + assert "install failed" in logs + + +class StubInstall(PackageInstall): + INSTALLER = "stub" + + def __init__(self, packages: list[str], restart: bool = False): + super().__init__(packages=packages, restart=restart) + + def install(self) -> None: + pass @gen_cluster(client=True, nthreads=[("", 1), ("", 1)]) -async def test_pip_install_multiple_workers(c, s, a, b): +async def test_package_install_installs_once_with_multiple_workers(c, s, a, b): with captured_logger( "distributed.diagnostics.plugin", level=logging.INFO ) as logger: - mocked = mock.Mock() - mocked.configure_mock( - **{"communicate.return_value": (b"", b""), "wait.return_value": 0} - ) - with mock.patch( - "distributed.diagnostics.plugin.subprocess.Popen", return_value=mocked - ) as Popen: + install_mock = mock.Mock(name="install") + with mock.patch.object(StubInstall, "install", install_mock): await c.register_worker_plugin( - PipInstall(packages=["requests"], pip_options=["--upgrade"]) + StubInstall( + packages=["requests"], + ) ) - - args = Popen.call_args[0][0] - assert "python" in args[0] - assert args[1:] == ["-m", "pip", "install", "--upgrade", "requests"] - assert Popen.call_count == 1 + assert install_mock.call_count == 1 logs = logger.getvalue() - assert "Pip installing" in logs assert "already been installed" in logs +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_package_install_restarts_on_nanny(c, s, a): + (addr,) = s.workers + await c.register_worker_plugin( + StubInstall( + packages=["requests"], + restart=True, + ) + ) + # Wait until the worker is restarted + while len(s.workers) != 1 or set(s.workers) == {addr}: + await asyncio.sleep(0.01) + + +class FailingInstall(PackageInstall): + INSTALLER = "fail" + + def __init__(self, packages: list[str], restart: bool = False): + super().__init__(packages=packages, restart=restart) + + def install(self) -> None: + raise RuntimeError() + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_package_install_failing_does_not_restart_on_nanny(c, s, a): + (addr,) = s.workers + with pytest.raises(RuntimeError): + await c.register_worker_plugin( + FailingInstall( + packages=["requests"], + restart=True, + ) + ) + # Nanny does not restart + assert a.status is Status.running + assert set(s.workers) == {addr} + + @gen_cluster(nthreads=[]) async def test_update_latency(s): async with Worker(s.address) as w: