Skip to content

Commit

Permalink
Small fixes in aggregate plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
bojan-karlas committed Aug 3, 2024
1 parent fcf2fe5 commit 1587006
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions experiments/datascope/experiments/reports/aggregate_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import pandas as pd
import re

from collections.abc import Sequence, Mapping
from enum import Enum
from functools import partial
from itertools import combinations, product

from matplotlib.colors import LinearSegmentedColormap
from matplotlib.cm import ScalarMappable
from matplotlib.figure import Figure
from matplotlib.patches import Wedge, Circle
from matplotlib.pyplot import Normalize
from matplotlib.ticker import EngFormatter, PercentFormatter
from matplotlib.transforms import Bbox, ScaledTranslation
from matplotlib.transforms import Bbox, ScaledTranslation, Affine2DBase
from numpy.typing import NDArray
from pandas import DataFrame, MultiIndex
from pandas.api.types import is_numeric_dtype
Expand All @@ -27,8 +27,8 @@

from ..bench import Report, Study, result, attribute

COLOR_NAMES = ["blue", "red", "yellow", "green", "purple", "brown", "cyan", "pink"]
COLORS = ["#2BBFD9", "#DF362A", "#FAC802", "#8AB365", "#C670D2", "#AE7E1E", "#008F6B", "#6839CC"]
COLOR_NAMES = ["red", "blue", "yellow", "green", "purple", "brown", "cyan", "pink"]
COLORS = ["#DF362A", "#2BBFD9", "#FAC802", "#8AB365", "#C670D2", "#AE7E1E", "#008F6B", "#6839CC"]
LABELS = {
"random": "Random",
"shapley-tmc": "Shapley TMC",
Expand Down Expand Up @@ -262,13 +262,13 @@ def column_mapper(x: Tuple) -> Tuple:
return dataframe


def replace_keywords(source: str, keyword_replacements: Dict[str, str]) -> str:
def replace_keywords(source: str, keyword_replacements: Mapping[str, str]) -> str:
for k, v in sorted(keyword_replacements.items(), key=lambda x: len(x[0]), reverse=True):
source = re.sub("(?<![a-zA-Z])%s(?![a-z-Z])" % k, v, source)
return source.replace("_", " ").title()


def get_colors(keys: List[Tuple[Hashable, ...]], colors: Optional[Dict[Tuple[Hashable, ...], str]]) -> List[str]:
def get_colors(keys: Sequence[Tuple[Hashable, ...]], colors: Optional[Mapping[Tuple[Hashable, ...], str]]) -> List[str]:
available_default_colors = [
COLORS[i]
for i in range(len(COLORS))
Expand Down Expand Up @@ -422,7 +422,9 @@ def lineplot(
labels.append(label)

centercol = VALUE_MEASURE_C[aggmode]
split_colors = dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None
split_colors: Optional[Mapping[Tuple[Hashable, ...], str]] = (
dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None
)
comp_colors = get_colors(comparison, colors=split_colors)
texts = []
ymin, ymax = np.inf, -np.inf
Expand Down Expand Up @@ -590,7 +592,9 @@ def barplot(
)

comparison = [item[0] for item in summary_items]
split_colors = dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None
split_colors: Optional[Mapping[Tuple[Hashable, ...], str]] = (
dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None
)
comp_colors = get_colors(comparison, colors=split_colors)
texts = []
ymin, ymax = np.inf, -np.inf
Expand Down Expand Up @@ -720,7 +724,9 @@ def dotplot(
)

comparison = [item[0] for item in summary_items]
split_colors = dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None
split_colors: Optional[Mapping[Tuple[Hashable, ...], str]] = (
dict((tuple(k.split(",")), v) for (k, v) in colors.items()) if colors is not None else None
)
comp_colors = get_colors(comparison, colors=split_colors)
texts = []

Expand Down Expand Up @@ -795,10 +801,10 @@ def dotplot(
def pcoordplot(
summary: dict,
targetval: str,
compare: Optional[List[str]] = None,
colors: Optional[Dict[Hashable, str]] = None,
compare: Optional[Sequence[str]] = None,
colors: Optional[Mapping[Hashable, str]] = None,
aggmode: AggregationMode = AggregationMode.MEDIAN_PERC_90,
keyword_replacements: Optional[Dict[str, str]] = None,
keyword_replacements: Optional[Mapping[str, str]] = None,
axes: Optional[plt.Axes] = None,
fontsize: int = DEFAULT_FONTSIZE,
dontcompare: Optional[str] = None,
Expand All @@ -824,7 +830,7 @@ def pcoordplot(
valuecols = list(next(iter(summary.values())).keys())
centercol = next(c for c in valuecols if c.endswith(VALUE_MEASURE_C[aggmode]))
dataframe = pd.DataFrame(
[list(k) + [v[vv] for vv in valuecols] for k, v in summary.items()], columns=compare + valuecols
[list(k) + [v[vv] for vv in valuecols] for k, v in summary.items()], columns=list(compare) + valuecols
)
dataframe = dataframe.sort_values(by=centercol)

Expand All @@ -846,12 +852,12 @@ def pcoordplot(
dataframe[comp + ":y"] = dataframe[comp].map(position_index[comp])

# Compute vertical position for targetval and determine the color map.
split_colors = (
dict((tuple(k.split(",") if isinstance(k, str) else k), v) for (k, v) in colors.items())
split_colors: Optional[Mapping[Tuple[Hashable, ...], str]] = (
dict((tuple(k.split(",")) if isinstance(k, str) else ((k,)), v) for (k, v) in colors.items())
if colors is not None
else None
)
color_map: Dict[Hashable, str] = {}
color_map: Dict[Hashable, Union[str, tuple[float, ...]]] = {}
if valnumeric:
dataframe[targetval + ":y"] = dataframe[centercol]
position_index[targetval] = {
Expand All @@ -860,12 +866,18 @@ def pcoordplot(
maxval: 1.0,
}
# colors_normalized = {(value - minval) / (maxval - minval): color for value, color in colors.items()}
color_keys = list(colors.keys()) if colors is not None else [0.0, 1.0]
colors_renamed = list(zip(color_keys, get_colors(color_keys, colors=split_colors)))
if colors is not None and not all(isinstance(k, float) for k in colors.keys()):
raise ValueError("Colors must be specified as floats when targetval is numeric.")
color_keys: Sequence[float] = list(colors.keys()) if colors is not None else [0.0, 1.0] # type: ignore
color_keys_tuples: Sequence[tuple[float, ...]] = list((k,) for k in color_keys)

colors_renamed: Sequence[Tuple[float, str]] = list(
zip(color_keys, get_colors(color_keys_tuples, colors=split_colors))
)
cmap = LinearSegmentedColormap.from_list("custom", colors_renamed)
norm = Normalize(vmin=minval, vmax=maxval)
sm = ScalarMappable(cmap=cmap, norm=norm)
color_map = {value: sm.to_rgba((value - minval) / (maxval - minval)) for value in val_unique}
color_map = {value: tuple(sm.to_rgba((value - minval) / (maxval - minval))) for value in val_unique}
else:
position_index[targetval] = {value: i for i, value in enumerate(sorted(val_unique))}
dataframe[centercol + ":y"] = dataframe[centercol].map(position_index[targetval])
Expand All @@ -881,7 +893,7 @@ def pcoordplot(

spline = make_interp_spline(lx, ly, k=2)
sx = np.linspace(0, len(compare), 100)
sy = spline(lx)
sy = spline(sx)

axes.plot(sx, sy, color=row[centercol + ":color"], linewidth=DEFAULT_LINEWIDTH, alpha=0.5)

Expand Down Expand Up @@ -917,6 +929,7 @@ def pcoordplot(
aggop = VALUE_MEASURES[aggmode][VALUE_MEASURE_C[aggmode]]
aggval = (dataframe[dataframe[comp] == value][centercol].agg(aggop) - minval) / valrange
theta = 360 / len(section_colors)
assert isinstance(axes.transData, Affine2DBase)
transform = ScaledTranslation(i + 1, y, axes.transData)
r = fontsize
center = (0.0, 0.0)
Expand Down Expand Up @@ -1437,7 +1450,7 @@ def generate(self) -> None:
summary=self._summary,
targetval=target[0],
compare=self.compare,
colors=colors,
colors=colors, # type: ignore
aggmode=self._summode,
keyword_replacements=keyword_replacements,
axes=axes[i],
Expand Down

0 comments on commit 1587006

Please sign in to comment.