Skip to content

Commit

Permalink
Refactor resource restriction handling in WorkerState (#6672)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Jul 6, 2022
1 parent d88c1d2 commit f7f6501
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 36 deletions.
4 changes: 2 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,7 @@ def transition_waiting_processing(self, key, stimulus_id):
self._set_duration_estimate(ts, ws)
ts.processing_on = ws
ts.state = "processing"
self.consume_resources(ts, ws)
self.acquire_resources(ts, ws)
self.check_idle_saturated(ws)
self.n_tasks += 1

Expand Down Expand Up @@ -2675,7 +2675,7 @@ def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:

return s

def consume_resources(self, ts: TaskState, ws: WorkerState):
def acquire_resources(self, ts: TaskState, ws: WorkerState):
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required

Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ async def test_submit_many_non_overlapping_2(c, s, a, b):
assert b.state.executing_count <= 1

await wait(futures)
assert a.total_resources == a.state.available_resources
assert b.total_resources == b.state.available_resources
assert a.state.total_resources == a.state.available_resources
assert b.state.total_resources == b.state.available_resources


@gen_cluster(
Expand Down Expand Up @@ -232,7 +232,7 @@ async def test_minimum_resource(c, s, a):
assert a.state.executing_count <= 1

await wait(futures)
assert a.total_resources == a.state.available_resources
assert a.state.total_resources == a.state.available_resources


@gen_cluster(client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})])
Expand Down Expand Up @@ -271,7 +271,7 @@ async def test_balance_resources(c, s, a, b):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)])
async def test_set_resources(c, s, a):
await a.set_resources(A=2)
assert a.total_resources["A"] == 2
assert a.state.total_resources["A"] == 2
assert a.state.available_resources["A"] == 2
assert s.workers[a.address].resources == {"A": 2}
lock = Lock()
Expand All @@ -281,7 +281,7 @@ async def test_set_resources(c, s, a):
await asyncio.sleep(0.01)

await a.set_resources(A=3)
assert a.total_resources["A"] == 3
assert a.state.total_resources["A"] == 3
assert a.state.available_resources["A"] == 2
assert s.workers[a.address].resources == {"A": 3}

