Skip to content

Commit

Permalink
Add parallel coordinate plot type to aggregate plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
bojan-karlas committed Jul 11, 2024
1 parent d4b5e74 commit ed5b43e
Showing 1 changed file with 216 additions and 11 deletions.
227 changes: 216 additions & 11 deletions experiments/datascope/experiments/reports/aggregate_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,24 @@
from enum import Enum
from functools import partial
from itertools import combinations, product

from matplotlib.colors import LinearSegmentedColormap
from matplotlib.cm import ScalarMappable
from matplotlib.figure import Figure
from matplotlib.patches import Wedge, Circle
from matplotlib.pyplot import Normalize
from matplotlib.ticker import EngFormatter, PercentFormatter
from matplotlib.transforms import Bbox
from matplotlib.transforms import Bbox, ScaledTranslation
from numpy.typing import NDArray
from scipy.sparse import lil_matrix
from scipy.sparse.csgraph import connected_components
from typing import Callable, Optional, Dict, Any, List, Union, TypeVar, Tuple

from pandas import DataFrame, MultiIndex
from pandas.api.types import is_numeric_dtype
from pandas.core.groupby.generic import DataFrameGroupBy
from pandas.core.groupby.groupby import GroupBy
from scipy.interpolate import make_interp_spline
from scipy.sparse import lil_matrix
from scipy.sparse.csgraph import connected_components
from textwrap import wrap
from typing import Callable, Optional, Dict, Any, List, Union, TypeVar, Tuple, Hashable

from ..bench import Report, Study, result, attribute

Expand Down Expand Up @@ -44,6 +51,7 @@ class PlotType(str, Enum):
BAR = "bar"
LINE = "line"
DOT = "dot"
PCOORD = "pcoord"


class TickFormat(str, Enum):
Expand All @@ -68,8 +76,8 @@ class AggregationMode(str, Enum):
},
AggregationMode.MEDIAN_PERC_90: {
"median": "median",
"95perc-l": partial(np.percentile, q=5),
"95perc-h": partial(np.percentile, q=95),
"90perc-l": partial(np.percentile, q=5),
"90perc-h": partial(np.percentile, q=95),
},
AggregationMode.MEDIAN_PERC_95: {
"median": "median",
Expand All @@ -78,8 +86,8 @@ class AggregationMode(str, Enum):
},
AggregationMode.MEDIAN_PERC_99: {
"median": "median",
"95perc-l": partial(np.percentile, q=0.5),
"95perc-h": partial(np.percentile, q=99.5),
"99perc-l": partial(np.percentile, q=0.5),
"99perc-h": partial(np.percentile, q=99.5),
},
}

Expand Down Expand Up @@ -260,7 +268,7 @@ def replace_keywords(source: str, keyword_replacements: Dict[str, str]) -> str:
return source.replace("_", " ").title()


