Skip to content

Commit

Permalink
Improve typing (#8239)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Oct 9, 2023
1 parent 1e6794f commit 275db75
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 17 deletions.
66 changes: 58 additions & 8 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
normalize_address,
unparse_host_port,
)
from distributed.comm.core import Listener
from distributed.compatibility import PeriodicCallback
from distributed.counter import Counter
from distributed.diskutils import WorkDir, WorkSpace
Expand All @@ -64,6 +65,8 @@
if TYPE_CHECKING:
from typing_extensions import ParamSpec, Self

from distributed.counter import Digest

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
Expand Down Expand Up @@ -315,12 +318,55 @@ class Server:
"""

default_ip = ""
default_port = 0
default_ip: ClassVar[str] = ""
default_port: ClassVar[int] = 0

id: str
blocked_handlers: list[str]
handlers: dict[str, Callable]
stream_handlers: dict[str, Callable]
listeners: list[Listener]
counters: defaultdict[str, Counter]
deserialize: bool

local_directory: str

monitor: SystemMonitor
io_loop: IOLoop
thread_id: int

periodic_callbacks: dict[str, PeriodicCallback]
digests: defaultdict[Hashable, Digest] | None
digests_total: defaultdict[Hashable, float]
digests_total_since_heartbeat: defaultdict[Hashable, float]
digests_max: defaultdict[Hashable, float]

_last_tick: float
_tick_counter: int
_last_tick_counter: int
_tick_interval: float
_tick_interval_observed: float

_status: Status

_address: str | None
_listen_address: str | None
_host: str | None
_port: int | None

_comms: dict[Comm, str | None]

_ongoing_background_tasks: AsyncTaskGroup
_event_finished: asyncio.Event

_original_local_dir: str
_updated_sys_path: bool
_workspace: WorkSpace
_workdir: None | WorkDir

_startup_lock: asyncio.Lock
__startup_exc: Exception | None

def __init__(
self,
handlers,
Expand Down Expand Up @@ -673,13 +719,13 @@ def stop(self) -> None:
self.monitor.close()
if not (stop_listeners := self._stop_listeners()).done():
self._ongoing_background_tasks.call_soon(
asyncio.wait_for(stop_listeners, timeout=None)
asyncio.wait_for(stop_listeners, timeout=None) # type: ignore[arg-type]
)
if self._workdir is not None:
self._workdir.release()

@property
def listener(self):
def listener(self) -> Listener | None:
if self.listeners:
return self.listeners[0]
else:
Expand Down Expand Up @@ -722,6 +768,7 @@ def address(self) -> str:
if self.listener is None:
raise ValueError("cannot get address of non-running Server")
self._address = self.listener.contact_address
assert self._address
return self._address

@property
Expand Down Expand Up @@ -784,7 +831,8 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
Client.dump_cluster_state
distributed.utils.recursive_to_dict
"""
info = self.identity()
info: dict = {}
info.update(self.identity())
extra = {
"address": self.address,
"status": self.status.name,
Expand Down Expand Up @@ -816,15 +864,15 @@ async def listen(self, port_or_addr=None, allow_offload=True, **kwargs):
)
self.listeners.append(listener)

def handle_comm(self, comm):
def handle_comm(self, comm: Comm) -> NoOpAwaitable:
"""Start a background task that dispatches new communications to coroutine-handlers"""
try:
self._ongoing_background_tasks.call_soon(self._handle_comm, comm)
except AsyncTaskGroupClosedError:
comm.abort()
return NoOpAwaitable()

async def _handle_comm(self, comm):
async def _handle_comm(self, comm: Comm) -> None:
"""Dispatch new communications to coroutine-handlers
Handlers is a dictionary mapping operation names to functions or
Expand Down Expand Up @@ -963,7 +1011,9 @@ async def _handle_comm(self, comm):
"Failed while closing connection to %r: %s", address, e
)

async def handle_stream(self, comm, extra=None):
async def handle_stream(
self, comm: Comm, extra: dict[str, str] | None = None
) -> None:
extra = extra or {}
logger.info("Starting established connection to %s", comm.peer_address)

Expand Down
1 change: 1 addition & 0 deletions distributed/http/worker/prometheus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def collect_crick(self) -> Iterator[Metric]:
# The following metrics will export NaN, if the corresponding digests are None
if not self.crick_available:
return
assert self.server.digests

yield GaugeMetricFamily(
self.build_name("tick_duration_median"),
Expand Down
4 changes: 2 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5863,7 +5863,7 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None:
stream_comms[worker].send(msg)
except (CommClosedError, AttributeError):
self._ongoing_background_tasks.call_soon(
self.remove_worker,
self.remove_worker, # type: ignore[arg-type]
address=worker,
stimulus_id=f"worker-send-comm-fail-{time()}",
)
Expand Down Expand Up @@ -5909,7 +5909,7 @@ def send_all(self, client_msgs: Msgs, worker_msgs: Msgs) -> None:
pass
except (CommClosedError, AttributeError):
self._ongoing_background_tasks.call_soon(
self.remove_worker,
self.remove_worker, # type: ignore[arg-type]
address=worker,
stimulus_id=f"send-all-comm-fail-{time()}",
)
Expand Down
4 changes: 2 additions & 2 deletions distributed/system_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(

self.update()

def recent(self) -> dict[str, Any]:
def recent(self) -> dict[str, float]:
return {k: v[-1] for k, v in self.quantities.items()}

def get_process_memory(self) -> int:
Expand Down Expand Up @@ -224,7 +224,7 @@ def __repr__(self) -> str:
"N/A" if WINDOWS else self.quantities["num_fds"][-1],
)

def range_query(self, start: int) -> dict[str, list]:
def range_query(self, start: int) -> dict[str, list[float | None]]:
if start >= self.count:
return {k: [] for k in self.quantities}

Expand Down
9 changes: 4 additions & 5 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ async def get_metrics(self) -> dict:

self.digests_total_since_heartbeat.clear()

out = dict(
out: dict = dict(
task_counts=self.state.task_counter.current_count(by_prefix=False),
bandwidth={
"total": self.bandwidth,
Expand Down Expand Up @@ -1348,9 +1348,8 @@ async def gather(self, who_has: dict[Key, list[str]]) -> dict[Key, object]:
else:
return {"status": "OK"}

def get_monitor_info(
self, recent: bool = False, start: int = 0
) -> dict[str, float]:
# FIXME: Improve typing
def get_monitor_info(self, recent: bool = False, start: int = 0) -> dict[str, Any]:
result = dict(
range_query=(
self.monitor.recent()
Expand Down Expand Up @@ -2457,7 +2456,7 @@ async def get_profile(
):
now = time() + self.scheduler_delay
if server:
history = self.io_loop.profile
history = self.io_loop.profile # type: ignore[attr-defined]
elif key is None:
history = self.profile_history
else:
Expand Down

0 comments on commit 275db75

Please sign in to comment.