diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 320edc391a..43cdbe5508 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -139,7 +139,8 @@ if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) - from typing_extensions import TypeAlias + # TODO import from typing (requires Python >=3.11) + from typing_extensions import Self, TypeAlias from dask.highlevelgraph import HighLevelGraph @@ -428,7 +429,7 @@ class WorkerState: versions: dict[str, Any] #: Address of the associated :class:`~distributed.nanny.Nanny`, if present - nanny: str + nanny: str | None #: Read-only worker status, synced one way from the remote Worker object status: Status @@ -522,7 +523,7 @@ def __init__( nthreads: int = 0, memory_limit: int, local_directory: str, - nanny: str, + nanny: str | None, server_id: str, services: dict[str, int] | None = None, versions: dict[str, Any] | None = None, @@ -1943,9 +1944,9 @@ def _transition( if self.transition_counter_max: assert self.transition_counter < self.transition_counter_max - recommendations: dict = {} - worker_msgs: dict = {} - client_msgs: dict = {} + recommendations: Recs = {} + worker_msgs: Msgs = {} + client_msgs: Msgs = {} if self.plugins: dependents = set(ts.dependents) @@ -3723,7 +3724,7 @@ async def post(self): self._last_client = None self._last_time = 0 unrunnable = set() - queued: HeapSet[TaskState] = HeapSet(key=operator.attrgetter("priority")) + queued = HeapSet(key=operator.attrgetter("priority")) self.datasets = {} @@ -3894,7 +3895,7 @@ async def post(self): # Administration # ################## - def __repr__(self): + def __repr__(self) -> str: return ( f"" ) - def _repr_html_(self): + def _repr_html_(self) -> str: return get_template("scheduler.html.j2").render( address=self.address, workers=self.workers, @@ -3910,7 +3911,7 @@ def _repr_html_(self): tasks=self.tasks, ) - def identity(self): + def identity(self) -> dict[str, Any]: """Basic information about ourselves and our cluster""" d = { "type": type(self).__name__, @@ -4029,7 +4030,7 @@ def get_worker_service_addr( else: return ws.host, port - async def start_unsafe(self): + async def start_unsafe(self) -> Self: """Clear out old state and restart all running coroutines""" await super().start_unsafe() @@ -4080,7 +4081,7 @@ async def start_unsafe(self): fn = self.scheduler_file # remove file when we close the process - def del_scheduler_file(): + def del_scheduler_file() -> None: if os.path.exists(fn): os.remove(fn) @@ -4207,7 +4208,7 @@ def heartbeat_worker( local_now = time() host_info = host_info or {} - dh: dict = self.host_info.setdefault(host, {}) + dh = self.host_info.setdefault(host, {}) dh["last-seen"] = local_now frac = 1 / len(self.workers) @@ -5012,7 +5013,9 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: assert qts.state == "processing" assert not self.queued or self.queued.peek() != qts - def stimulus_task_finished(self, key, worker, stimulus_id, run_id, **kwargs): + def stimulus_task_finished( + self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any + ) -> RecsMsgs: """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s[%d] %s", key, run_id, worker) @@ -5020,8 +5023,7 @@ def stimulus_task_finished(self, key, worker, stimulus_id, run_id, **kwargs): client_msgs: Msgs = {} worker_msgs: Msgs = {} - ws: WorkerState = self.workers[worker] - ts: TaskState = self.tasks.get(key) + ts = self.tasks.get(key) if ts is None or ts.state in ("released", "queued", "no-worker"): logger.debug( "Received already computed task, worker: %s, state: %s" @@ -5063,14 +5065,8 @@ def stimulus_task_finished(self, key, worker, stimulus_id, run_id, **kwargs): if ts.metadata is None: ts.metadata = dict() ts.metadata.update(kwargs["metadata"]) - r: tuple = self._transition( - key, "memory", stimulus_id, worker=worker, **kwargs - ) - recommendations, client_msgs, worker_msgs = r + return self._transition(key, "memory", stimulus_id, worker=worker, **kwargs) - if ts.state == "memory": - assert ts.who_has - assert ws in ts.who_has return recommendations, client_msgs, worker_msgs def stimulus_task_erred( @@ -5086,7 +5082,7 @@ def stimulus_task_erred( """Mark that a task has erred on a particular worker""" logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = self.tasks.get(key) + ts = self.tasks.get(key) if ts is None or ts.state != "processing": return {}, {}, {} @@ -5110,7 +5106,9 @@ def stimulus_task_erred( **kwargs, ) - def stimulus_retry(self, keys, client=None): + def stimulus_retry( + self, keys: Collection[Key], client: str | None = None + ) -> tuple[Key, ...]: logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -5346,20 +5344,17 @@ def stimulus_cancel( ) -> None: """Stop execution on a list of keys""" logger.info("Client %s requests to cancel %d keys", client, len(keys)) - if client: - self.log_event( - client, {"action": "cancel", "count": len(keys), "force": force} - ) + self.log_event(client, {"action": "cancel", "count": len(keys), "force": force}) + cs = self.clients.get(client) + if not cs: + return + cancelled_keys = [] clients = [] for key in keys: - ts: TaskState | None = self.tasks.get(key) + ts = self.tasks.get(key) if not ts: continue - try: - cs: ClientState = self.clients[client] - except KeyError: - return if force or ts.who_wants == {cs}: # no one else wants this key if ts.dependents: @@ -5378,11 +5373,12 @@ def stimulus_cancel( ) self.report({"op": "cancelled-keys", "keys": cancelled_keys}) - def client_desires_keys(self, keys=None, client=None): - cs: ClientState = self.clients.get(client) + def client_desires_keys(self, keys: Collection[Key], client: str) -> None: + cs = self.clients.get(client) if cs is None: # For publish, queues etc. self.clients[client] = cs = ClientState(client) + for k in keys: ts = self.tasks.get(k) if ts is None: @@ -5396,7 +5392,9 @@ def client_desires_keys(self, keys=None, client=None): if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys(self, keys=None, client=None, stimulus_id=None): + def client_releases_keys( + self, keys: Collection[Key], client: str, stimulus_id: str | None = None + ) -> None: """Remove keys from client desired list""" stimulus_id = stimulus_id or f"client-releases-keys-{time()}" if not isinstance(keys, list): @@ -5409,9 +5407,9 @@ def client_releases_keys(self, keys=None, client=None, stimulus_id=None): self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) - def client_heartbeat(self, client=None): + def client_heartbeat(self, client: str) -> None: """Handle heartbeats from Client""" - cs: ClientState = self.clients[client] + cs = self.clients[client] cs.last_seen = time() ################### @@ -5419,7 +5417,7 @@ def client_heartbeat(self, client=None): ################### def validate_released(self, key: Key) -> None: - ts: TaskState = self.tasks[key] + ts = self.tasks[key] assert ts.state == "released" assert not ts.waiters assert not ts.waiting_on @@ -5430,7 +5428,7 @@ def validate_released(self, key: Key) -> None: assert ts not in self.queued def validate_waiting(self, key: Key) -> None: - ts: TaskState = self.tasks[key] + ts = self.tasks[key] assert ts.waiting_on assert not ts.who_has assert not ts.processing_on @@ -5442,8 +5440,7 @@ def validate_waiting(self, key: Key) -> None: assert ts in (dts.waiters or ()) # XXX even if dts._who_has? def validate_queued(self, key: Key) -> None: - ts: TaskState = self.tasks[key] - dts: TaskState + ts = self.tasks[key] assert ts in self.queued assert not ts.waiting_on assert not ts.who_has @@ -5456,8 +5453,7 @@ def validate_queued(self, key: Key) -> None: assert ts in (dts.waiters or ()) def validate_processing(self, key: Key) -> None: - ts: TaskState = self.tasks[key] - dts: TaskState + ts = self.tasks[key] assert not ts.waiting_on ws = ts.processing_on assert ws @@ -5469,8 +5465,7 @@ def validate_processing(self, key: Key) -> None: assert ts in (dts.waiters or ()) def validate_memory(self, key: Key) -> None: - ts: TaskState = self.tasks[key] - dts: TaskState + ts = self.tasks[key] assert ts.who_has assert bool(ts in self.replicated_tasks) == (len(ts.who_has) > 1) assert not ts.processing_on @@ -5484,7 +5479,7 @@ def validate_memory(self, key: Key) -> None: assert ts not in (dts.waiting_on or ()) def validate_no_worker(self, key: Key) -> None: - ts: TaskState = self.tasks[key] + ts = self.tasks[key] assert ts in self.unrunnable assert not ts.waiting_on assert ts in self.unrunnable @@ -5495,7 +5490,7 @@ def validate_no_worker(self, key: Key) -> None: assert dts.who_has def validate_erred(self, key: Key) -> None: - ts: TaskState = self.tasks[key] + ts = self.tasks[key] assert ts.exception_blame assert not ts.who_has assert ts not in self.queued @@ -5628,8 +5623,7 @@ def report( if ts is None: msg_key = msg.get("key") if msg_key is not None: - tasks: dict = self.tasks - ts = tasks.get(msg_key) + ts = self.tasks.get(msg_key) if ts is None and client is None: # Notify all clients @@ -5733,7 +5727,7 @@ def remove_client(self, client: str, stimulus_id: str | None = None) -> None: except Exception as e: logger.exception(e) - async def remove_client_from_events(): + async def remove_client_from_events() -> None: # If the client isn't registered anymore after the delay, remove from events if client not in self.clients and client in self.events: del self.events[client] @@ -5751,7 +5745,7 @@ def send_task_to_worker( ) -> None: """Send a single computational task to a worker""" try: - msg: dict = self._task_to_msg(ts, duration) + msg = self._task_to_msg(ts, duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -6026,9 +6020,8 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: This also handles connection failures by adding a callback to remove the worker on the next cycle. """ - stream_comms: dict = self.stream_comms try: - stream_comms[worker].send(msg) + self.stream_comms[worker].send(msg) except (CommClosedError, AttributeError): self._ongoing_background_tasks.call_soon( self.remove_worker, # type: ignore[arg-type] @@ -6038,8 +6031,7 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: def client_send(self, client, msg): """Send message to client""" - client_comms: dict = self.client_comms - c = client_comms.get(client) + c = self.client_comms.get(client) if c is None: return try: @@ -6318,8 +6310,8 @@ async def broadcast( self, *, msg: dict, - workers: list[str] | None = None, - hosts: list[str] | None = None, + workers: Collection[str] | None = None, + hosts: Collection[str] | None = None, nanny: bool = False, serializers: Any = None, on_error: Literal["raise", "return", "return_pickle", "ignore"] = "raise", @@ -6330,15 +6322,16 @@ async def broadcast( workers = list(self.workers) else: workers = [] + else: + workers = list(workers) if hosts is not None: for host in hosts: - dh: dict = self.host_info.get(host) # type: ignore + dh = self.host_info.get(host) if dh is not None: workers.extend(dh["addresses"]) - # TODO replace with worker_list if nanny: - addresses = [self.workers[w].nanny for w in workers] + addresses = [n for w in workers if (n := self.workers[w].nanny) is not None] else: addresses = workers @@ -6371,10 +6364,7 @@ async def send_message(addr): f"or 'ignore'; got {on_error!r}" ) - results = await All( - [send_message(address) for address in addresses if address is not None] - ) - + results = await All([send_message(address) for address in addresses]) return {k: v for k, v in zip(workers, results) if v is not ERROR} async def proxy( @@ -6435,7 +6425,7 @@ async def gather_on_worker( raise ValueError(f"Unexpected message from {worker_address}: {result}") for key in keys_ok: - ts: TaskState = self.tasks.get(key) # type: ignore + ts = self.tasks.get(key) if ts is None or ts.state != "memory": logger.warning(f"Key lost during replication: {key}") continue @@ -6475,7 +6465,7 @@ async def delete_worker_data( return for key in keys: - ts: TaskState = self.tasks.get(key) # type: ignore + ts = self.tasks.get(key) if ts is not None and ws in (ts.who_has or ()): assert ts.state == "memory" self.remove_replica(ts, ws) @@ -7355,7 +7345,7 @@ def add_keys( """ if worker not in self.workers: return "not found" - ws: WorkerState = self.workers[worker] + ws = self.workers[worker] redundant_replicas = [] for key in keys: ts = self.tasks.get(key) @@ -7482,7 +7472,7 @@ def log_worker_event( msg["worker"] = worker self.log_event(topic, msg) - def subscribe_worker_status(self, comm: Comm) -> str: + def subscribe_worker_status(self, comm: Comm) -> dict[str, Any]: WorkerStatusPlugin(self, comm) ident = self.identity() for v in ident["workers"].values(): @@ -7918,7 +7908,7 @@ def _reschedule( def add_resources( self, worker: str, resources: dict | None = None ) -> Literal["OK"]: - ws: WorkerState = self.workers[worker] + ws = self.workers[worker] if resources: ws.resources.update(resources) ws.used_resources = {} @@ -7931,7 +7921,7 @@ def add_resources( return "OK" def remove_resources(self, worker: str) -> None: - ws: WorkerState = self.workers[worker] + ws = self.workers[worker] for resource in ws.resources: dr = self.resources.setdefault(resource, {}) del dr[worker]