Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP co-assign related root-ish tasks #4899

Closed
Closed
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 181 additions & 18 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,9 @@ class TaskGroup:
_start: double
_stop: double
_all_durations: object
_last_worker: WorkerState
_last_worker_tasks_left: int # TODO Py_ssize_t?
_last_worker_priority: tuple # TODO remove (debugging only)

def __init__(self, name: str):
self._name = name
Expand All @@ -964,6 +967,9 @@ def __init__(self, name: str):
self._start = 0.0
self._stop = 0.0
self._all_durations = defaultdict(float)
self._last_worker = None
self._last_worker_tasks_left = 0
self._last_worker_priority = ()

@property
def name(self):
Expand Down Expand Up @@ -1009,6 +1015,26 @@ def start(self):
def stop(self):
return self._stop

@property
def last_worker(self):
return self._last_worker

@property
def last_worker_tasks_left(self):
return self._last_worker_tasks_left

@last_worker_tasks_left.setter
def last_worker_tasks_left(self, n: int):
self._last_worker_tasks_left = n

@property
def last_worker_priority(self):
return self._last_worker_priority

@last_worker_priority.setter
def last_worker_priority(self, x: tuple):
self._last_worker_priority = x

@ccall
def add(self, o):
ts: TaskState = o
Expand Down Expand Up @@ -2337,14 +2363,20 @@ def decide_worker(self, ts: TaskState) -> WorkerState:
ts.state = "no-worker"
return ws

if ts._dependencies or valid_workers is not None:
if (
ts._dependencies
or valid_workers is not None
or ts._group._last_worker is not None
):
ws = decide_worker(
ts,
self._workers_dv.values(),
valid_workers,
partial(self.worker_objective, ts),
self._total_nthreads,
)
else:
# Fastpath when there are no related tasks or restrictions
worker_pool = self._idle or self._workers
worker_pool_dv = cast(dict, worker_pool)
wp_vals = worker_pool.values()
Expand All @@ -2366,6 +2398,14 @@ def decide_worker(self, ts: TaskState) -> WorkerState:
else: # dumb but fast in large case
ws = wp_vals[self._n_tasks % n_workers]

# TODO repeated logic from `decide_worker`
print(f"nodeps / no last worker fastpah - {ts.group_key} -> {ws.name}")
ts._group._last_worker = ws
ts._group._last_worker_tasks_left = math.floor(
len(ts._group) / self._total_nthreads
)
ts._group._last_worker_priority = ts._priority

