Skip to content

Commit

Permalink
Add Hypothesis.style, a default style to apply for plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
wookayin committed Nov 7, 2023
1 parent eeafada commit 18a031e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
5 changes: 5 additions & 0 deletions expt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ class Hypothesis(Iterable[Run]):
"""
name: str
runs: RunList
style: Dict[str, Any] # TODO: Use some typing. TODO: Add tests.
config: Optional[RunConfig] = None

@typechecked
Expand All @@ -509,13 +510,16 @@ def __init__(
name: str,
runs: Union[Run, Iterable[Run]],
*,
style: Optional[Dict[str, Any]] = None,
config: Union[RunConfig, Literal['auto'], None] = 'auto',
):
"""Create a new Hypothesis object.
Args:
name: The name of the hypothesis. Should be unique within an Experiment.
runs: The underlying runs that this hypothesis consists of.
style: (optional) A dict that represents preferred style for plotting.
These will be passed as kwargs to plot().
config: A config dict that describes the configuration of the hypothesis.
A config is optional, where `config` is explicitly set to be `None`.
If config exists (not None), it should represent the config this
Expand Down Expand Up @@ -549,6 +553,7 @@ def __init__(
r.name for r in self.runs if r.config is None))

self.config = config
self.style = {**style} if style is not None else {}

def __iter__(self) -> Iterator[Run]:
return iter(self.runs)
Expand Down
40 changes: 26 additions & 14 deletions expt/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,10 @@ def _should_include_column(col_name: str) -> bool:
if suptitle is None and (ax is None and grid is None):
suptitle = self._parent.name

# Merge with hypothesis's default style
if self._parent.style:
kwargs = {**self._parent.style, **kwargs}

return self._do_plot(
y,
representative, # type: ignore
Expand Down Expand Up @@ -950,7 +954,8 @@ def __call__(
y = [yi for yi in y if yi != kwargs['x']]
kwargs['y'] = y

# Line style for each hypothesis.
# Assign line style for each hypothesis
# TODO: avoid conflicts as much as we can against hypothosis.style
axes_cycle = matplotlib.rcParams['axes.prop_cycle']() # type: ignore
axes_props = list(itertools.islice(axes_cycle, len(self._hypotheses)))
for key in list(axes_props[0].keys()):
Expand Down Expand Up @@ -995,37 +1000,44 @@ def __call__(
given_ax_or_grid = ('ax' in kwargs) or (grid is not None)

for i, (name, hypo) in enumerate(self._hypotheses.items()):
h_kwargs = kwargs.copy()
if grid is not None:
h_kwargs.pop('ax', None) # i=0: ax, i>0: grid

if isinstance(y, str):
# display different hypothesis over subplots:
kwargs['label'] = hypothesis_labels[i]
kwargs['subplots'] = False
h_kwargs['label'] = hypothesis_labels[i]
h_kwargs['subplots'] = False
if 'title' not in kwargs:
kwargs['title'] = y # column name
h_kwargs['title'] = y # column name

else:
# display multiple columns over subplots:
if y is not None:
kwargs['label'] = [f'{y_i} ({name})' for y_i in y]
if kwargs.get('prettify_labels', False):
kwargs['label'] = util.prettify_labels(kwargs['label'])
kwargs['subplots'] = True
h_kwargs['label'] = [f'{y_i} ({name})' for y_i in y]
if h_kwargs.get('prettify_labels', False):
h_kwargs['label'] = util.prettify_labels(h_kwargs['label'])
h_kwargs['subplots'] = True

h_kwargs.update(axes_props[i]) # e.g. color, linestyle, etc.

kwargs.update(axes_props[i]) # e.g. color, linestyle, etc.
# Hypothesis' own style should take more priority
h_kwargs.update(hypo.style)

# exclude the hypothesis if it has no runs in it
if hypo.empty():
warnings.warn(f"Hypothesis `{hypo.name}` has no data, "
"ignoring it", UserWarning)
continue

kwargs['tight_layout'] = False
kwargs['ignore_unknown'] = True
kwargs['suptitle'] = '' # no suptitle for each hypo
h_kwargs['tight_layout'] = False
h_kwargs['ignore_unknown'] = True
h_kwargs['suptitle'] = '' # no suptitle for each hypo

grid = hypo.plot(*args, grid=grid, **kwargs) # on the same ax(es)?
grid = hypo.plot(*args, grid=grid, **h_kwargs) # on the same ax(es)?
assert grid is not None

kwargs.pop('ax', None) # From now on, grid.axes will be used
assert grid is not None # True if len(hypothesis) > 0

# corner case: if there is only one column, use it as a label
if len(grid.axes_active) == 1 and isinstance(y, str):
Expand Down

0 comments on commit 18a031e

Please sign in to comment.