Skip to content

Commit

Permalink
SlurmRunner proof-of-concept (#3)
Browse files Browse the repository at this point in the history
* slurm: initial proof-of-concept

Use `SLURM_PROCID` to get the rank, and `scheduler_file` to get the scheduler address. The `SlurmRunner` doesn't actually need to know the scheduler address, it just needs to pass along the scheduler file.

* base: allow Runners to not have the scheduler address. This lets the SlurmRunner just pass along the scheduler_file rather than opening it.

* slurm: add test

* slurm: wait for scheduler file to exist before launching the client/workers. Also append the job ID to the filename stem.

* slurm: store the scheduler_address in the SlurmRunner for compatibility with Client

* slurm: smarter default scheduler_file name
  • Loading branch information
lgarrison authored May 2, 2024
1 parent dce935f commit 0e4d3d3
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 3 deletions.
1 change: 1 addition & 0 deletions dask_hpc_runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import AsyncCommWorld, AsyncRunner # noqa
from .mpi import MPIRunner
from .slurm import SlurmRunner
8 changes: 5 additions & 3 deletions dask_hpc_runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ async def _start(self) -> None:
sys.exit(0)
elif self.role == Role.client:
self.scheduler_address = await self.get_scheduler_address()
self.scheduler_comm = rpc(self.scheduler_address)
if self.scheduler_address:
self.scheduler_comm = rpc(self.scheduler_address)
await self.before_client_start()
self.status = Status.running

Expand All @@ -179,8 +180,9 @@ async def start_worker(self) -> None:
async def _close(self) -> None:
print(f"stopping {self.role}")
if self.status == Status.running:
with suppress(CommClosedError):
await self.scheduler_comm.terminate()
if self.scheduler_comm:
with suppress(CommClosedError):
await self.scheduler_comm.terminate()
self.status = Status.closed

def close(self) -> None:
Expand Down
4 changes: 4 additions & 0 deletions dask_hpc_runner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ def mpirun(allow_run_as_root):
return ["mpirun", "--allow-run-as-root"]
else:
return ["mpirun"]

@pytest.fixture
def srun():
return ["srun", "--mpi=none"]
175 changes: 175 additions & 0 deletions dask_hpc_runner/slurm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import asyncio
import atexit
import json
import os
from pathlib import Path

import dask
from dask.distributed import Scheduler
from .base import Role, BaseRunner

_RUNNER_REF = None


class WorldTooSmallException(RuntimeError):
"""Not enough Slurm tasks to start all required processes."""


class SlurmRunner(BaseRunner):
def __init__(self, *args, scheduler_file="scheduler-{}.json", **kwargs):
try:
self.rank = int(os.environ["SLURM_PROCID"])
self.world_size = self.n_workers = int(os.environ["SLURM_NTASKS"])
self.job_id = int(os.environ["SLURM_JOB_ID"])
except KeyError as e:
raise RuntimeError("SLURM_PROCID, SLURM_NTASKS, and SLURM_JOB_ID must be present "
"in the environment."
) from e
if not scheduler_file:
scheduler_file = kwargs.get("scheduler_options",{}).get("scheduler_file")

if not scheduler_file:
raise RuntimeError("scheduler_file must be specified in either the "
"scheduler_options or as keyword argument to SlurmRunner.")

# Encourage filename uniqueness by inserting the job ID
scheduler_file = scheduler_file.format(self.job_id)
scheduler_file = Path(scheduler_file)

if isinstance(kwargs.get("scheduler_options"), dict):
kwargs["scheduler_options"]["scheduler_file"] = scheduler_file
else:
kwargs["scheduler_options"] = {"scheduler_file": scheduler_file}
if isinstance(kwargs.get("worker_options"), dict):
kwargs["worker_options"]["scheduler_file"] = scheduler_file
else:
kwargs["worker_options"] = {"scheduler_file": scheduler_file}

self.scheduler_file = scheduler_file

super().__init__(*args, **kwargs)

async def get_role(self) -> str:
if self.scheduler and self.client and self.world_size < 3:
raise WorldTooSmallException(
f"Not enough Slurm tasks to start cluster, found {self.world_size}, "
"needs at least 3, one each for the scheduler, client and a worker."
)
elif self.scheduler and self.world_size < 2:
raise WorldTooSmallException(
f"Not enough Slurm tasks to start cluster, found {self.world_size}, "
"needs at least 2, one each for the scheduler and a worker."
)
self.n_workers -= int(self.scheduler) + int(self.client)
if self.rank == 0 and self.scheduler:
return Role.scheduler
elif self.rank == 1 and self.client:
return Role.client
else:
return Role.worker

async def set_scheduler_address(self, scheduler: Scheduler) -> None:
return

async def get_scheduler_address(self) -> str:
return

async def on_scheduler_start(self, scheduler: Scheduler) -> None:
return

async def before_worker_start(self) -> None:
while not self.scheduler_file.exists():
await asyncio.sleep(0.2)
self.load_scheduler_address()

async def before_client_start(self) -> None:
while not self.scheduler_file.exists():
await asyncio.sleep(0.2)
self.load_scheduler_address()

def load_scheduler_address(self):
with self.scheduler_file.open() as f:
cfg = json.load(f)
self.scheduler_address = cfg["address"]

async def get_worker_name(self) -> str:
return self.rank

async def _close(self):
await super()._close()


def initialize(
interface=None,
nthreads=1,
local_directory="",
memory_limit="auto",
nanny=False,
dashboard=True,
dashboard_address=":8787",
protocol=None,
worker_class="distributed.Worker",
worker_options=None,
scheduler_file="scheduler-{}.json",
):
"""
Initialize a Dask cluster using Slurm
Using Slurm, the user launches 3 or more tasks, usually with the 'srun' command.
The first task becomes the scheduler, the second task becomes the client, and the
remaining tasks become workers. Each task identifies its task ID using the
'SLURM_PROCID' environment variable, which is like the MPI rank. The scheduler
address is communicated using the 'scheduler_file' keyword argument.
Parameters
----------
interface : str
Network interface like 'eth0' or 'ib0'
nthreads : int
Number of threads per worker
local_directory : str
Directory to place worker files
memory_limit : int, float, or 'auto'
Number of bytes before spilling data to disk. This can be an
integer (nbytes), float (fraction of total memory), or 'auto'.
nanny : bool
Start workers in nanny process for management (deprecated, use worker_class instead)
dashboard : bool
Enable Bokeh visual diagnostics
dashboard_address : str
Bokeh port for visual diagnostics
protocol : str
Protocol like 'inproc' or 'tcp'
worker_class : str
Class to use when creating workers
worker_options : dict
Options to pass to workers
scheduler_file : str
Filename to use when saving scheduler connection information. A format placeholder
will be replaced with the job ID.
"""

scheduler_options = {
"interface": interface,
"protocol": protocol,
"dashboard": dashboard,
"dashboard_address": dashboard_address,
"scheduler_file": scheduler_file,
}
worker_options = {
"interface": interface,
"protocol": protocol,
"nthreads": nthreads,
"memory_limit": memory_limit,
"local_directory": local_directory,
}
worker_class = "dask.distributed.Nanny" if nanny else "dask.distributed.Worker"
runner = SlurmRunner(
scheduler_options=scheduler_options,
worker_class=worker_class,
worker_options=worker_options,
)
dask.config.set(scheduler_file=scheduler_file)
_RUNNER_REF = runner # Keep a reference to avoid gc
atexit.register(_RUNNER_REF.close)
return runner
7 changes: 7 additions & 0 deletions dask_hpc_runner/tests/slurm_core_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dask.distributed import Client
from dask_hpc_runner import SlurmRunner

with SlurmRunner() as runner:
with Client(runner) as client:
assert client.submit(lambda x: x + 1, 10).result() == 11
assert client.submit(lambda x: x + 1, 20, workers=2).result() == 21
22 changes: 22 additions & 0 deletions dask_hpc_runner/tests/test_slurm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import subprocess
import sys


def test_context(srun):
script_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "slurm_core_context.py")

p = subprocess.Popen(srun + ["-n", "4", sys.executable, script_file])

p.communicate()
assert p.returncode == 0


def test_small_world(srun):
script_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "slurm_core_context.py")

p = subprocess.Popen(srun + ["-n", "1", sys.executable, script_file], stderr=subprocess.PIPE)

_, std_err = p.communicate()
assert p.returncode != 0
assert "Not enough Slurm tasks" in std_err.decode(sys.getfilesystemencoding())

0 comments on commit 0e4d3d3

Please sign in to comment.