Skip to content

Commit

Permalink
Improved pair plots for scenario discovery (#288)
Browse files Browse the repository at this point in the history
These changes add some interesting plotting options to the pair plots for scenario discovery. New options include contour plots, bivariate histograms, and the option to leave the upper triangle empty to reduce visual clutter. It also adds control over the diagonal, which can be either a CDF or KDE.  

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jan Kwakkel <[email protected]>
Co-authored-by: Ewout ter Hoeven <[email protected]>
  • Loading branch information
4 people authored Nov 15, 2023
1 parent 2dd5a57 commit c9049bb
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 39 deletions.
1 change: 1 addition & 0 deletions ema_workbench/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
from .plotting import lines, envelopes, kde_over_time, multiple_densities
from .plotting_util import Density, PlotType
from .prim import Prim, run_constrained_prim, pca_preprocess, setup_prim
from .prim_util import DiagKind
from .scenario_discovery_util import RuleInductionType
38 changes: 31 additions & 7 deletions ema_workbench/analysis/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
calculate_qp,
determine_dimres,
is_significant,
DiagKind,
)

# Created on 22 feb. 2013
Expand Down Expand Up @@ -827,7 +828,15 @@ def show_tradeoff(self, cmap=mpl.cm.viridis, annotated=False): # @UndefinedVari
"""
return sdutil.plot_tradeoff(self.peeling_trajectory, cmap=cmap, annotated=annotated)

def show_pairs_scatter(self, i=None, dims=None, cdf=False):
def show_pairs_scatter(
self,
i=None,
dims=None,
diag_kind=DiagKind.KDE,
upper="scatter",
lower="contour",
fill_subplots=True,
):
"""Make a pair wise scatter plot of all the restricted
dimensions with color denoting whether a given point is of
interest or not and the boxlims superimposed on top.
Expand All @@ -837,8 +846,18 @@ def show_pairs_scatter(self, i=None, dims=None, cdf=False):
i : int, optional
dims : list of str, optional
dimensions to show, defaults to all restricted dimensions
cdf : bool, optional
plot diag as cdf or pdf
diag_kind : {DiagKind.KDE, DiagKind.CDF}
Plot diagonal as kernel density estimate ('kde') or
cumulative density function ('cdf').
upper, lower: string, optional
Use either 'scatter', 'contour', or 'hist' (bivariate
histogram) plots for upper and lower triangles. Upper triangle
can also be 'none' to eliminate redundancy. Legend uses
lower triangle style for markers.
fill_subplots: Boolean, optional
if True, subplots are resized to fill their respective axes.
This removes unnecessary whitespace, but may be undesirable
for some variable combinations.
Returns
-------
Expand All @@ -851,17 +870,22 @@ def show_pairs_scatter(self, i=None, dims=None, cdf=False):
if dims is None:
dims = sdutil._determine_restricted_dims(self.box_lims[i], self.prim.box_init)

# x =
# y = self.prim.y[self.yi_initial]
# order = np.argsort(y)
if diag_kind not in diag_kind.__members__:
raise ValueError(
f"diag_kind should be one of DiagKind.KDE or DiagKind.CDF, not {diag_kind}"
)
diag = diag_kind.value

return sdutil.plot_pair_wise_scatter(
self.prim.x.iloc[self.yi_initial, :],
self.prim.y[self.yi_initial],
self.box_lims[i],
self.prim.box_init,
dims,
cdf=cdf,
diag=diag,
upper=upper,
lower=lower,
fill_subplots=fill_subplots,
)

def write_ppt_to_stdout(self):
Expand Down
8 changes: 8 additions & 0 deletions ema_workbench/analysis/prim_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ class PRIMObjectiveFunctions(Enum):
ORIGINAL = "original"


class DiagKind(Enum):
KDE = "kde"
"""constant for plotting diagonal in pairs_scatter as kde"""

