Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inspect multiple boxes and display them in a single figure #317

Merged
merged 4 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions ema_workbench/analysis/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from operator import itemgetter

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
Expand Down Expand Up @@ -396,9 +397,9 @@ def __getattr__(self, name):
else:
raise AttributeError

def inspect(self, i=None, style="table", **kwargs):
def inspect(self, i=None, style="table", ax=None, **kwargs):
"""Write the stats and box limits of the user specified box to
standard out. if i is not provided, the last box will be
standard out. If i is not provided, the last box will be
printed

Parameters
Expand All @@ -409,20 +410,38 @@ def inspect(self, i=None, style="table", **kwargs):
the style of the visualization. 'table' prints the stats and
boxlim. 'graph' creates a figure. 'data' returns a list of
tuples, where each tuple contains the stats and the box_lims.
ax : axes or list of axes instances, optional
used in conjunction with `graph` style, allows you to control the axes on which graph is plotted
if i is list, axes should be list of equal length. If axes is None, each i_j in i will be plotted
in a separate figure.

additional kwargs are passed to the helper function that
generates the table or graph

"""
if style not in {"table", "graph", "data"}:
raise ValueError(f"style must be one of 'table', 'graph', or 'data', not {style}")

if i is None:
i = [self._cur_box]
elif isinstance(i, int):
i = [i]

if isinstance(ax, mpl.axes.Axes):
ax = [ax]

if not all(isinstance(x, int) for x in i):
raise TypeError(f"i must be an integer or list of integers, not {type(i)}")

return [self._inspect(entry, style=style, **kwargs) for entry in i]
if (ax is not None) and style == "graph":
if len(ax) != len(i):
raise ValueError(
f"the number of axes ({len(ax)}) does not match the number of boxes to inspect ({len(i)})"
)
else:
return [self._inspect(i_j, style=style, ax=ax, **kwargs) for i_j, ax in zip(i, ax)]
else:
return [self._inspect(entry, style=style, **kwargs) for entry in i]

def _inspect(self, i=None, style="table", **kwargs):
"""Helper method for inspecting one or more boxes on the
Expand Down Expand Up @@ -450,7 +469,13 @@ def _inspect(self, i=None, style="table", **kwargs):
if style == "table":
return self._inspect_table(i, uncs, qp_values)
elif style == "graph":
return self._inspect_graph(i, uncs, qp_values, **kwargs)
# makes it possible to use _inspect to plot multiple
# boxes into a single figure
try:
ax = kwargs.pop("ax")
except KeyError:
fig, ax = plt.subplots()
return self._inspect_graph(i, uncs, qp_values, ax=ax, **kwargs)
elif style == "data":
return self._inspect_data(i, uncs, qp_values)
else:
Expand Down Expand Up @@ -496,6 +521,7 @@ def _inspect_graph(
ticklabel_formatter="{} ({})",
boxlim_formatter="{: .2g}",
table_formatter="{:.3g}",
ax=None,
):
"""Helper method for visualizing box statistics in
graph form"""
Expand All @@ -507,6 +533,7 @@ def _inspect_graph(
uncs,
self.peeling_trajectory.at[i, "coverage"],
self.peeling_trajectory.at[i, "density"],
ax,
ticklabel_formatter=ticklabel_formatter,
boxlim_formatter=boxlim_formatter,
table_formatter=table_formatter,
Expand Down
22 changes: 15 additions & 7 deletions ema_workbench/analysis/scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def plot_pair_wise_scatter(
fill_subplots=True,
):
"""helper function for pair wise scatter plotting

