Skip to content

Commit

Permalink
Yank state machine out of Worker class (#6566)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 14, 2022
1 parent 0fcc724 commit 344868a
Show file tree
Hide file tree
Showing 15 changed files with 2,886 additions and 2,417 deletions.
3 changes: 2 additions & 1 deletion distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,9 @@ def run(
) -> SuggestionGenerator:
"""This method is invoked by the ActiveMemoryManager every few seconds, or
whenever the user invokes ``client.amm.run_once``.
It is an iterator that must emit
:class:`~distributed.active_memory_manager.Suggestion`s:
:class:`~distributed.active_memory_manager.Suggestion` objects:
- ``Suggestion("replicate", <TaskState>)``
- ``Suggestion("replicate", <TaskState>, {subset of potential workers to replicate to})``
Expand Down
2 changes: 1 addition & 1 deletion distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def __init__(self, filepath):

async def setup(self, worker):
response = await worker.upload_file(
comm=None, filename=self.filename, data=self.data, load=True
filename=self.filename, data=self.data, load=True
)
assert len(self.data) == response["nbytes"]

Expand Down
7 changes: 4 additions & 3 deletions distributed/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,16 @@ def stop_services(self):
def service_ports(self):
return {k: v.port for k, v in self.services.items()}

def _setup_logging(self, logger):
def _setup_logging(self, *loggers):
self._deque_handler = DequeHandler(
n=dask.config.get("distributed.admin.log-length")
)
self._deque_handler.setFormatter(
logging.Formatter(dask.config.get("distributed.admin.log-format"))
)
logger.addHandler(self._deque_handler)
weakref.finalize(self, logger.removeHandler, self._deque_handler)
for logger in loggers:
logger.addHandler(self._deque_handler)
weakref.finalize(self, logger.removeHandler, self._deque_handler)

def get_logs(self, start=0, n=None, timestamps=False):
"""
Expand Down
5 changes: 4 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,10 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:


class WorkerState:
"""A simple object holding information about a worker."""
"""A simple object holding information about a worker.
Not to be confused with :class:`distributed.worker_state_machine.WorkerState`.
"""

#: This worker's unique key. This can be its connected address
#: (such as ``"tcp://127.0.0.1:8891"``) or an alias (such as ``"alice"``).
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/shuffle_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __init__(self, worker: Worker) -> None:
# Initialize
self.worker: Worker = worker
self.shuffles: dict[ShuffleId, Shuffle] = {}
self.executor = ThreadPoolExecutor(worker.nthreads)
self.executor = ThreadPoolExecutor(worker.state.nthreads)

# Handlers
##########
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,8 +866,8 @@ async def test_RetireWorker_no_recipients(c, s, w1, w2, w3, w4):
assert set(out) in ({w1.address, w3.address}, {w1.address, w4.address})
assert not s.extensions["amm"].policies
assert set(s.workers) in ({w2.address, w3.address}, {w2.address, w4.address})
# After a Scheduler -> Worker -> WorkerState roundtrip, workers that failed to
# retire went back from closing_gracefully to running and can run tasks
# After a Scheduler -> Worker -> Scheduler roundtrip, workers that failed to retire
# went back from closing_gracefully to running and can run tasks
while any(ws.status != Status.running for ws in s.workers.values()):
await asyncio.sleep(0.01)
assert await c.submit(inc, 1) == 2
Expand Down Expand Up @@ -896,7 +896,7 @@ async def test_RetireWorker_all_recipients_are_paused(c, s, a, b):
assert not s.extensions["amm"].policies
assert set(s.workers) == {a.address, b.address}

# After a Scheduler -> Worker -> WorkerState roundtrip, workers that failed to
# After a Scheduler -> Worker -> Scheduler roundtrip, workers that failed to
# retire went back from closing_gracefully to running and can run tasks
while ws_a.status != Status.running:
await asyncio.sleep(0.01)
Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,7 @@ def g():
os.remove("myfile.zip")


@pytest.mark.slow
@gen_cluster(client=True)
async def test_upload_file_egg(c, s, a, b):
pytest.importorskip("setuptools")
Expand Down Expand Up @@ -6810,7 +6811,7 @@ async def test_workers_collection_restriction(c, s, a, b):
assert a.data and not b.data


@gen_cluster(client=True, nthreads=[("127.0.0.1", 0)])
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_get_client_functions_spawn_clusters(c, s, a):
# see gh4565

Expand Down
25 changes: 11 additions & 14 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
from contextlib import contextmanager
from time import sleep
from unittest import mock

import pytest
import yaml
Expand Down Expand Up @@ -44,7 +45,8 @@
from distributed.worker_state_machine import (
InvalidTaskState,
InvalidTransition,
StateMachineEvent,
PauseEvent,
WorkerState,
)


Expand Down Expand Up @@ -656,22 +658,17 @@ def test_start_failure_scheduler():


def test_invalid_transitions(capsys):
class BrokenEvent(StateMachineEvent):
pass

class MyWorker(Worker):
@Worker._handle_event.register
def _(self, ev: BrokenEvent):
ts = next(iter(self.tasks.values()))
return {ts: "foo"}, []

@gen_cluster(client=True, Worker=MyWorker, nthreads=[("", 1)])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_log_invalid_transitions(c, s, a):
x = c.submit(inc, 1, key="task-name")
await x

with pytest.raises(InvalidTransition):
a.handle_stimulus(BrokenEvent(stimulus_id="test"))
ts = a.tasks["task-name"]
ev = PauseEvent(stimulus_id="test")
with mock.patch.object(
WorkerState, "_handle_event", return_value=({ts: "foo"}, [])
):
with pytest.raises(InvalidTransition):
a.handle_stimulus(ev)

while not s.events["invalid-worker-transition"]:
await asyncio.sleep(0.01)
Expand Down
42 changes: 23 additions & 19 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,9 @@ async def f(ev):
task for task in asyncio.all_tasks() if "execute(f1)" in task.get_name()
)
start = time()
with captured_logger("distributed.worker", level=logging.ERROR) as logger:
with captured_logger(
"distributed.worker_state_machine", level=logging.ERROR
) as logger:
await a.close(timeout=1)
assert "Failed to cancel asyncio task" in logger.getvalue()
assert time() - start < 5
Expand Down Expand Up @@ -2030,7 +2032,7 @@ async def test_gather_dep_from_remote_workers_if_all_local_workers_are_busy(
assert_story(a.story("receive-dep"), [("receive-dep", rw.address, {"f"})])


@gen_cluster(client=True, nthreads=[("127.0.0.1", 0)])
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_worker_client_uses_default_no_close(c, s, a):
"""
If a default client is available in the process, the worker will pick this
Expand All @@ -2057,7 +2059,7 @@ def get_worker_client_id():
assert c is c_def


@gen_cluster(nthreads=[("127.0.0.1", 0)])
@gen_cluster(nthreads=[("127.0.0.1", 1)])
async def test_worker_client_closes_if_created_on_worker_one_worker(s, a):
async with Client(s.address, set_as_default=False, asynchronous=True) as c:
with pytest.raises(ValueError):
Expand Down Expand Up @@ -2542,7 +2544,7 @@ def raise_exc(*args):
await asyncio.sleep(0.01)


@gen_cluster(client=True, nthreads=[("127.0.0.1", x) for x in range(4)])
@gen_cluster(client=True, nthreads=[("", x) for x in (1, 2, 3, 4)])
async def test_hold_on_to_replicas(c, s, *workers):
f1 = c.submit(inc, 1, workers=[workers[0].address], key="f1")
f2 = c.submit(inc, 2, workers=[workers[1].address], key="f2")
Expand Down Expand Up @@ -3283,38 +3285,40 @@ async def test_Worker__to_dict(c, s, a):
"type",
"id",
"scheduler",
"nthreads",
"address",
"status",
"thread_id",
"logs",
"config",
"incoming_transfer_log",
"outgoing_transfer_log",
# Attributes of WorkerMemoryManager
"data",
"max_spill",
"memory_limit",
"memory_monitor_interval",
"memory_pause_fraction",
"memory_spill_fraction",
"memory_target_fraction",
# Attributes of WorkerState
"nthreads",
"running",
"ready",
"constrained",
"executing",
"long_running",
"executing_count",
"in_flight_tasks",
"in_flight_workers",
"busy_workers",
"log",
"stimulus_log",
"transition_counter",
"tasks",
"logs",
"config",
"incoming_transfer_log",
"outgoing_transfer_log",
"data_needed",
"data_needed_per_worker",
# attributes of WorkerMemoryManager
"data",
"max_spill",
"memory_limit",
"memory_monitor_interval",
"memory_pause_fraction",
"memory_spill_fraction",
"memory_target_fraction",
}
assert d["tasks"]["x"]["key"] == "x"
assert d["data"] == ["x"]
assert d["data"] == {"x": None}


@gen_cluster(nthreads=[])
Expand Down
69 changes: 69 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TaskState,
TaskStateState,
UpdateDataEvent,
WorkerState,
merge_recs_instructions,
)

Expand Down Expand Up @@ -72,6 +73,74 @@ def test_TaskState__to_dict():
]


def test_WorkerState__to_dict():
ws = WorkerState(8)
ws.address = "127.0.0.1.1234"
ws.handle_stimulus(
AcquireReplicasEvent(who_has={"x": ["127.0.0.1:1235"]}, stimulus_id="s1")
)
ws.handle_stimulus(
UpdateDataEvent(data={"y": object()}, report=False, stimulus_id="s2")
)

actual = recursive_to_dict(ws)
# Remove timestamps
for ev in actual["log"]:
del ev[-1]
for stim in actual["stimulus_log"]:
del stim["handled"]

expect = {
"address": "127.0.0.1.1234",
"busy_workers": [],
"constrained": [],
"data": {"y": None},
"data_needed": ["x"],
"data_needed_per_worker": {"127.0.0.1:1235": ["x"]},
"executing": [],
"in_flight_tasks": [],
"in_flight_workers": {},
"log": [
["x", "ensure-task-exists", "released", "s1"],
["x", "released", "fetch", "fetch", {}, "s1"],
["y", "put-in-memory", "s2"],
["y", "receive-from-scatter", "s2"],
],
"long_running": [],
"nthreads": 8,
"ready": [],
"running": True,
"stimulus_log": [
{
"cls": "AcquireReplicasEvent",
"stimulus_id": "s1",
"who_has": {"x": ["127.0.0.1:1235"]},
},
{
"cls": "UpdateDataEvent",
"data": {"y": None},
"report": False,
"stimulus_id": "s2",
},
],
"tasks": {
"x": {
"key": "x",
"priority": [1],
"state": "fetch",
"who_has": ["127.0.0.1:1235"],
},
"y": {
"key": "y",
"nbytes": 16,
"state": "memory",
},
},
"transition_counter": 1,
}
assert actual == expect


def traverse_subclasses(cls: type) -> Iterator[type]:
yield cls
for subcls in cls.__subclasses__():
Expand Down
21 changes: 12 additions & 9 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
reset_logger_locks,
sync,
)
from distributed.worker import WORKER_ANY_RUNNING, InvalidTransition, Worker
from distributed.worker import WORKER_ANY_RUNNING, Worker
from distributed.worker_state_machine import InvalidTransition

try:
import ssl
Expand Down Expand Up @@ -1271,8 +1272,10 @@ def validate_state(*servers: Scheduler | Worker | Nanny) -> None:
Excludes workers wrapped by Nannies and workers manually started by the test.
"""
for s in servers:
if s.validate and hasattr(s, "validate_state"):
s.validate_state() # type: ignore
if isinstance(s, Scheduler) and s.validate:
s.validate_state()
elif isinstance(s, Worker) and s.state.validate:
s.validate_state()


def raises(func, exc=Exception):
Expand Down Expand Up @@ -2322,13 +2325,13 @@ def freeze_data_fetching(w: Worker, *, jump_start: bool = False):
If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated
worker moving out of in_flight_workers.
"""
old_out_connections = w.total_out_connections
old_comm_threshold = w.comm_threshold_bytes
w.total_out_connections = 0
w.comm_threshold_bytes = 0
old_out_connections = w.state.total_out_connections
old_comm_threshold = w.state.comm_threshold_bytes
w.state.total_out_connections = 0
w.state.comm_threshold_bytes = 0
yield
w.total_out_connections = old_out_connections
w.comm_threshold_bytes = old_comm_threshold
w.state.total_out_connections = old_out_connections
w.state.comm_threshold_bytes = old_comm_threshold
if jump_start:
w.status = Status.paused
w.status = Status.running
Expand Down
Loading

0 comments on commit 344868a

Please sign in to comment.