diff --git a/experiments/datascope/experiments/reports/aggregate_plot.py b/experiments/datascope/experiments/reports/aggregate_plot.py index 142b0e6..a53329f 100644 --- a/experiments/datascope/experiments/reports/aggregate_plot.py +++ b/experiments/datascope/experiments/reports/aggregate_plot.py @@ -3,17 +3,17 @@ import pandas as pd import re +from collections.abc import Sequence, Mapping from enum import Enum from functools import partial from itertools import combinations, product - from matplotlib.colors import LinearSegmentedColormap from matplotlib.cm import ScalarMappable from matplotlib.figure import Figure from matplotlib.patches import Wedge, Circle from matplotlib.pyplot import Normalize from matplotlib.ticker import EngFormatter, PercentFormatter -from matplotlib.transforms import Bbox, ScaledTranslation +from matplotlib.transforms import Bbox, ScaledTranslation, Affine2DBase from numpy.typing import NDArray from pandas import DataFrame, MultiIndex from pandas.api.types import is_numeric_dtype @@ -27,8 +27,8 @@ from ..bench import Report, Study, result, attribute -COLOR_NAMES = ["blue", "red", "yellow", "green", "purple", "brown", "cyan", "pink"] -COLORS = ["#2BBFD9", "#DF362A", "#FAC802", "#8AB365", "#C670D2", "#AE7E1E", "#008F6B", "#6839CC"] +COLOR_NAMES = ["red", "blue", "yellow", "green", "purple", "brown", "cyan", "pink"] +COLORS = ["#DF362A", "#2BBFD9", "#FAC802", "#8AB365", "#C670D2", "#AE7E1E", "#008F6B", "#6839CC"] LABELS = { "random": "Random", "shapley-tmc": "Shapley TMC", @@ -262,13 +262,13 @@ def column_mapper(x: Tuple) -> Tuple: return dataframe -def replace_keywords(source: str, keyword_replacements: Dict[str, str]) -> str: +def replace_keywords(source: str, keyword_replacements: Mapping[str, str]) -> str: for k, v in sorted(keyword_replacements.items(), key=lambda x: len(x[0]), reverse=True): source = re.sub("(? List[str]: +def get_colors(keys: Sequence[Tuple[Hashable, ...]], colors: Optional[Mapping[Tuple[Hashable, ...], str]]) -> List[str]: available_default_colors = [ COLORS[i] for i in range(len(COLORS)) @@ -422,7 +422,9 @@ def lineplot( labels.append(label) centercol = VALUE_MEASURE_C[aggmode] - split_colors = dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None + split_colors: Optional[Mapping[Tuple[Hashable, ...], str]] = ( + dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None + ) comp_colors = get_colors(comparison, colors=split_colors) texts = [] ymin, ymax = np.inf, -np.inf @@ -590,7 +592,9 @@ def barplot( ) comparison = [item[0] for item in summary_items] - split_colors = dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None + split_colors: Optional[Mapping[Tuple[Hashable, ...], str]] = ( + dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None + ) comp_colors = get_colors(comparison, colors=split_colors) texts = [] ymin, ymax = np.inf, -np.inf @@ -720,7 +724,9 @@ def dotplot( ) comparison = [item[0] for item in summary_items] - split_colors = dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None + split_colors: Optional[Mapping[Tuple[Hashable, ...], str]] = ( + dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None + ) comp_colors = get_colors(comparison, colors=split_colors) texts = [] @@ -795,10 +801,10 @@ def dotplot( def pcoordplot( summary: dict, targetval: str, - compare: Optional[List[str]] = None, - colors: Optional[Dict[Hashable, str]] = None, + compare: Optional[Sequence[str]] = None, + colors: Optional[Mapping[Hashable, str]] = None, aggmode: AggregationMode = AggregationMode.MEDIAN_PERC_90, - keyword_replacements: Optional[Dict[str, str]] = None, + keyword_replacements: Optional[Mapping[str, str]] = None, axes: Optional[plt.Axes] = None, fontsize: int = DEFAULT_FONTSIZE, dontcompare: Optional[str] = None, @@ -824,7 +830,7 @@ def pcoordplot( valuecols = list(next(iter(summary.values())).keys()) centercol = next(c for c in valuecols if c.endswith(VALUE_MEASURE_C[aggmode])) dataframe = pd.DataFrame( - [list(k) + [v[vv] for vv in valuecols] for k, v in summary.items()], columns=compare + valuecols + [list(k) + [v[vv] for vv in valuecols] for k, v in summary.items()], columns=list(compare) + valuecols ) dataframe = dataframe.sort_values(by=centercol) @@ -846,12 +852,12 @@ def pcoordplot( dataframe[comp + ":y"] = dataframe[comp].map(position_index[comp]) # Compute vertical position for targetval and determine the color map. - split_colors = ( - dict((tuple(k.split(",") if isinstance(k, str) else k), v) for (k, v) in colors.items()) + split_colors: Optional[Mapping[Tuple[Hashable, ...], str]] = ( + dict((tuple(k.split(",")) if isinstance(k, str) else ((k,)), v) for (k, v) in colors.items()) if colors is not None else None ) - color_map: Dict[Hashable, str] = {} + color_map: Dict[Hashable, Union[str, tuple[float, ...]]] = {} if valnumeric: dataframe[targetval + ":y"] = dataframe[centercol] position_index[targetval] = { @@ -860,12 +866,18 @@ def pcoordplot( maxval: 1.0, } # colors_normalized = {(value - minval) / (maxval - minval): color for value, color in colors.items()} - color_keys = list(colors.keys()) if colors is not None else [0.0, 1.0] - colors_renamed = list(zip(color_keys, get_colors(color_keys, colors=split_colors))) + if colors is not None and not all(isinstance(k, float) for k in colors.keys()): + raise ValueError("Colors must be specified as floats when targetval is numeric.") + color_keys: Sequence[float] = list(colors.keys()) if colors is not None else [0.0, 1.0] # type: ignore + color_keys_tuples: Sequence[tuple[float, ...]] = list((k,) for k in color_keys) + + colors_renamed: Sequence[Tuple[float, str]] = list( + zip(color_keys, get_colors(color_keys_tuples, colors=split_colors)) + ) cmap = LinearSegmentedColormap.from_list("custom", colors_renamed) norm = Normalize(vmin=minval, vmax=maxval) sm = ScalarMappable(cmap=cmap, norm=norm) - color_map = {value: sm.to_rgba((value - minval) / (maxval - minval)) for value in val_unique} + color_map = {value: tuple(sm.to_rgba((value - minval) / (maxval - minval))) for value in val_unique} else: position_index[targetval] = {value: i for i, value in enumerate(sorted(val_unique))} dataframe[centercol + ":y"] = dataframe[centercol].map(position_index[targetval]) @@ -881,7 +893,7 @@ def pcoordplot( spline = make_interp_spline(lx, ly, k=2) sx = np.linspace(0, len(compare), 100) - sy = spline(lx) + sy = spline(sx) axes.plot(sx, sy, color=row[centercol + ":color"], linewidth=DEFAULT_LINEWIDTH, alpha=0.5) @@ -917,6 +929,7 @@ def pcoordplot( aggop = VALUE_MEASURES[aggmode][VALUE_MEASURE_C[aggmode]] aggval = (dataframe[dataframe[comp] == value][centercol].agg(aggop) - minval) / valrange theta = 360 / len(section_colors) + assert isinstance(axes.transData, Affine2DBase) transform = ScaledTranslation(i + 1, y, axes.transData) r = fontsize center = (0.0, 0.0) @@ -1437,7 +1450,7 @@ def generate(self) -> None: summary=self._summary, targetval=target[0], compare=self.compare, - colors=colors, + colors=colors, # type: ignore aggmode=self._summode, keyword_replacements=keyword_replacements, axes=axes[i],