if self._validate:
assert ws is None or isinstance(ws, WorkerState), (
type(ws),
Expand Down Expand Up @@ -7468,14 +7508,58 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
steal.put_key_in_stealable(ts)


NOT_ROOT_ISH = WorkerState()


@cfunc
@exceptval(check=False)
def decide_worker(
ts: TaskState, all_workers, valid_workers: set, objective
ts: TaskState,
all_workers,
valid_workers: set,
objective,
total_nthreads: Py_ssize_t,
) -> WorkerState:
"""
r"""
Decide which worker should take task *ts*.

There are two modes: root(ish) tasks, and normal tasks.

Root(ish) tasks
~~~~~~~~~~~~~~~

Root(ish) have no (or very very few) dependencies and fan out widely:
they belong to TaskGroups that contain more tasks than there are workers.
We want neighboring root tasks to run on the same worker, since there's a
good chance those neighbors will be combined in a downstream operation:

i j
/ \ / \
e f g h
| | | |
a b c d
\ \ / /
X

In the above case, we want ``a`` and ``b`` to run on the same worker,
and ``c`` and ``d`` to run on the same worker, reducing future
data transfer. We can also ignore the location of ``X``, because
as a common dependency, it will eventually get transferred everywhere.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3 the ascii art

Comment/question: Do we want to explain all of this here? Historically I haven't put the logic behind heuristics in the code. This is a subjective opinion, and far from universal, but I find that heavily commented/documented logic makes it harder to understand the code at a glance. I really like that the current decide_worker implementation fits in a terminal window. I think that single-line comments are cool, but that long multi-line comments would better be written as documentation.

Thoughts? If you are not in disagreement then I would encourage us to write up a small docpage or maybe a blogpost and then link to that external resource from the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also planning on updating https://distributed.dask.org/en/latest/scheduling-policies.html#choosing-workers, probably with this same ascii art. So just linking to that page in the docstring seems appropriate.


Calculaing this directly from the graph would be expensive, so instead
we use task priority as a proxy. We aim to send tasks close in priority
within a `TaskGroup` to the same worker. To do this efficiently, we rely
on the fact that `decide_worker` is generally called in priority order
for root tasks (because `Scheduler.update_graph` creates recommendations
in priority order), and track only the last worker used for a `TaskGroup`,
and how many more tasks can be assigned to it before picking a new one.

By colocating related root tasks, we ensure that placing thier downstream
normal tasks is set up for success.

Normal tasks
~~~~~~~~~~~~

We choose the worker that has the data on which *ts* depends.

If several workers have dependencies then we choose the less-busy worker.
Expand All @@ -7488,36 +7572,115 @@ def decide_worker(
of bytes sent between workers. This is determined by calling the
*objective* function.
"""
ws: WorkerState = None
wws: WorkerState
dts: TaskState

group: TaskGroup = ts._group
ws: WorkerState = group._last_worker

group_tasks_per_worker: float
if ws is None or (ws is not NOT_ROOT_ISH and group._last_worker_tasks_left) == 0:
# Calculate the ratio of tasks in the task group to number of workers.
# We only need to do this computation when 1) seeing a task group for the first time,
# or 2) this is a root-ish task group, and we've just filled up the worker we were
# sending tasks to and need to pick a new one.
if valid_workers is not None:
total_nthreads = sum(wws._nthreads for wws in valid_workers)

group_tasks_per_worker = len(group) / total_nthreads
else:
group_tasks_per_worker = float("nan")

is_root_ish: bool
if ws is None:
# Very fist task in the group - we haven't determined yet whether it's a root-ish task group
if (
group_tasks_per_worker > 1 # group is larger than cluster
and ( # is a root-like task (task group is large, but depends on very few tasks)
sum(map(len, group._dependencies)) < 5 # TODO what number
)
):
is_root_ish = True
else:
is_root_ish = False
group._last_worker = NOT_ROOT_ISH
else:
# We've seen this task group before and already made the above determination
is_root_ish = ws is not NOT_ROOT_ISH

if is_root_ish and ws is not None and group._last_worker_tasks_left > 0:
# Root-ish task and previous worker not fully assigned - reuse previous worker.
# (When the previous worker _is_ fully assigned, we fall through here to the pick-a-worker logic.)
if group._last_worker_priority < ts.priority:
group._last_worker_priority = ts.priority
group._last_worker_tasks_left -= 1
# print(f"reusing worker - {ts.group_key} -> {ws.name}")
return ws

# We're not being called in priority order---this is probably not actually a
# root-ish task; disable root task mode for its whole task group.
# print(
# f"decide_worker called out of priority order: {group._last_worker_priority} >= {ts.priority}.\n"
# f"{ts=}\n"
# f"{group.last_worker=}\n"
# f"{group.last_worker_tasks_left=}\n"
# f"{group_tasks_per_worker=}\n"
# )
group._last_worker = NOT_ROOT_ISH
group._last_worker_tasks_left = 0
is_root_ish = False

# Pick a worker to run this task
deps: set = ts._dependencies
dts: TaskState
candidates: set
assert all([dts._who_has for dts in deps])
if ts._actor:
candidates = set(all_workers)
if is_root_ish:
# Previous worker is fully assigned (or unknown), so pick a new worker.
# Since this is a root-like task, we should ignore the placement of its dependencies while selecting workers.
# Every worker is going to end up running this type of task eventually, and any dependencies will have to be
# transferred to all workers, so there's no gain from only considering workers where the dependencies already live.
# Indeed, we _must_ consider all workers, otherwise we would keep picking the same "new" worker(s) every time,
# since there are only N workers to choose from that actually have the dependency (where N <= n_deps).
candidates = valid_workers if valid_workers is not None else set(all_workers)
else:
candidates = {wws for dts in deps for wws in dts._who_has}
if valid_workers is None:
if not candidates:
# Restrict placement of this task to workers that hold its dependencies
if ts._actor:
candidates = set(all_workers)
else:
candidates &= valid_workers
if not candidates:
candidates = valid_workers
else:
candidates = {wws for dts in deps for wws in dts._who_has}
if valid_workers is None:
if not candidates:
if ts._loose_restrictions:
ws = decide_worker(ts, all_workers, None, objective)
return ws
candidates = set(all_workers)
else:
candidates &= valid_workers
if not candidates:
candidates = valid_workers
if not candidates:
if ts._loose_restrictions:
ws = decide_worker(
ts, all_workers, None, objective, total_nthreads
)
return ws

ncandidates: Py_ssize_t = len(candidates)
if ncandidates == 0:
pass
# print(f"no candidates - {ts.group_key}")
return None
elif ncandidates == 1:
# NOTE: this is the ideal case: all the deps are already on the same worker.
# We did a good job in previous `decide_worker`s!
for ws in candidates:
break
# print(f"1 candidate - {ts.group_key} -> {ws.name}")
else:
ws = min(candidates, key=objective)
# print(f"picked worker - {ts.group_key} -> {ws.name}")

if is_root_ish:
group._last_worker = ws
group._last_worker_tasks_left = math.floor(group_tasks_per_worker)
group._last_worker_priority = ts.priority

return ws


Expand Down