Skip to content

Commit

Permalink
Merge branch 'main' into WSMR/missing_data
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 13, 2022
2 parents 0feda68 + 4b81f06 commit 6551803
Show file tree
Hide file tree
Showing 20 changed files with 220 additions and 203 deletions.
4 changes: 1 addition & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ repos:
- id: isort
language_version: python3
- repo: https://github.com/asottile/pyupgrade
# Do not upgrade: there's a bug in Cython that causes sum(... for ...) to fail;
# it needs sum([... for ...])
rev: v2.13.0
rev: v2.32.0
hooks:
- id: pyupgrade
args:
Expand Down
2 changes: 1 addition & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from tornado.ioloop import PeriodicCallback

from distributed import cluster_dump, preloading
from distributed import versions as version_module # type: ignore
from distributed import versions as version_module
from distributed.batched import BatchedSend
from distributed.cfexecutor import ClientExecutor
from distributed.core import (
Expand Down
13 changes: 9 additions & 4 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import socket
import struct
import sys
import weakref
from itertools import islice
from typing import Any
Expand Down Expand Up @@ -776,10 +777,14 @@ class _ZeroCopyWriter:
# (which would be very large), and set a limit on the number of buffers to
# pass to sendmsg.
if hasattr(socket.socket, "sendmsg"):
try:
SENDMSG_MAX_COUNT = os.sysconf("SC_IOV_MAX") # type: ignore
except Exception:
SENDMSG_MAX_COUNT = 16 # Should be supported on all systems
# Note: can't use WINDOWS constant as it upsets mypy
if sys.platform == "win32":
SENDMSG_MAX_COUNT = 16 # No os.sysconf available
else:
try:
SENDMSG_MAX_COUNT = os.sysconf("SC_IOV_MAX")
except Exception:
SENDMSG_MAX_COUNT = 16 # Should be supported on all systems
else:
SENDMSG_MAX_COUNT = 1 # sendmsg not supported, use send instead

Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]:
if sys.version_info >= (3, 10):
# py3.10 importlib.metadata type annotations are not in mypy yet
# https://github.com/python/typeshed/pull/7331
_entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment]
_entry_points: _EntryPoints = importlib.metadata.entry_points
else:

def _entry_points(
Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
except ImportError:
pass
else:
ucp = None # type: ignore
ucp = None

device_array = None
pre_existing_cuda_context = False
Expand Down
2 changes: 1 addition & 1 deletion distributed/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

LINUX = sys.platform == "linux"
MACOS = sys.platform == "darwin"
WINDOWS = sys.platform.startswith("win")
WINDOWS = sys.platform == "win32"


if sys.version_info >= (3, 9):
Expand Down
4 changes: 2 additions & 2 deletions distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,7 +2794,7 @@ def _get_timeseries(self, restrict_to_existing=False):
back = None
# Remove any periods of zero compute at the front or back of the timeseries
if len(self.plugin.compute):
agg = sum([np.array(v[front:]) for v in self.plugin.compute.values()])
agg = sum(np.array(v[front:]) for v in self.plugin.compute.values())
front2 = len(agg) - len(np.trim_zeros(agg, trim="f"))
front += front2
back = len(np.trim_zeros(agg, trim="b")) - len(agg) or None
Expand Down Expand Up @@ -3192,7 +3192,7 @@ def update(self):
"names": ["Scheduler", "Workers"],
"values": [
s._tick_interval_observed,
sum([w.metrics["event_loop_interval"] for w in s.workers.values()])
sum(w.metrics["event_loop_interval"] for w in s.workers.values())
/ (len(s.workers) or 1),
],
}
Expand Down
2 changes: 1 addition & 1 deletion distributed/http/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_handlers(server, modules: list[str], prefix="/"):
_routes = []
for module_name in modules:
module = importlib.import_module(module_name)
_routes.extend(module.routes) # type: ignore
_routes.extend(module.routes)

routes = []

Expand Down
6 changes: 3 additions & 3 deletions distributed/pytest_resourceleaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test1():
import psutil
import pytest

from distributed.compatibility import WINDOWS
from distributed.metrics import time


Expand Down Expand Up @@ -155,10 +154,11 @@ def format(self, before: int, after: int) -> str:

class FDChecker(ResourceChecker, name="fds"):
def measure(self) -> int:
if WINDOWS:
# Note: can't use WINDOWS constant as it upsets mypy
if sys.platform == "win32":
# Don't use num_handles(); you'll get tens of thousands of reported leaks
return 0
return psutil.Process().num_fds() # type: ignore
return psutil.Process().num_fds()

