diff --git a/experiments/datascope/experiments/reports/aggregate_plot.py b/experiments/datascope/experiments/reports/aggregate_plot.py index e7c4eb5..142b0e6 100644 --- a/experiments/datascope/experiments/reports/aggregate_plot.py +++ b/experiments/datascope/experiments/reports/aggregate_plot.py @@ -6,17 +6,24 @@ from enum import Enum from functools import partial from itertools import combinations, product + +from matplotlib.colors import LinearSegmentedColormap +from matplotlib.cm import ScalarMappable from matplotlib.figure import Figure +from matplotlib.patches import Wedge, Circle +from matplotlib.pyplot import Normalize from matplotlib.ticker import EngFormatter, PercentFormatter -from matplotlib.transforms import Bbox +from matplotlib.transforms import Bbox, ScaledTranslation from numpy.typing import NDArray -from scipy.sparse import lil_matrix -from scipy.sparse.csgraph import connected_components -from typing import Callable, Optional, Dict, Any, List, Union, TypeVar, Tuple - from pandas import DataFrame, MultiIndex +from pandas.api.types import is_numeric_dtype from pandas.core.groupby.generic import DataFrameGroupBy from pandas.core.groupby.groupby import GroupBy +from scipy.interpolate import make_interp_spline +from scipy.sparse import lil_matrix +from scipy.sparse.csgraph import connected_components +from textwrap import wrap +from typing import Callable, Optional, Dict, Any, List, Union, TypeVar, Tuple, Hashable from ..bench import Report, Study, result, attribute @@ -44,6 +51,7 @@ class PlotType(str, Enum): BAR = "bar" LINE = "line" DOT = "dot" + PCOORD = "pcoord" class TickFormat(str, Enum): @@ -68,8 +76,8 @@ class AggregationMode(str, Enum): }, AggregationMode.MEDIAN_PERC_90: { "median": "median", - "95perc-l": partial(np.percentile, q=5), - "95perc-h": partial(np.percentile, q=95), + "90perc-l": partial(np.percentile, q=5), + "90perc-h": partial(np.percentile, q=95), }, AggregationMode.MEDIAN_PERC_95: { "median": "median", @@ -78,8 +86,8 @@ class AggregationMode(str, Enum): }, AggregationMode.MEDIAN_PERC_99: { "median": "median", - "95perc-l": partial(np.percentile, q=0.5), - "95perc-h": partial(np.percentile, q=99.5), + "99perc-l": partial(np.percentile, q=0.5), + "99perc-h": partial(np.percentile, q=99.5), }, } @@ -260,7 +268,7 @@ def replace_keywords(source: str, keyword_replacements: Dict[str, str]) -> str: return source.replace("_", " ").title() -def get_colors(keys: List[Tuple[str, ...]], colors: Optional[Dict[Tuple[str, ...], str]]) -> List[str]: +def get_colors(keys: List[Tuple[Hashable, ...]], colors: Optional[Dict[Tuple[Hashable, ...], str]]) -> List[str]: available_default_colors = [ COLORS[i] for i in range(len(COLORS)) @@ -784,6 +792,184 @@ def dotplot( return figure +def pcoordplot( + summary: dict, + targetval: str, + compare: Optional[List[str]] = None, + colors: Optional[Dict[Hashable, str]] = None, + aggmode: AggregationMode = AggregationMode.MEDIAN_PERC_90, + keyword_replacements: Optional[Dict[str, str]] = None, + axes: Optional[plt.Axes] = None, + fontsize: int = DEFAULT_FONTSIZE, + dontcompare: Optional[str] = None, +) -> Optional[Figure]: + if compare is None: + compare = [] + + if keyword_replacements is None: + keyword_replacements = {} + if dontcompare is None: + dontcompare = "" + figure: Optional[Figure] = None + if axes is None: + figure = plt.figure(figsize=(len(compare) * 4, 4)) + subplots = figure.subplots() + assert isinstance(subplots, plt.Axes) + axes = subplots + else: + figure = axes.get_figure() + assert figure is not None + + # Convert summary to dataframe to make it easier to work with. + valuecols = list(next(iter(summary.values())).keys()) + centercol = next(c for c in valuecols if c.endswith(VALUE_MEASURE_C[aggmode])) + dataframe = pd.DataFrame( + [list(k) + [v[vv] for vv in valuecols] for k, v in summary.items()], columns=compare + valuecols + ) + dataframe = dataframe.sort_values(by=centercol) + + # Compute the range of the targetval. + valnumeric = is_numeric_dtype(dataframe[centercol]) + val_unique = dataframe[centercol].unique() + minval = dataframe[centercol].min() if valnumeric else 0 + maxval = dataframe[centercol].max() if valnumeric else len(val_unique) - 1 + valrange = maxval - minval + + # Compute vertical positions for each compare variable. + position_index: Dict[str, Dict[Hashable, float]] = {} + for comp in compare: + num_values = len(dataframe[comp].unique()) + position_index[comp] = { + value: (float(i) / (num_values - 1)) * valrange * 0.7 + minval + for i, value in enumerate(sorted(dataframe[comp].unique())) + } + dataframe[comp + ":y"] = dataframe[comp].map(position_index[comp]) + + # Compute vertical position for targetval and determine the color map. + split_colors = ( + dict((tuple(k.split(",") if isinstance(k, str) else k), v) for (k, v) in colors.items()) + if colors is not None + else None + ) + color_map: Dict[Hashable, str] = {} + if valnumeric: + dataframe[targetval + ":y"] = dataframe[centercol] + position_index[targetval] = { + minval: 0.0, + (maxval + minval) * 0.5: 0.5, + maxval: 1.0, + } + # colors_normalized = {(value - minval) / (maxval - minval): color for value, color in colors.items()} + color_keys = list(colors.keys()) if colors is not None else [0.0, 1.0] + colors_renamed = list(zip(color_keys, get_colors(color_keys, colors=split_colors))) + cmap = LinearSegmentedColormap.from_list("custom", colors_renamed) + norm = Normalize(vmin=minval, vmax=maxval) + sm = ScalarMappable(cmap=cmap, norm=norm) + color_map = {value: sm.to_rgba((value - minval) / (maxval - minval)) for value in val_unique} + else: + position_index[targetval] = {value: i for i, value in enumerate(sorted(val_unique))} + dataframe[centercol + ":y"] = dataframe[centercol].map(position_index[targetval]) + color_map = {value: color for value, color in zip(val_unique, get_colors(val_unique, colors=split_colors))} + + # Determine the color of each instance based on the value of targetval. + dataframe[centercol + ":color"] = dataframe[centercol].map(color_map) + + # Draw all the lines. + for i, row in dataframe.iterrows(): + lx = np.arange(len(compare) + 1) + ly = np.array([row[targetval + ":y"]] + [row[comp + ":y"] for comp in compare]) + + spline = make_interp_spline(lx, ly, k=2) + sx = np.linspace(0, len(compare), 100) + sy = spline(lx) + + axes.plot(sx, sy, color=row[centercol + ":color"], linewidth=DEFAULT_LINEWIDTH, alpha=0.5) + + # Draw the coordinates and annotations. + for i, comp in enumerate(compare): + axes.plot([i + 1, i + 1], [minval, maxval], color="black", linewidth=1) + if comp != targetval: + axes.annotate( + "\n".join(wrap(replace_keywords(comp, keyword_replacements), width=20)), + xy=(i + 1, maxval), + xytext=(-fontsize * 0.5, 0), + textcoords="offset points", + fontsize=fontsize, + horizontalalignment="right", + verticalalignment="top", + ) + for value, y in position_index[comp].items(): + axes.annotate( + str(value), + xy=(i + 1, y), + xytext=(-fontsize * 1.5, 0), + textcoords="offset points", + fontsize=fontsize, + horizontalalignment="right", + verticalalignment="center", + ) + + # Draw cross sections for each compare variable value. + for i, comp in enumerate(compare): + for value, y in position_index[comp].items(): + section_colors = dataframe[dataframe[comp] == value][centercol + ":color"].values + + aggop = VALUE_MEASURES[aggmode][VALUE_MEASURE_C[aggmode]] + aggval = (dataframe[dataframe[comp] == value][centercol].agg(aggop) - minval) / valrange + theta = 360 / len(section_colors) + transform = ScaledTranslation(i + 1, y, axes.transData) + r = fontsize + center = (0.0, 0.0) + + if valnumeric: + circle = Circle( + xy=center, + radius=r + 3, + edgecolor="black", + facecolor="black", + zorder=2, + transform=transform, + linewidth=1, + ) + axes.add_patch(circle) + wedge = Wedge( + center=center, + r=r + 3, + theta1=0, + theta2=aggval * 360, + edgecolor="black", + facecolor="white", + zorder=2, + transform=transform, + ) + axes.add_patch(wedge) + + for j in range(len(section_colors)): + wedge = Wedge( + center=center, + r=r, + theta1=j * theta, + theta2=(j + 2) * theta, + edgecolor=None, + facecolor=tuple(float(x) for x in section_colors[j]), + zorder=2, + transform=transform, + ) + axes.add_patch(wedge) + + circle = Circle( + xy=center, radius=r, edgecolor="black", facecolor="none", zorder=3, transform=transform, linewidth=1 + ) + axes.add_patch(circle) + + axes.set_ylabel(replace_keywords(targetval, keyword_replacements), fontsize=fontsize, wrap=True) + axes.get_xaxis().set_ticks([]) + + axes.set_xlim(0.0, len(compare) + 0.1) + + return figure + + NONE_SYMBOL = "-" DEFAULT_PLOTSIZE = [10, 8] @@ -1218,7 +1404,7 @@ def generate(self) -> None: dontcompare=self._dontcompare[i], ) - else: + elif plottype == PlotType.DOT: if self._summary is None: raise ValueError("A dot plot can only be generated from a summary.") @@ -1243,6 +1429,25 @@ def generate(self) -> None: dontcompare=self._dontcompare[i], ) + elif plottype == PlotType.PCOORD: + if self._summary is None: + raise ValueError("A parallel coordinate plot can only be generated from a summary.") + + pcoordplot( + summary=self._summary, + targetval=target[0], + compare=self.compare, + colors=colors, + aggmode=self._summode, + keyword_replacements=keyword_replacements, + axes=axes[i], + fontsize=self.fontsize, + dontcompare=self._dontcompare[i], + ) + + else: + raise ValueError("Unknown plot type: %s" % plottype) + if self._legend: self._figure.subplots_adjust(bottom=0.25) lines, labels = self._figure.axes[0].get_legend_handles_labels()