def get_colors(keys: List[Tuple[str, ...]], colors: Optional[Dict[Tuple[str, ...], str]]) -> List[str]:
def get_colors(keys: List[Tuple[Hashable, ...]], colors: Optional[Dict[Tuple[Hashable, ...], str]]) -> List[str]:
available_default_colors = [
COLORS[i]
for i in range(len(COLORS))
Expand Down Expand Up @@ -784,6 +792,184 @@ def dotplot(
return figure


def pcoordplot(
summary: dict,
targetval: str,
compare: Optional[List[str]] = None,
colors: Optional[Dict[Hashable, str]] = None,
aggmode: AggregationMode = AggregationMode.MEDIAN_PERC_90,
keyword_replacements: Optional[Dict[str, str]] = None,
axes: Optional[plt.Axes] = None,
fontsize: int = DEFAULT_FONTSIZE,
dontcompare: Optional[str] = None,
) -> Optional[Figure]:
if compare is None:
compare = []

if keyword_replacements is None:
keyword_replacements = {}
if dontcompare is None:
dontcompare = ""
figure: Optional[Figure] = None
if axes is None:
figure = plt.figure(figsize=(len(compare) * 4, 4))
subplots = figure.subplots()
assert isinstance(subplots, plt.Axes)
axes = subplots
else:
figure = axes.get_figure()
assert figure is not None

# Convert summary to dataframe to make it easier to work with.
valuecols = list(next(iter(summary.values())).keys())
centercol = next(c for c in valuecols if c.endswith(VALUE_MEASURE_C[aggmode]))
dataframe = pd.DataFrame(
[list(k) + [v[vv] for vv in valuecols] for k, v in summary.items()], columns=compare + valuecols
)
dataframe = dataframe.sort_values(by=centercol)

# Compute the range of the targetval.
valnumeric = is_numeric_dtype(dataframe[centercol])
val_unique = dataframe[centercol].unique()
minval = dataframe[centercol].min() if valnumeric else 0
maxval = dataframe[centercol].max() if valnumeric else len(val_unique) - 1
valrange = maxval - minval

# Compute vertical positions for each compare variable.
position_index: Dict[str, Dict[Hashable, float]] = {}
for comp in compare:
num_values = len(dataframe[comp].unique())
position_index[comp] = {
value: (float(i) / (num_values - 1)) * valrange * 0.7 + minval
for i, value in enumerate(sorted(dataframe[comp].unique()))
}
dataframe[comp + ":y"] = dataframe[comp].map(position_index[comp])

# Compute vertical position for targetval and determine the color map.
split_colors = (
dict((tuple(k.split(",") if isinstance(k, str) else k), v) for (k, v) in colors.items())
if colors is not None
else None
)
color_map: Dict[Hashable, str] = {}
if valnumeric:
dataframe[targetval + ":y"] = dataframe[centercol]
position_index[targetval] = {
minval: 0.0,
(maxval + minval) * 0.5: 0.5,
maxval: 1.0,
}
# colors_normalized = {(value - minval) / (maxval - minval): color for value, color in colors.items()}
color_keys = list(colors.keys()) if colors is not None else [0.0, 1.0]
colors_renamed = list(zip(color_keys, get_colors(color_keys, colors=split_colors)))
cmap = LinearSegmentedColormap.from_list("custom", colors_renamed)
norm = Normalize(vmin=minval, vmax=maxval)
sm = ScalarMappable(cmap=cmap, norm=norm)
color_map = {value: sm.to_rgba((value - minval) / (maxval - minval)) for value in val_unique}
else:
position_index[targetval] = {value: i for i, value in enumerate(sorted(val_unique))}
dataframe[centercol + ":y"] = dataframe[centercol].map(position_index[targetval])
color_map = {value: color for value, color in zip(val_unique, get_colors(val_unique, colors=split_colors))}

# Determine the color of each instance based on the value of targetval.
dataframe[centercol + ":color"] = dataframe[centercol].map(color_map)

# Draw all the lines.
for i, row in dataframe.iterrows():
lx = np.arange(len(compare) + 1)
ly = np.array([row[targetval + ":y"]] + [row[comp + ":y"] for comp in compare])

spline = make_interp_spline(lx, ly, k=2)
sx = np.linspace(0, len(compare), 100)
sy = spline(lx)

axes.plot(sx, sy, color=row[centercol + ":color"], linewidth=DEFAULT_LINEWIDTH, alpha=0.5)

# Draw the coordinates and annotations.
for i, comp in enumerate(compare):
axes.plot([i + 1, i + 1], [minval, maxval], color="black", linewidth=1)
if comp != targetval:
axes.annotate(
"\n".join(wrap(replace_keywords(comp, keyword_replacements), width=20)),
xy=(i + 1, maxval),
xytext=(-fontsize * 0.5, 0),
textcoords="offset points",
fontsize=fontsize,
horizontalalignment="right",
verticalalignment="top",
)
for value, y in position_index[comp].items():
axes.annotate(
str(value),
xy=(i + 1, y),
xytext=(-fontsize * 1.5, 0),
textcoords="offset points",
fontsize=fontsize,
horizontalalignment="right",
verticalalignment="center",
)

# Draw cross sections for each compare variable value.
for i, comp in enumerate(compare):
for value, y in position_index[comp].items():
section_colors = dataframe[dataframe[comp] == value][centercol + ":color"].values

aggop = VALUE_MEASURES[aggmode][VALUE_MEASURE_C[aggmode]]
aggval = (dataframe[dataframe[comp] == value][centercol].agg(aggop) - minval) / valrange
theta = 360 / len(section_colors)
transform = ScaledTranslation(i + 1, y, axes.transData)
r = fontsize
center = (0.0, 0.0)

if valnumeric:
circle = Circle(
xy=center,
radius=r + 3,
edgecolor="black",
facecolor="black",
zorder=2,
transform=transform,
linewidth=1,
)
axes.add_patch(circle)
wedge = Wedge(
center=center,
r=r + 3,
theta1=0,
theta2=aggval * 360,
edgecolor="black",
facecolor="white",
zorder=2,
transform=transform,
)
axes.add_patch(wedge)

for j in range(len(section_colors)):
wedge = Wedge(
center=center,
r=r,
theta1=j * theta,
theta2=(j + 2) * theta,
edgecolor=None,
facecolor=tuple(float(x) for x in section_colors[j]),
zorder=2,
transform=transform,
)
axes.add_patch(wedge)

circle = Circle(
xy=center, radius=r, edgecolor="black", facecolor="none", zorder=3, transform=transform, linewidth=1
)
axes.add_patch(circle)

axes.set_ylabel(replace_keywords(targetval, keyword_replacements), fontsize=fontsize, wrap=True)
axes.get_xaxis().set_ticks([])

axes.set_xlim(0.0, len(compare) + 0.1)

return figure


NONE_SYMBOL = "-"
DEFAULT_PLOTSIZE = [10, 8]

Expand Down Expand Up @@ -1218,7 +1404,7 @@ def generate(self) -> None:
dontcompare=self._dontcompare[i],
)

else:
elif plottype == PlotType.DOT:
if self._summary is None:
raise ValueError("A dot plot can only be generated from a summary.")

Expand All @@ -1243,6 +1429,25 @@ def generate(self) -> None:
dontcompare=self._dontcompare[i],
)

elif plottype == PlotType.PCOORD:
if self._summary is None:
raise ValueError("A parallel coordinate plot can only be generated from a summary.")

pcoordplot(
summary=self._summary,
targetval=target[0],
compare=self.compare,
colors=colors,
aggmode=self._summode,
keyword_replacements=keyword_replacements,
axes=axes[i],
fontsize=self.fontsize,
dontcompare=self._dontcompare[i],
)

else:
raise ValueError("Unknown plot type: %s" % plottype)

if self._legend:
self._figure.subplots_adjust(bottom=0.25)
lines, labels = self._figure.axes[0].get_legend_handles_labels()
Expand Down

0 comments on commit ed5b43e

Please sign in to comment.