CDF = "cdf"
"""constant for plotting diagonal in pairs_scatter as cdf"""


def get_quantile(data, quantile):
"""
quantile calculation modeled on the implementation used in sdtoolkit
Expand Down
164 changes: 132 additions & 32 deletions ema_workbench/analysis/scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,18 @@ def _calculate_quasip(x, y, box, Hbox, Tbox):
return qp.pvalue


def plot_pair_wise_scatter(x, y, boxlim, box_init, restricted_dims, cdf=False):
def plot_pair_wise_scatter(
x,
y,
boxlim,
box_init,
restricted_dims,
diag="kde",
upper="scatter",
lower="hist",
fill_subplots=True,
):
"""helper function for pair wise scatter plotting
Parameters
----------
x : DataFrame
Expand All @@ -326,18 +335,23 @@ def plot_pair_wise_scatter(x, y, boxlim, box_init, restricted_dims, cdf=False):
box_init : DataFrame
restricted_dims : collection of strings
list of uncertainties that define the boxlims
cdf : bool, optional
plot diagonal as pdf or cdf, defaults to kde approx. of pdf
diag : string, optional
Plot diagonal as kernel density estimate ('kde') or
cumulative density function ('cdf').
upper, lower: string, optional
Use either 'scatter', 'contour', or 'hist' (bivariate
histogram) plots for upper and lower triangles. Upper triangle
can also be 'none' to eliminate redundancy. Legend uses
lower triangle style for markers.
fill_subplots: Boolean, optional
if True, subplots are resized to fill their respective axes.
This removes unnecessary whitespace, but may be undesirable
for some variable combinations.
"""

x = x[restricted_dims]
data = x.copy()

# TODO:: have option to change
# diag to CDF, gives you effectively the
# regional sensitivity analysis results
categorical_columns = data.select_dtypes("category").columns.values
categorical_mappings = {}
for column in categorical_columns:
Expand All @@ -357,59 +371,122 @@ def plot_pair_wise_scatter(x, y, boxlim, box_init, restricted_dims, cdf=False):
# replace column with codes
data[column] = data[column].cat.codes

# add outcome of interest to DataFrame
data["y"] = y

# ensures cases of interest are plotted on top
data.sort_values("y", inplace=True)

grid = sns.pairplot(
data=data,
hue="y",
vars=x.columns.values,
diag_kind="kde",
diag_kws={"cumulative": cdf, "common_norm": False, "fill": False},
)
# main plot body

cats = set(categorical_columns)
for row, ylabel in zip(grid.axes, grid.y_vars):
ylim = boxlim[ylabel]
grid = sns.PairGrid(
data=data, hue="y", vars=x.columns.values, diag_sharey=False
) # enables different plots in upper and lower triangles

if ylabel in cats:
y = -0.2
height = len(ylim[0]) - 0.6 # 2 * 0.2
else:
y = ylim[0]
height = ylim[1] - ylim[0]
# upper triangle
if upper == "contour":
# draw contours twice to get different fill and line alphas, more interpretable
grid.map_upper(
sns.kdeplot, fill=True, alpha=0.8, bw_adjust=1.2, levels=5, common_norm=False, cut=0
) # cut = 0
grid.map_upper(
sns.kdeplot, fill=False, alpha=1, bw_adjust=1.2, levels=5, common_norm=False, cut=0
)
elif upper == "hist":
grid.map_upper(sns.histplot)
elif upper == "scatter":
grid.map_upper(sns.scatterplot)
elif upper == "none":
None
else:
raise NotImplementedError(
f"upper = {upper} not implemented. Use either 'scatter', 'contour', 'hist' (bivariate histogram) or None plots for upper triangle."
)

# lower triangle
if lower == "contour":
# draw contours twice to get different fill and line alphas, more interpretable
grid.map_lower(
sns.kdeplot, fill=True, alpha=0.8, bw_adjust=1.2, levels=5, common_norm=False, cut=0
) # cut = 0
grid.map_lower(
sns.kdeplot, fill=False, alpha=1, bw_adjust=1.2, levels=5, common_norm=False, cut=0
)
elif lower == "hist":
grid.map_lower(sns.histplot)
elif lower == "scatter":
grid.map_lower(sns.scatterplot)
elif lower == "none":
raise ValueError(f"Lower triangle cannot be none.")
else:
raise NotImplementedError(
f"lower = {lower} not implemented. Use either 'scatter', 'contour' or 'hist' (bivariate histogram) plots for lower triangle."
)

# diagonal
if diag == "cdf":
grid.map_diag(sns.ecdfplot)
elif diag == "kde":
grid.map_diag(sns.kdeplot, fill=False, common_norm=False, cut=0)
else:
raise NotImplementedError(
f"diag = {diag} not implemented. Use either 'kde' (kernel density estimate) or 'cdf' (cumulative density function)."
)

# draw box
pad = 0.1

cats = set(categorical_columns)
for row, ylabel in zip(grid.axes, grid.y_vars):
for ax, xlabel in zip(row, grid.x_vars):
if ylabel == xlabel:
continue

xrange = ax.get_xlim()[1] - ax.get_xlim()[0]
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]

ylim = boxlim[ylabel]

if ylabel in cats:
height = (len(ylim[0]) - 1) + pad * yrange
y = -yrange * pad / 2
else:
y = ylim[0]
height = ylim[1] - ylim[0]

if xlabel in cats:
xlim = boxlim.at[0, xlabel]
x = -0.2
width = len(xlim) - 0.6 # 2 * 0.2
width = (len(xlim) - 1) + pad * xrange
x = -xrange * pad / 2
else:
xlim = boxlim[xlabel]
x = xlim[0]
width = xlim[1] - xlim[0]

xy = x, y
box = patches.Rectangle(xy, width, height, edgecolor="red", facecolor="none", lw=3)
ax.add_patch(box)
box = patches.Rectangle(
xy, width, height, edgecolor="red", facecolor="none", lw=3, zorder=100
)
if ax.has_data(): # keeps box from being drawn in upper triangle if empty
ax.add_patch(box)
else:
ax.set_axis_off()

# do the yticklabeling for categorical rows
for row, ylabel in zip(grid.axes, grid.y_vars):
if ylabel in cats:
ax = row[0]
labels = []
for entry in ax.get_yticklabels():
_, value = entry.get_position()
locs = []
mapping = categorical_mappings[ylabel]
for i in range(-1, len(mapping) + 1):
locs.append(i)
try:
label = categorical_mappings[ylabel][value]
label = categorical_mappings[ylabel][i]
except KeyError:
label = ""
labels.append(label)
ax.set_yticks(locs)
ax.set_yticklabels(labels)

# do the xticklabeling for categorical columns
Expand All @@ -427,6 +504,29 @@ def plot_pair_wise_scatter(x, y, boxlim, box_init, restricted_dims, cdf=False):
labels.append(label)
ax.set_xticks(locs)
ax.set_xticklabels(labels, rotation=90)

# fit subplot to data ranges, with some padding for aesthetics
if fill_subplots == True:
for axis in grid.axes:
for subplot in axis:
if subplot.get_xlabel() != "":
upper = data[subplot.get_xlabel()].max()
lower = data[subplot.get_xlabel()].min()

pad_rel = (upper - lower) * 0.1 # padding relative to range of data points

subplot.set_xlim(lower - pad_rel, upper + pad_rel)

if subplot.get_ylabel() != "":
upper = data[subplot.get_ylabel()].max()
lower = data[subplot.get_ylabel()].min()

pad_rel = (upper - lower) * 0.1 # padding relative to range of data points

subplot.set_ylim(lower - pad_rel, upper + pad_rel)

grid.add_legend()

return grid


Expand Down

0 comments on commit c9049bb

Please sign in to comment.