diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 713f104fd..98d6e88d0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -134,7 +134,7 @@ jobs: if: ${{ matrix.executors == 'dask' && matrix.python-version != '3.8' }} run: | cd tests - PYTEST_EXECUTORS=dask poetry run python -m pytest -sv test_all.py + PYTEST_EXECUTORS=dask poetry run python -m pytest -sv test_all.py test_dask.py webknossos_linux: needs: changes diff --git a/cluster_tools/Changelog.md b/cluster_tools/Changelog.md index 9ae58b858..a9019840c 100644 --- a/cluster_tools/Changelog.md +++ b/cluster_tools/Changelog.md @@ -16,6 +16,7 @@ For upgrade instructions, please check the respective *Breaking Changes* section ### Changed ### Fixed +- Fixed working directory propagation in DaskExecutor. [#994](https://github.com/scalableminds/webknossos-libs/pull/994) ## [0.14.14](https://github.com/scalableminds/webknossos-libs/releases/tag/v0.14.14) - 2024-01-12 diff --git a/cluster_tools/cluster_tools/executors/dask.py b/cluster_tools/cluster_tools/executors/dask.py index 35ea39bc4..076ff7212 100644 --- a/cluster_tools/cluster_tools/executors/dask.py +++ b/cluster_tools/cluster_tools/executors/dask.py @@ -41,6 +41,8 @@ def _run_in_nanny( for key, value in __env.items(): os.environ[key] = value + if "PWD" in os.environ: + os.chdir(os.environ["PWD"]) ret = __fn(*args, **kwargs) queue.put({"value": ret}) except Exception as exc: @@ -174,7 +176,9 @@ def submit( # type: ignore[override] ), ) - kwargs["__env"] = os.environ.copy() + __env = os.environ.copy() + __env["PWD"] = os.getcwd() + kwargs["__env"] = __env # We run the functions in dask as a separate process to not hold the # GIL for too long, because dask workers need to be able to communicate diff --git a/cluster_tools/tests/test_dask.py b/cluster_tools/tests/test_dask.py new file mode 100644 index 000000000..f7393cda7 --- /dev/null +++ b/cluster_tools/tests/test_dask.py @@ -0,0 +1,29 @@ +import os +from typing import TYPE_CHECKING, List, Optional + +if TYPE_CHECKING: + from distributed import LocalCluster + +import cluster_tools + +_dask_cluster: Optional["LocalCluster"] = None + + +def job(_arg: None) -> str: + return os.getcwd() + + +def test_pass_cwd() -> None: + global _dask_cluster + if not _dask_cluster: + from distributed import LocalCluster, Worker + + _dask_cluster = LocalCluster( + worker_class=Worker, resources={"mem": 20e9, "cpus": 4}, nthreads=6 + ) + with cluster_tools.get_executor( + "dask", job_resources={"address": _dask_cluster} + ) as exec: + tmp_path = os.path.realpath("/tmp") # macOS redirects `/tmp` to `/private/tmp` + os.chdir(tmp_path) + assert list(exec.map(job, [None])) == [tmp_path]