diff --git a/ema_workbench/analysis/__init__.py b/ema_workbench/analysis/__init__.py index 595862021..17217f923 100644 --- a/ema_workbench/analysis/__init__.py +++ b/ema_workbench/analysis/__init__.py @@ -15,4 +15,5 @@ from .plotting import lines, envelopes, kde_over_time, multiple_densities from .plotting_util import Density, PlotType from .prim import Prim, run_constrained_prim, pca_preprocess, setup_prim +from .prim_util import DiagKind from .scenario_discovery_util import RuleInductionType diff --git a/ema_workbench/analysis/prim.py b/ema_workbench/analysis/prim.py index e924f23ba..bc2d3a84c 100644 --- a/ema_workbench/analysis/prim.py +++ b/ema_workbench/analysis/prim.py @@ -42,6 +42,7 @@ calculate_qp, determine_dimres, is_significant, + DiagKind, ) # Created on 22 feb. 2013 @@ -827,7 +828,15 @@ def show_tradeoff(self, cmap=mpl.cm.viridis, annotated=False): # @UndefinedVari """ return sdutil.plot_tradeoff(self.peeling_trajectory, cmap=cmap, annotated=annotated) - def show_pairs_scatter(self, i=None, dims=None, cdf=False): + def show_pairs_scatter( + self, + i=None, + dims=None, + diag_kind=DiagKind.KDE, + upper="scatter", + lower="contour", + fill_subplots=True, + ): """Make a pair wise scatter plot of all the restricted dimensions with color denoting whether a given point is of interest or not and the boxlims superimposed on top. @@ -837,8 +846,18 @@ def show_pairs_scatter(self, i=None, dims=None, cdf=False): i : int, optional dims : list of str, optional dimensions to show, defaults to all restricted dimensions - cdf : bool, optional - plot diag as cdf or pdf + diag_kind : {DiagKind.KDE, DiagKind.CDF} + Plot diagonal as kernel density estimate ('kde') or + cumulative density function ('cdf'). + upper, lower: string, optional + Use either 'scatter', 'contour', or 'hist' (bivariate + histogram) plots for upper and lower triangles. Upper triangle + can also be 'none' to eliminate redundancy. Legend uses + lower triangle style for markers. + fill_subplots: Boolean, optional + if True, subplots are resized to fill their respective axes. + This removes unnecessary whitespace, but may be undesirable + for some variable combinations. Returns ------- @@ -851,9 +870,11 @@ def show_pairs_scatter(self, i=None, dims=None, cdf=False): if dims is None: dims = sdutil._determine_restricted_dims(self.box_lims[i], self.prim.box_init) - # x = - # y = self.prim.y[self.yi_initial] - # order = np.argsort(y) + if diag_kind not in diag_kind.__members__: + raise ValueError( + f"diag_kind should be one of DiagKind.KDE or DiagKind.CDF, not {diag_kind}" + ) + diag = diag_kind.value return sdutil.plot_pair_wise_scatter( self.prim.x.iloc[self.yi_initial, :], @@ -861,7 +882,10 @@ def show_pairs_scatter(self, i=None, dims=None, cdf=False): self.box_lims[i], self.prim.box_init, dims, - cdf=cdf, + diag=diag, + upper=upper, + lower=lower, + fill_subplots=fill_subplots, ) def write_ppt_to_stdout(self): diff --git a/ema_workbench/analysis/prim_util.py b/ema_workbench/analysis/prim_util.py index 65d419725..f6b3c709c 100644 --- a/ema_workbench/analysis/prim_util.py +++ b/ema_workbench/analysis/prim_util.py @@ -30,6 +30,14 @@ class PRIMObjectiveFunctions(Enum): ORIGINAL = "original" +class DiagKind(Enum): + KDE = "kde" + """constant for plotting diagonal in pairs_scatter as kde""" + + CDF = "cdf" + """constant for plotting diagonal in pairs_scatter as cdf""" + + def get_quantile(data, quantile): """ quantile calculation modeled on the implementation used in sdtoolkit diff --git a/ema_workbench/analysis/scenario_discovery_util.py b/ema_workbench/analysis/scenario_discovery_util.py index f061606dc..4c362d26a 100644 --- a/ema_workbench/analysis/scenario_discovery_util.py +++ b/ema_workbench/analysis/scenario_discovery_util.py @@ -312,9 +312,18 @@ def _calculate_quasip(x, y, box, Hbox, Tbox): return qp.pvalue -def plot_pair_wise_scatter(x, y, boxlim, box_init, restricted_dims, cdf=False): +def plot_pair_wise_scatter( + x, + y, + boxlim, + box_init, + restricted_dims, + diag="kde", + upper="scatter", + lower="hist", + fill_subplots=True, +): """helper function for pair wise scatter plotting - Parameters ---------- x : DataFrame @@ -326,18 +335,23 @@ def plot_pair_wise_scatter(x, y, boxlim, box_init, restricted_dims, cdf=False): box_init : DataFrame restricted_dims : collection of strings list of uncertainties that define the boxlims - cdf : bool, optional - plot diagonal as pdf or cdf, defaults to kde approx. of pdf - - + diag : string, optional + Plot diagonal as kernel density estimate ('kde') or + cumulative density function ('cdf'). + upper, lower: string, optional + Use either 'scatter', 'contour', or 'hist' (bivariate + histogram) plots for upper and lower triangles. Upper triangle + can also be 'none' to eliminate redundancy. Legend uses + lower triangle style for markers. + fill_subplots: Boolean, optional + if True, subplots are resized to fill their respective axes. + This removes unnecessary whitespace, but may be undesirable + for some variable combinations. """ x = x[restricted_dims] data = x.copy() - # TODO:: have option to change - # diag to CDF, gives you effectively the - # regional sensitivity analysis results categorical_columns = data.select_dtypes("category").columns.values categorical_mappings = {} for column in categorical_columns: @@ -357,59 +371,122 @@ def plot_pair_wise_scatter(x, y, boxlim, box_init, restricted_dims, cdf=False): # replace column with codes data[column] = data[column].cat.codes + # add outcome of interest to DataFrame data["y"] = y # ensures cases of interest are plotted on top data.sort_values("y", inplace=True) - grid = sns.pairplot( - data=data, - hue="y", - vars=x.columns.values, - diag_kind="kde", - diag_kws={"cumulative": cdf, "common_norm": False, "fill": False}, - ) + # main plot body - cats = set(categorical_columns) - for row, ylabel in zip(grid.axes, grid.y_vars): - ylim = boxlim[ylabel] + grid = sns.PairGrid( + data=data, hue="y", vars=x.columns.values, diag_sharey=False + ) # enables different plots in upper and lower triangles - if ylabel in cats: - y = -0.2 - height = len(ylim[0]) - 0.6 # 2 * 0.2 - else: - y = ylim[0] - height = ylim[1] - ylim[0] + # upper triangle + if upper == "contour": + # draw contours twice to get different fill and line alphas, more interpretable + grid.map_upper( + sns.kdeplot, fill=True, alpha=0.8, bw_adjust=1.2, levels=5, common_norm=False, cut=0 + ) # cut = 0 + grid.map_upper( + sns.kdeplot, fill=False, alpha=1, bw_adjust=1.2, levels=5, common_norm=False, cut=0 + ) + elif upper == "hist": + grid.map_upper(sns.histplot) + elif upper == "scatter": + grid.map_upper(sns.scatterplot) + elif upper == "none": + None + else: + raise NotImplementedError( + f"upper = {upper} not implemented. Use either 'scatter', 'contour', 'hist' (bivariate histogram) or None plots for upper triangle." + ) + + # lower triangle + if lower == "contour": + # draw contours twice to get different fill and line alphas, more interpretable + grid.map_lower( + sns.kdeplot, fill=True, alpha=0.8, bw_adjust=1.2, levels=5, common_norm=False, cut=0 + ) # cut = 0 + grid.map_lower( + sns.kdeplot, fill=False, alpha=1, bw_adjust=1.2, levels=5, common_norm=False, cut=0 + ) + elif lower == "hist": + grid.map_lower(sns.histplot) + elif lower == "scatter": + grid.map_lower(sns.scatterplot) + elif lower == "none": + raise ValueError(f"Lower triangle cannot be none.") + else: + raise NotImplementedError( + f"lower = {lower} not implemented. Use either 'scatter', 'contour' or 'hist' (bivariate histogram) plots for lower triangle." + ) + + # diagonal + if diag == "cdf": + grid.map_diag(sns.ecdfplot) + elif diag == "kde": + grid.map_diag(sns.kdeplot, fill=False, common_norm=False, cut=0) + else: + raise NotImplementedError( + f"diag = {diag} not implemented. Use either 'kde' (kernel density estimate) or 'cdf' (cumulative density function)." + ) + # draw box + pad = 0.1 + + cats = set(categorical_columns) + for row, ylabel in zip(grid.axes, grid.y_vars): for ax, xlabel in zip(row, grid.x_vars): if ylabel == xlabel: continue + xrange = ax.get_xlim()[1] - ax.get_xlim()[0] + yrange = ax.get_ylim()[1] - ax.get_ylim()[0] + + ylim = boxlim[ylabel] + + if ylabel in cats: + height = (len(ylim[0]) - 1) + pad * yrange + y = -yrange * pad / 2 + else: + y = ylim[0] + height = ylim[1] - ylim[0] + if xlabel in cats: xlim = boxlim.at[0, xlabel] - x = -0.2 - width = len(xlim) - 0.6 # 2 * 0.2 + width = (len(xlim) - 1) + pad * xrange + x = -xrange * pad / 2 else: xlim = boxlim[xlabel] x = xlim[0] width = xlim[1] - xlim[0] xy = x, y - box = patches.Rectangle(xy, width, height, edgecolor="red", facecolor="none", lw=3) - ax.add_patch(box) + box = patches.Rectangle( + xy, width, height, edgecolor="red", facecolor="none", lw=3, zorder=100 + ) + if ax.has_data(): # keeps box from being drawn in upper triangle if empty + ax.add_patch(box) + else: + ax.set_axis_off() # do the yticklabeling for categorical rows for row, ylabel in zip(grid.axes, grid.y_vars): if ylabel in cats: ax = row[0] labels = [] - for entry in ax.get_yticklabels(): - _, value = entry.get_position() + locs = [] + mapping = categorical_mappings[ylabel] + for i in range(-1, len(mapping) + 1): + locs.append(i) try: - label = categorical_mappings[ylabel][value] + label = categorical_mappings[ylabel][i] except KeyError: label = "" labels.append(label) + ax.set_yticks(locs) ax.set_yticklabels(labels) # do the xticklabeling for categorical columns @@ -427,6 +504,29 @@ def plot_pair_wise_scatter(x, y, boxlim, box_init, restricted_dims, cdf=False): labels.append(label) ax.set_xticks(locs) ax.set_xticklabels(labels, rotation=90) + + # fit subplot to data ranges, with some padding for aesthetics + if fill_subplots == True: + for axis in grid.axes: + for subplot in axis: + if subplot.get_xlabel() != "": + upper = data[subplot.get_xlabel()].max() + lower = data[subplot.get_xlabel()].min() + + pad_rel = (upper - lower) * 0.1 # padding relative to range of data points + + subplot.set_xlim(lower - pad_rel, upper + pad_rel) + + if subplot.get_ylabel() != "": + upper = data[subplot.get_ylabel()].max() + lower = data[subplot.get_ylabel()].min() + + pad_rel = (upper - lower) * 0.1 # padding relative to range of data points + + subplot.set_ylim(lower - pad_rel, upper + pad_rel) + + grid.add_legend() + return grid diff --git a/ema_workbench/examples/sd_prim_byrant_and_lempert.py b/ema_workbench/examples/sd_prim_bryant_and_lempert.py similarity index 100% rename from ema_workbench/examples/sd_prim_byrant_and_lempert.py rename to ema_workbench/examples/sd_prim_bryant_and_lempert.py