diff --git a/dvc/proc/manager.py b/dvc/proc/manager.py index 1d4391eb7d..bb32bfab10 100644 --- a/dvc/proc/manager.py +++ b/dvc/proc/manager.py @@ -5,12 +5,12 @@ import os import signal import sys -from typing import Generator, List, Optional, Union +from typing import Generator, List, Optional, Tuple, Union from funcy.flow import reraise from shortuuid import uuid -from .exceptions import UnsupportedSignalError +from .exceptions import ProcessNotTerminatedError, UnsupportedSignalError from .process import ManagedProcess, ProcessInfo logger = logging.getLogger(__name__) @@ -27,8 +27,11 @@ class ProcessManager: def __init__(self, wdir: Optional[str] = None): self.wdir = wdir or "." - def __iter__(self): - return self.processes() + def __iter__(self) -> Generator[str, None, None]: + if not os.path.exists(self.wdir): + return + for name in os.listdir(self.wdir): + yield name def __getitem__(self, key: str) -> "ProcessInfo": info_path = os.path.join(self.wdir, key, f"{key}.json") @@ -44,18 +47,23 @@ def __setitem__(self, key: str, value: "ProcessInfo"): with open(info_path, "w", encoding="utf-8") as fobj: return json.dump(value.asdict(), fobj) + def __delitem__(self, key: str) -> None: + from dvc.utils.fs import remove + + path = os.path.join(self.wdir, key) + if os.path.exists(path): + remove(path) + def get(self, key: str, default=None): try: return self[key] except KeyError: return default - def processes(self) -> Generator["ProcessInfo", None, None]: - if not os.path.exists(self.wdir): - return - for name in os.listdir(self.wdir): + def processes(self) -> Generator[Tuple[str, "ProcessInfo"], None, None]: + for name in self: try: - yield self[name] + yield name, self[name] except KeyError: continue @@ -125,8 +133,22 @@ def remove(self, name: str, force: bool = False): ProcessNotTerminatedError if the specified process is still running and was not forcefully killed. """ - raise NotImplementedError + try: + process_info = self[name] + except KeyError: + return + if process_info.returncode is None and not force: + raise ProcessNotTerminatedError(name) + try: + self.kill(name) + except ProcessLookupError: + pass + del self[name] - def cleanup(self): + def cleanup(self, force: bool = False): """Remove stale (terminated) processes from this manager.""" - raise NotImplementedError + for name in self: + try: + self.remove(name, force) + except ProcessNotTerminatedError: + continue diff --git a/tests/unit/proc/test_manager.py b/tests/unit/proc/test_manager.py index f93d05847b..bfd6f7146e 100644 --- a/tests/unit/proc/test_manager.py +++ b/tests/unit/proc/test_manager.py @@ -5,7 +5,10 @@ import pytest -from dvc.proc.exceptions import UnsupportedSignalError +from dvc.proc.exceptions import ( + ProcessNotTerminatedError, + UnsupportedSignalError, +) from dvc.proc.manager import ProcessManager from dvc.proc.process import ProcessInfo @@ -92,3 +95,24 @@ def test_terminate(tmp_dir, mocker, running_process, finished_process): m.reset_mock() process_manager.terminate(finished_process) m.assert_not_called() + + +def test_remove(mocker, tmp_dir, running_process, finished_process): + mocker.patch("os.kill", return_value=None) + process_manager = ProcessManager(tmp_dir) + process_manager.remove(finished_process) + assert not (tmp_dir / finished_process).exists() + with pytest.raises(ProcessNotTerminatedError): + process_manager.remove(running_process) + assert (tmp_dir / running_process).exists() + process_manager.remove(running_process, True) + assert not (tmp_dir / running_process).exists() + + +@pytest.mark.parametrize("force", [True, False]) +def test_cleanup(mocker, tmp_dir, running_process, finished_process, force): + mocker.patch("os.kill", return_value=None) + process_manager = ProcessManager(tmp_dir) + process_manager.cleanup(force) + assert (tmp_dir / running_process).exists() != force + assert not (tmp_dir / finished_process).exists()