From e77dd56028e3e9800fb965e1fa974102ff314c9b Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 6 Oct 2023 12:39:31 +0200 Subject: [PATCH] Improve typing --- distributed/core.py | 66 +++++++++++++++++++--- distributed/http/worker/prometheus/core.py | 1 + distributed/scheduler.py | 4 +- distributed/system_monitor.py | 4 +- distributed/worker.py | 9 ++- 5 files changed, 67 insertions(+), 17 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index bee894a4db..6dadbd5a50 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -43,6 +43,7 @@ normalize_address, unparse_host_port, ) +from distributed.comm.core import Listener from distributed.compatibility import PeriodicCallback from distributed.counter import Counter from distributed.diskutils import WorkDir, WorkSpace @@ -64,6 +65,8 @@ if TYPE_CHECKING: from typing_extensions import ParamSpec, Self + from distributed.counter import Digest + P = ParamSpec("P") R = TypeVar("R") T = TypeVar("T") @@ -315,12 +318,55 @@ class Server: """ - default_ip = "" - default_port = 0 + default_ip: ClassVar[str] = "" + default_port: ClassVar[int] = 0 + + id: str + blocked_handlers: list[str] + handlers: dict[str, Callable] + stream_handlers: dict[str, Callable] + listeners: list[Listener] + counters: defaultdict[str, Counter] + deserialize: bool + local_directory: str + + monitor: SystemMonitor + io_loop: IOLoop + thread_id: int + + periodic_callbacks: dict[str, PeriodicCallback] + digests: defaultdict[Hashable, Digest] | None + digests_total: defaultdict[Hashable, float] + digests_total_since_heartbeat: defaultdict[Hashable, float] + digests_max: defaultdict[Hashable, float] + + _last_tick: float + _tick_counter: int + _last_tick_counter: int + _tick_interval: float + _tick_interval_observed: float + + _status: Status + + _address: str | None + _listen_address: str | None + _host: str | None + _port: int | None + + _comms: dict[Comm, str | None] + + _ongoing_background_tasks: AsyncTaskGroup + _event_finished: asyncio.Event + + _original_local_dir: str + _updated_sys_path: bool _workspace: WorkSpace _workdir: None | WorkDir + _startup_lock: asyncio.Lock + __startup_exc: Exception | None + def __init__( self, handlers, @@ -673,13 +719,13 @@ def stop(self) -> None: self.monitor.close() if not (stop_listeners := self._stop_listeners()).done(): self._ongoing_background_tasks.call_soon( - asyncio.wait_for(stop_listeners, timeout=None) + asyncio.wait_for(stop_listeners, timeout=None) # type: ignore[arg-type] ) if self._workdir is not None: self._workdir.release() @property - def listener(self): + def listener(self) -> Listener | None: if self.listeners: return self.listeners[0] else: @@ -722,6 +768,7 @@ def address(self) -> str: if self.listener is None: raise ValueError("cannot get address of non-running Server") self._address = self.listener.contact_address + assert self._address return self._address @property @@ -784,7 +831,8 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: Client.dump_cluster_state distributed.utils.recursive_to_dict """ - info = self.identity() + info: dict = {} + info.update(self.identity()) extra = { "address": self.address, "status": self.status.name, @@ -816,7 +864,7 @@ async def listen(self, port_or_addr=None, allow_offload=True, **kwargs): ) self.listeners.append(listener) - def handle_comm(self, comm): + def handle_comm(self, comm: Comm) -> NoOpAwaitable: """Start a background task that dispatches new communications to coroutine-handlers""" try: self._ongoing_background_tasks.call_soon(self._handle_comm, comm) @@ -824,7 +872,7 @@ def handle_comm(self, comm): comm.abort() return NoOpAwaitable() - async def _handle_comm(self, comm): + async def _handle_comm(self, comm: Comm) -> None: """Dispatch new communications to coroutine-handlers Handlers is a dictionary mapping operation names to functions or @@ -963,7 +1011,9 @@ async def _handle_comm(self, comm): "Failed while closing connection to %r: %s", address, e ) - async def handle_stream(self, comm, extra=None): + async def handle_stream( + self, comm: Comm, extra: dict[str, str] | None = None + ) -> None: extra = extra or {} logger.info("Starting established connection to %s", comm.peer_address) diff --git a/distributed/http/worker/prometheus/core.py b/distributed/http/worker/prometheus/core.py index 23aef2d9f6..f62027cf47 100644 --- a/distributed/http/worker/prometheus/core.py +++ b/distributed/http/worker/prometheus/core.py @@ -170,6 +170,7 @@ def collect_crick(self) -> Iterator[Metric]: # The following metrics will export NaN, if the corresponding digests are None if not self.crick_available: return + assert self.server.digests yield GaugeMetricFamily( self.build_name("tick_duration_median"), diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 93a5cbdf1a..9262eeb90b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5863,7 +5863,7 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): self._ongoing_background_tasks.call_soon( - self.remove_worker, + self.remove_worker, # type: ignore[arg-type] address=worker, stimulus_id=f"worker-send-comm-fail-{time()}", ) @@ -5909,7 +5909,7 @@ def send_all(self, client_msgs: Msgs, worker_msgs: Msgs) -> None: pass except (CommClosedError, AttributeError): self._ongoing_background_tasks.call_soon( - self.remove_worker, + self.remove_worker, # type: ignore[arg-type] address=worker, stimulus_id=f"send-all-comm-fail-{time()}", ) diff --git a/distributed/system_monitor.py b/distributed/system_monitor.py index 193fcdcc72..bccce514ca 100644 --- a/distributed/system_monitor.py +++ b/distributed/system_monitor.py @@ -138,7 +138,7 @@ def __init__( self.update() - def recent(self) -> dict[str, Any]: + def recent(self) -> dict[str, float]: return {k: v[-1] for k, v in self.quantities.items()} def get_process_memory(self) -> int: @@ -224,7 +224,7 @@ def __repr__(self) -> str: "N/A" if WINDOWS else self.quantities["num_fds"][-1], ) - def range_query(self, start: int) -> dict[str, list]: + def range_query(self, start: int) -> dict[str, list[float | None]]: if start >= self.count: return {k: [] for k in self.quantities} diff --git a/distributed/worker.py b/distributed/worker.py index 044b4fbd1a..53688d828d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1050,7 +1050,7 @@ async def get_metrics(self) -> dict: self.digests_total_since_heartbeat.clear() - out = dict( + out: dict = dict( task_counts=self.state.task_counter.current_count(by_prefix=False), bandwidth={ "total": self.bandwidth, @@ -1348,9 +1348,8 @@ async def gather(self, who_has: dict[Key, list[str]]) -> dict[Key, object]: else: return {"status": "OK"} - def get_monitor_info( - self, recent: bool = False, start: int = 0 - ) -> dict[str, float]: + # FIXME: Improve typing + def get_monitor_info(self, recent: bool = False, start: int = 0) -> dict[str, Any]: result = dict( range_query=( self.monitor.recent() @@ -2457,7 +2456,7 @@ async def get_profile( ): now = time() + self.scheduler_delay if server: - history = self.io_loop.profile + history = self.io_loop.profile # type: ignore[attr-defined] elif key is None: history = self.profile_history else: