Skip to content

Commit

Permalink
Inspect multiple boxes in prim and plot them in a single figure
Browse files Browse the repository at this point in the history
closes #124
  • Loading branch information
quaquel committed Nov 20, 2023
1 parent e28ee4a commit ec46ce6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 7 deletions.
24 changes: 21 additions & 3 deletions ema_workbench/analysis/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def __getattr__(self, name):
else:
raise AttributeError

def inspect(self, i=None, style="table", **kwargs):
def inspect(self, i=None, style="table", ax=None, **kwargs):
"""Write the stats and box limits of the user specified box to
standard out. If i is not provided, the last box will be
printed
Expand All @@ -410,20 +410,38 @@ def inspect(self, i=None, style="table", **kwargs):
the style of the visualization. 'table' prints the stats and
boxlim. 'graph' creates a figure. 'data' returns a list of
tuples, where each tuple contains the stats and the box_lims.
ax : axes or list of axes instances, optional
used in conjunction with `graph` style, allows you to control the axes on which graph is plotted
if i is list, axes should be list of equal length. If axes is None, each i_j in i will be plotted
in a separate figure.
additional kwargs are passed to the helper function that
generates the table or graph
"""
if style not in {"table", "graph", "data"}:
raise ValueError(f"style must be one of 'table', 'graph', or 'data', not {style}")

if i is None:
i = [self._cur_box]
elif isinstance(i, int):
i = [i]

if isinstance(ax, mpl.axes.Axes):
ax = [ax]

if not all(isinstance(x, int) for x in i):
raise TypeError(f"i must be an integer or list of integers, not {type(i)}")

return [self._inspect(entry, style=style, **kwargs) for entry in i]
if (ax is not None) and style == "graph":
if len(ax) != len(i):
raise ValueError(
f"the number of axes ({len(ax)}) does not match the number of boxes to inspect ({len(i)})"
)
else:
return [self._inspect(i_j, style=style, ax=ax, **kwargs) for i_j, ax in zip(i, ax)]
else:
return [self._inspect(entry, style=style, **kwargs) for entry in i]

def _inspect(self, i=None, style="table", **kwargs):
"""Helper method for inspecting one or more boxes on the
Expand Down Expand Up @@ -515,10 +533,10 @@ def _inspect_graph(
uncs,
self.peeling_trajectory.at[i, "coverage"],
self.peeling_trajectory.at[i, "density"],
ax,
ticklabel_formatter=ticklabel_formatter,
boxlim_formatter=boxlim_formatter,
table_formatter=table_formatter,
ax=ax,
)

def inspect_tradeoff(self):
Expand Down
2 changes: 1 addition & 1 deletion ema_workbench/analysis/scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,10 @@ def plot_box(
uncs,
coverage,
density,
ax,
ticklabel_formatter="{} ({})",
boxlim_formatter="{: .2g}",
table_formatter="{:.3g}",
ax=None,
):
"""Helper function for parallel coordinate style visualization
of a box
Expand Down
3 changes: 1 addition & 2 deletions ema_workbench/examples/sd_prim_flu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def classify(data):

fig, axes = plt.subplots(nrows=2, ncols=1)

for i, ax in zip([5, 6], axes):
box_1._inspect(i, style="graph", boxlim_formatter="{: .2f}", ax=ax)
box_1.inspect([5, 6], style="graph", boxlim_formatter="{: .2f}", ax=axes)
plt.show()

box_1.inspect(5)
Expand Down
3 changes: 2 additions & 1 deletion test/test_analysis/test_scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def test_plot_box(self):

qp_values = {"a": [0.05, 0.9], "c": [0.05, -1]}

sdutil.plot_box(boxlim, qp_values, box_init, restricted_dims, 1, 1)
fig, ax = plt.subplots()
sdutil.plot_box(boxlim, qp_values, box_init, restricted_dims, 1, 1, ax)
plt.draw()
plt.close("all")

Expand Down

0 comments on commit ec46ce6

Please sign in to comment.