Skip to content

Commit

Permalink
MemorySampler to show memory breakdown
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 27, 2022
1 parent c82bba5 commit 34d1579
Showing 1 changed file with 75 additions and 38 deletions.
113 changes: 75 additions & 38 deletions distributed/diagnostics/memory_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import uuid
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Collection, Iterator
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -49,17 +49,27 @@ class MemorySampler:
ms.plot()
"""

samples: dict[str, list[tuple[float, int]]]
# {label: [[timestamp, nbytes, nbytes, ...], ...]}
samples: dict[str, list[list[float]]]
# {label: [measure, measure, ...]
measures: dict[str, list[str]]

def __init__(self):
self.samples = {}
self.measures = {}

def sample(
self,
label: str | None = None,
*,
client: Client | None = None,
measure: str = "process",
measure: str
| Collection[str] = (
"managed_in_memory",
"unmanaged_old",
"unmanaged_recent",
"managed_spilled",
),
interval: float = 0.5,
) -> Any:
"""Context manager that records memory usage in the cluster.
Expand All @@ -72,13 +82,13 @@ def sample(
==========
label: str, optional
Tag to record the samples under in the self.samples dict.
Default: automatically generate a random label
Default: automatically generate a unique label
client: Client, optional
client used to connect to the scheduler.
Default: use the global client
measure: str, optional
One of the measures from :class:`distributed.scheduler.MemoryState`.
Default: sample process memory
measure: str or Collection[str], optional
One or more measures from :class:`distributed.scheduler.MemoryState`.
Default: sample process + spilled memory, broken down
interval: float, optional
sampling interval, in seconds.
Default: 0.5
Expand All @@ -88,39 +98,54 @@ def sample(

client = get_client()

measures = [measure] if isinstance(measure, str) else list(measure)

if not label:
for i in range(len(self.samples) + 1):
label = f"Samples {i}"
if label not in self.samples:
break
assert label

self.samples[label] = []
self.measures[label] = measures

if client.asynchronous:
return self._sample_async(label, client, measure, interval)
return self._sample_async(client, label, measures, interval)
else:
return self._sample_sync(label, client, measure, interval)
return self._sample_sync(client, label, measures, interval)

@contextmanager
def _sample_sync(
self, label: str | None, client: Client, measure: str, interval: float
self, client: Client, label: str, measures: list[str], interval: float
) -> Iterator[None]:
key = client.sync(
client.scheduler.memory_sampler_start,
client=client.id,
measure=measure,
measures=measures,
interval=interval,
)
try:
yield
finally:
samples = client.sync(client.scheduler.memory_sampler_stop, key=key)
self.samples[label or key] = samples
samples = client.sync(
client.scheduler.memory_sampler_stop,
key=key,
)
self.samples[label] = samples

@asynccontextmanager
async def _sample_async(
self, label: str | None, client: Client, measure: str, interval: float
self, client: Client, label: str, measures: list[str], interval: float
) -> AsyncIterator[None]:
key = await client.scheduler.memory_sampler_start(
client=client.id, measure=measure, interval=interval
client=client.id, measures=measures, interval=interval
)
try:
yield
finally:
samples = await client.scheduler.memory_sampler_stop(key=key)
self.samples[label or key] = samples
self.samples[label] = samples

def to_pandas(self, *, align: bool = False) -> pd.DataFrame:
"""Return the data series as a pandas.Dataframe.
Expand All @@ -134,28 +159,33 @@ def to_pandas(self, *, align: bool = False) -> pd.DataFrame:
"""
import pandas as pd

ss = {}
for (label, s_list) in self.samples.items():
dfs = []
for label, s_list in self.samples.items():
assert s_list # There's always at least one sample
s = pd.DataFrame(s_list).set_index(0)[1]
s.index = pd.to_datetime(s.index, unit="s")
s.name = label
df = pd.DataFrame(s_list).set_index(0)
df.index = pd.to_datetime(df.index, unit="s")
df.columns = pd.MultiIndex.from_tuples(
[(label, measure) for measure in self.measures[label]],
names=["label", "measure"],
)
if align:
# convert datetime to timedelta from the first sample
s.index -= s.index[0]
ss[label] = s
df.index -= df.index[0]
dfs.append(df)

df = pd.DataFrame(ss)

if len(ss) > 1:
# Forward-fill NaNs in the middle of a series created either by overlapping
# sampling time range or by align=True. Do not ffill series beyond their
# last sample.
df = df.ffill().where(~pd.isna(df.bfill()))
if len(dfs) == 1:
return dfs[0]

df = pd.concat(dfs, axis=1).sort_index()
# Forward-fill NaNs in the middle of a series created either by overlapping
# sampling time range or by align=True. Do not ffill series beyond their
# last sample.
df = df.ffill().where(~pd.isna(df.bfill()))
return df

def plot(self, *, align: bool = False, **kwargs: Any) -> Any:
def plot(
self, *, align: bool = False, kind: str | None = None, **kwargs: Any
) -> Any:
"""Plot data series collected so far
Parameters
Expand All @@ -170,7 +200,10 @@ def plot(self, *, align: bool = False, **kwargs: Any) -> Any:
Output of :meth:`pandas.DataFrame.plot`
"""
df = self.to_pandas(align=align) / 2**30
if not kind:
kind = "line" if len(self.samples) > 1 else "area"
return df.plot(
kind=kind,
xlabel="time",
ylabel="Cluster memory (GiB)",
**kwargs,
Expand All @@ -181,7 +214,8 @@ class MemorySamplerExtension:
"""Scheduler extension - server side of MemorySampler"""

scheduler: Scheduler
samples: dict[str, list[tuple[float, int]]]
# {unique key: [[timestamp, nbytes, nbytes, ...], ...]}
samples: dict[str, list[list[float]]]

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
Expand All @@ -190,19 +224,22 @@ def __init__(self, scheduler: Scheduler):
self.scheduler.handlers["memory_sampler_stop"] = self.stop
self.samples = {}

def start(self, client: str, measure: str, interval: float) -> str:
def start(self, client: str, measures: list[str], interval: float) -> str:
"""Start periodically sampling memory"""
assert not measure.startswith("_")
assert isinstance(getattr(self.scheduler.memory, measure), int)
mem = self.scheduler.memory
for measure in measures:
assert not measure.startswith("_")
assert isinstance(getattr(mem, measure), int)

key = str(uuid.uuid4())
self.samples[key] = []

def sample():
if client in self.scheduler.clients:
ts = datetime.now().timestamp()
nbytes = getattr(self.scheduler.memory, measure)
self.samples[key].append((ts, nbytes))
mem = self.scheduler.memory
nbytes = [getattr(mem, measure) for measure in measures]
self.samples[key].append([ts] + nbytes)
else:
self.stop(key)

Expand All @@ -216,7 +253,7 @@ def sample():

return key

def stop(self, key: str) -> list[tuple[float, int]]:
def stop(self, key: str) -> list[list[float]]:
"""Stop sampling and return the samples"""
pc = self.scheduler.periodic_callbacks.pop("MemorySampler-" + key)
pc.stop()
Expand Down

0 comments on commit 34d1579

Please sign in to comment.