Expand Down
1 change: 1 addition & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,7 @@ def ws_with_running_task(ws, request):
The task may or may not raise secede(); the tests using this fixture runs twice.
"""
ws.available_resources = {"R": 1}
ws.total_resources = {"R": 1}
instructions = ws.handle_stimulus(
ComputeTaskEvent.dummy(
key="x", resource_restrictions={"R": 1}, stimulus_id="compute"
Expand Down
19 changes: 10 additions & 9 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ class Worker(BaseWorker, ServerNode):
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
total_resources: dict[str, float]
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
Expand Down Expand Up @@ -628,7 +627,6 @@ def __init__(
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.total_resources = resources.copy()

self.death_timeout = parse_timedelta(death_timeout)

Expand Down Expand Up @@ -754,7 +752,7 @@ def __init__(
data=self.memory_manager.data,
threads=self.threads,
plugins=self.plugins,
resources=self.total_resources,
resources=resources,
total_out_connections=total_out_connections,
validate=validate,
transition_counter_max=transition_counter_max,
Expand Down Expand Up @@ -877,6 +875,7 @@ def data(self) -> MutableMapping[str, Any]:
tasks = DeprecatedWorkerStateAttribute()
target_message_size = DeprecatedWorkerStateAttribute()
total_out_connections = DeprecatedWorkerStateAttribute()
total_resources = DeprecatedWorkerStateAttribute()
transition_counter = DeprecatedWorkerStateAttribute()
transition_counter_max = DeprecatedWorkerStateAttribute()
validate = DeprecatedWorkerStateAttribute()
Expand Down Expand Up @@ -1100,7 +1099,7 @@ async def _register_with_scheduler(self) -> None:
},
types={k: typename(v) for k, v in self.data.items()},
now=time(),
resources=self.total_resources,
resources=self.state.total_resources,
memory_limit=self.memory_manager.memory_limit,
local_directory=self.local_directory,
services=self.service_ports,
Expand Down Expand Up @@ -1752,17 +1751,19 @@ def update_data(
)
return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"}

async def set_resources(self, **resources) -> None:
async def set_resources(self, **resources: float) -> None:
for r, quantity in resources.items():
if r in self.total_resources:
self.state.available_resources[r] += quantity - self.total_resources[r]
if r in self.state.total_resources:
self.state.available_resources[r] += (
quantity - self.state.total_resources[r]
)
else:
self.state.available_resources[r] = quantity
self.total_resources[r] = quantity
self.state.total_resources[r] = quantity

await retry_operation(
self.scheduler.set_resources,
resources=self.total_resources,
resources=self.state.total_resources,
worker=self.contact_address,
)

Expand Down
47 changes: 27 additions & 20 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
}
READY: set[TaskStateState] = {"ready", "constrained"}


NO_VALUE = "--no-value-sentinel--"


Expand Down Expand Up @@ -1027,8 +1026,12 @@ class WorkerState:
#: determining a last-in-first-out order between them.
generation: int

#: ``{resource name: amount}``. Total resources available for task execution.
#: See :doc: `resources`.
total_resources: dict[str, float]

#: ``{resource name: amount}``. Current resources that aren't being currently
#: consumed by task execution. Always less or equal to ``Worker.total_resources``.
#: consumed by task execution. Always less or equal to :attr:`total_resources`.
#: See :doc:`resources`.
available_resources: dict[str, float]

Expand Down Expand Up @@ -1102,7 +1105,8 @@ def __init__(
self.data = data if data is not None else {}
self.threads = threads if threads is not None else {}
self.plugins = plugins if plugins is not None else {}
self.available_resources = dict(resources) if resources is not None else {}
self.total_resources = dict(resources) if resources is not None else {}
self.available_resources = self.total_resources.copy()

self.validate = validate
self.tasks = {}
Expand Down Expand Up @@ -1445,15 +1449,11 @@ def _ensure_computing(self) -> RecsInstrs:
if ts in recs:
continue

if any(
self.available_resources[resource] < needed
for resource, needed in ts.resource_restrictions.items()
):
if not self._resource_restrictions_satisfied(ts):
break

self.constrained.popleft()
for resource, needed in ts.resource_restrictions.items():
self.available_resources[resource] -= needed
self._acquire_resources(ts)

recs[ts] = "executing"
self.executing.add(ts)
Expand Down Expand Up @@ -1734,8 +1734,7 @@ def _transition_executing_rescheduled(
# Reschedule(), which is "cancelled"
assert ts.state in ("executing", "long-running"), ts

for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
self._release_resources(ts)
self.executing.discard(ts)

return merge_recs_instructions(
Expand Down Expand Up @@ -1831,8 +1830,7 @@ def _transition_executing_error(
*,
stimulus_id: str,
) -> RecsInstrs:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
self._release_resources(ts)
self.executing.discard(ts)

return merge_recs_instructions(
Expand Down Expand Up @@ -1977,9 +1975,7 @@ def _transition_cancelled_released(
self.executing.discard(ts)
self.in_flight_tasks.discard(ts)

for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity

self._release_resources(ts)
return self._transition_generic_released(ts, stimulus_id=stimulus_id)

def _transition_executing_released(
Expand All @@ -2006,10 +2002,7 @@ def _transition_generic_memory(
f"Tried to transition task {ts} to `memory` without data available"
)

if ts.resource_restrictions is not None:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity

self._release_resources(ts)
self.executing.discard(ts)
self.in_flight_tasks.discard(ts)
ts.coming_from = None
Expand Down Expand Up @@ -2351,6 +2344,20 @@ def _transition(
)
return recs, instructions

def _resource_restrictions_satisfied(self, ts: TaskState) -> bool:
return all(
self.available_resources[resource] >= needed
for resource, needed in ts.resource_restrictions.items()
)

def _acquire_resources(self, ts: TaskState) -> None:
for resource, needed in ts.resource_restrictions.items():
self.available_resources[resource] -= needed

def _release_resources(self, ts: TaskState) -> None:
for resource, needed in ts.resource_restrictions.items():
self.available_resources[resource] += needed

def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructions:
"""Process transitions until none are left
Expand Down

0 comments on commit f7f6501

Please sign in to comment.