Parameters
----------
x : DataFrame
Expand Down Expand Up @@ -530,16 +531,19 @@ def plot_pair_wise_scatter(
return grid


def _setup_figure(uncs):
def _setup_figure(uncs, ax):
"""

helper function for creating the basic layout for the figures that
show the box lims.

Parameters
----------
uncs : list of str
ax : axes instance

"""
nr_unc = len(uncs)
fig = plt.figure()
ax = fig.add_subplot(111)

# create the shaded grey background
rect = mpl.patches.Rectangle(
Expand All @@ -551,7 +555,6 @@ def _setup_figure(uncs):
ax.yaxis.set_ticks(list(range(nr_unc)))
ax.xaxis.set_ticks([0, 0.25, 0.5, 0.75, 1])
ax.set_yticklabels(uncs[::-1])
return fig, ax


def plot_box(
Expand All @@ -561,6 +564,7 @@ def plot_box(
uncs,
coverage,
density,
ax,
ticklabel_formatter="{} ({})",
boxlim_formatter="{: .2g}",
table_formatter="{:.3g}",
Expand All @@ -579,6 +583,7 @@ def plot_box(
ticklabel_formatter : str
boxlim_formatter : str
table_formatter : str
ax : Axes instance

Returns
-------
Expand All @@ -587,8 +592,9 @@ def plot_box(

"""
norm_box_lim = _normalize(boxlim, box_init, uncs)
fig = plt.gcf()

fig, ax = _setup_figure(uncs)
_setup_figure(uncs, ax)
for j, u in enumerate(uncs):
# we want to have the most restricted dimension
# at the top of the figure
Expand Down Expand Up @@ -842,7 +848,8 @@ def plot_boxes(x, boxes, together):
norm_box_lims = [_normalize(box_lim, box_init, uncs) for box_lim in boxes]

if together:
fig, ax = _setup_figure(uncs)
fig, ax = plt.subplots()
_setup_figure(uncs, ax)

for i, u in enumerate(uncs):
colors = itertools.cycle(COLOR_LIST)
Expand All @@ -862,7 +869,8 @@ def plot_boxes(x, boxes, together):
colors = itertools.cycle(COLOR_LIST)

for j, norm_box_lim in enumerate(norm_box_lims):
fig, ax = _setup_figure(uncs)
fig, ax = plt.subplots()
_setup_figure(uncs, ax)
ax.set_title(f"box {j}")
color = next(colors)

Expand Down
8 changes: 7 additions & 1 deletion ema_workbench/examples/sd_prim_flu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def classify(data):
box_1 = prim_obj.find_box()
box_1.show_ppt()
box_1.show_tradeoff()
box_1.inspect(5, style="graph", boxlim_formatter="{: .2f}")
# box_1.inspect([5, 6], style="graph", boxlim_formatter="{: .2f}")

fig, axes = plt.subplots(nrows=2, ncols=1)

box_1.inspect([5, 6], style="graph", boxlim_formatter="{: .2f}", ax=axes)
plt.show()

box_1.inspect(5)
box_1.select(5)
box_1.write_ppt_to_stdout()
Expand Down
Binary file removed test/data/test.tar.gz
Binary file not shown.
11 changes: 11 additions & 0 deletions test/test_analysis/test_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from ema_workbench.analysis import prim
from ema_workbench.analysis.prim import PrimBox
Expand Down Expand Up @@ -69,9 +70,19 @@ def test_inspect(self):
box.inspect(1)
box.inspect()
box.inspect(style="graph")
box.inspect(style="data")

box.inspect([0, 1])

fig, axes = plt.subplots(2)
box.inspect([0, 1], ax=axes, style="graph")

fig, ax = plt.subplots()
box.inspect(0, ax=ax, style="graph")

with pytest.raises(ValueError):
fig, axes = plt.subplots(3)
box.inspect([0, 1], ax=axes, style="graph")
with pytest.raises(ValueError):
box.inspect(style="some unknown style")
with pytest.raises(TypeError):
Expand Down
3 changes: 2 additions & 1 deletion test/test_analysis/test_scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def test_plot_box(self):

qp_values = {"a": [0.05, 0.9], "c": [0.05, -1]}

sdutil.plot_box(boxlim, qp_values, box_init, restricted_dims, 1, 1)
fig, ax = plt.subplots()
sdutil.plot_box(boxlim, qp_values, box_init, restricted_dims, 1, 1, ax)
quaquel marked this conversation as resolved.
Show resolved Hide resolved
plt.draw()
plt.close("all")

Expand Down
Loading