def has_leak(self, before: int, after: int) -> bool:
return after > before
Expand Down
68 changes: 32 additions & 36 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,14 +744,14 @@ def __repr__(self) -> str:

@property
def nbytes_total(self) -> int:
return sum([tg.nbytes_total for tg in self.groups])
return sum(tg.nbytes_total for tg in self.groups)

def __len__(self) -> int:
return sum(map(len, self.groups))

@property
def duration(self) -> float:
return sum([tg.duration for tg in self.groups])
return sum(tg.duration for tg in self.groups)

@property
def types(self) -> set[str]:
Expand Down Expand Up @@ -1400,7 +1400,9 @@ def new_task(
# State Transitions #
#####################

def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
def _transition(
self, key: str, finish: str, stimulus_id: str, *args, **kwargs
) -> tuple[dict, dict, dict]:
"""Transition a key from its current state to the finish state
Examples
Expand Down Expand Up @@ -1432,9 +1434,9 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
if self.transition_counter_max:
assert self.transition_counter < self.transition_counter_max

recommendations = {} # type: ignore
worker_msgs = {} # type: ignore
client_msgs = {} # type: ignore
recommendations: dict = {}
worker_msgs: dict = {}
client_msgs: dict = {}

if self.plugins:
dependents = set(ts.dependents)
Expand All @@ -1444,47 +1446,41 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
if func is not None:
recommendations, client_msgs, worker_msgs = func(
self, key, stimulus_id, *args, **kwargs
) # type: ignore
)

elif "released" not in (start, finish):
assert not args and not kwargs, (args, kwargs, start, finish)
a_recs: dict
a_cmsgs: dict
a_wmsgs: dict
a: tuple = self._transition(key, "released", stimulus_id)
a_recs, a_cmsgs, a_wmsgs = a
a_recs, a_cmsgs, a_wmsgs = self._transition(
key, "released", stimulus_id
)

v = a_recs.get(key, finish)
func = self._TRANSITIONS_TABLE["released", v]
b_recs: dict
b_cmsgs: dict
b_wmsgs: dict
b: tuple = func(self, key, stimulus_id) # type: ignore
b_recs, b_cmsgs, b_wmsgs = b
b_recs, b_cmsgs, b_wmsgs = func(self, key, stimulus_id)

recommendations.update(a_recs)
for c, new_msgs in a_cmsgs.items():
msgs = client_msgs.get(c) # type: ignore
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs
for w, new_msgs in a_wmsgs.items():
msgs = worker_msgs.get(w) # type: ignore
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs

recommendations.update(b_recs)
for c, new_msgs in b_cmsgs.items():
msgs = client_msgs.get(c) # type: ignore
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs
for w, new_msgs in b_wmsgs.items():
msgs = worker_msgs.get(w) # type: ignore
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
Expand Down Expand Up @@ -1953,7 +1949,7 @@ def transition_processing_memory(
assert not ts.exception_blame
assert ts.state == "processing"

ws = self.workers.get(worker) # type: ignore
ws = self.workers.get(worker)
if ws is None:
recommendations[key] = "released"
return recommendations, client_msgs, worker_msgs
Expand Down Expand Up @@ -2280,7 +2276,7 @@ def transition_processing_erred(
traceback=None,
exception_text: str = None,
traceback_text: str = None,
worker: str = None, # type: ignore
worker: str = None,
**kwargs,
):
ws: WorkerState
Expand Down Expand Up @@ -3455,7 +3451,7 @@ def heartbeat_worker(
) -> dict[str, Any]:
address = self.coerce_address(address, resolve_address)
address = normalize_address(address)
ws: WorkerState = self.workers.get(address) # type: ignore
ws = self.workers.get(address)
if ws is None:
return {"status": "missing"}

Expand Down Expand Up @@ -4773,7 +4769,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
def handle_worker_status_change(
self, status: str, worker: str, stimulus_id: str
) -> None:
ws: WorkerState = self.workers.get(worker) # type: ignore
ws = self.workers.get(worker)
if not ws:
return
prev_status = ws.status
Expand Down Expand Up @@ -5285,9 +5281,9 @@ async def gather_on_worker(
)
return set(who_has)

ws: WorkerState = self.workers.get(worker_address) # type: ignore
ws = self.workers.get(worker_address)

if ws is None:
if not ws:
logger.warning(f"Worker {worker_address} lost during replication")
return set(who_has)
elif result["status"] == "OK":
Expand Down Expand Up @@ -5339,8 +5335,8 @@ async def delete_worker_data(
)
return

ws: WorkerState = self.workers.get(worker_address) # type: ignore
if ws is None:
ws = self.workers.get(worker_address)
if not ws:
return

for key in keys:
Expand Down Expand Up @@ -5912,9 +5908,9 @@ def workers_to_close(
groups = groupby(key, self.workers.values())

limit_bytes = {
k: sum([ws.memory_limit for ws in v]) for k, v in groups.items()
k: sum(ws.memory_limit for ws in v) for k, v in groups.items()
}
group_bytes = {k: sum([ws.nbytes for ws in v]) for k, v in groups.items()}
group_bytes = {k: sum(ws.nbytes for ws in v) for k, v in groups.items()}

limit = sum(limit_bytes.values())
total = sum(group_bytes.values())
Expand Down Expand Up @@ -6866,8 +6862,8 @@ def profile_to_figure(state):
tasks_timings=tasks_timings,
address=self.address,
nworkers=len(self.workers),
threads=sum([ws.nthreads for ws in self.workers.values()]),
memory=format_bytes(sum([ws.memory_limit for ws in self.workers.values()])),
threads=sum(ws.nthreads for ws in self.workers.values()),
memory=format_bytes(sum(ws.memory_limit for ws in self.workers.values())),
code=code,
dask_version=dask.__version__,
distributed_version=distributed.__version__,
Expand Down Expand Up @@ -7101,8 +7097,8 @@ def adaptive_target(self, target_duration=None):
cpu = max(1, cpu)

# add more workers if more than 60% of memory is used
limit = sum([ws.memory_limit for ws in self.workers.values()])
used = sum([ws.nbytes for ws in self.workers.values()])
limit = sum(ws.memory_limit for ws in self.workers.values())
used = sum(ws.nbytes for ws in self.workers.values())
memory = 0
if used > 0.6 * limit and limit > 0:
memory = 2 * len(self.workers)
Expand Down Expand Up @@ -7514,7 +7510,7 @@ def validate_task_state(ts: TaskState) -> None:

if ts.actor:
if ts.state == "memory":
assert sum([ts in ws.actors for ws in ts.who_has]) == 1
assert sum(ts in ws.actors for ws in ts.who_has) == 1
if ts.state == "processing":
assert ts.processing_on
assert ts in ts.processing_on.actors
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/shuffle_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_output_partition(self, i: int) -> pd.DataFrame:
self.output_partitions_left > 0
), f"No outputs remaining, but requested output partition {i} on {self.worker.address}."

sync(self.worker.loop, self.multi_file.flush) # type: ignore
sync(self.worker.loop, self.multi_file.flush)
try:
df = self.multi_file.read(i)
with self.time("cpu"):
Expand Down
5 changes: 2 additions & 3 deletions distributed/spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

logger = logging.getLogger(__name__)
has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0")
# At the moment of writing, zict 2.2.0 has not been released yet. Support git tip.
has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0.dev2")
has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0")


class SpilledSize(NamedTuple):
Expand All @@ -31,7 +30,7 @@ class SpilledSize(NamedTuple):
def __add__(self, other: SpilledSize) -> SpilledSize: # type: ignore
return SpilledSize(self.memory + other.memory, self.disk + other.disk)

def __sub__(self, other: SpilledSize) -> SpilledSize: # type: ignore
def __sub__(self, other: SpilledSize) -> SpilledSize:
return SpilledSize(self.memory - other.memory, self.disk - other.disk)


Expand Down
18 changes: 10 additions & 8 deletions distributed/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def memory_limit() -> int:
limit = psutil.virtual_memory().total

# Check cgroups if available
# Note: can't use LINUX and WINDOWS constants as they upset mypy
if sys.platform == "linux":
try:
with open("/sys/fs/cgroup/memory/memory.limit_in_bytes") as f:
Expand All @@ -27,14 +28,15 @@ def memory_limit() -> int:
pass

# Check rlimit if available
try:
import resource

hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1] # type: ignore
if hard_limit > 0:
limit = min(limit, hard_limit)
except (ImportError, OSError):
pass
if sys.platform != "win32":
try:
import resource

hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1]
if hard_limit > 0:
limit = min(limit, hard_limit)
except (ImportError, OSError):
pass

return limit

Expand Down
Loading

0 comments on commit 6551803

Please sign in to comment.