Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add the option to save a figure in plot setting params #351

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,9 @@ def plot_perf_over_time(
The settings of a pair of color and label for each plot.
args, kwargs (Any):
Arguments for the ax.plot.

Note:
You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment.
"""

if not hasattr(metrics, metric_name):
Expand Down
48 changes: 36 additions & 12 deletions autoPyTorch/utils/results_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, NamedTuple, Optional, Tuple

import matplotlib.pyplot as plt

Expand Down Expand Up @@ -71,8 +71,7 @@ def extract_dicts(
return colors, labels


@dataclass(frozen=True)
class PlotSettingParams:
class PlotSettingParams(NamedTuple):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is due to the fact that dataclass does not allow to have dict or list.

"""
Parameters for the plot environment.

Expand All @@ -93,12 +92,28 @@ class PlotSettingParams:
The range of x axis.
ylim (Tuple[float, float]):
The range of y axis.
grid (bool):
Whether to have grid lines.
If users would like to define lines in detail,
they need to deactivate it.
legend (bool):
Whether to have legend in the figure.
legend_loc (str):
The location of the legend.
legend_kwargs (Dict[str, Any]):
The kwargs for ax.legend.
Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html
title (Optional[str]):
The title of the figure.
title_kwargs (Dict[str, Any]):
The kwargs for ax.set_title except title label.
Ref: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.set_title.html
show (bool):
Whether to show the plot.
If figname is not None, the save will be prioritized.
figname (Optional[str]):
Name of a figure to save. If None, no figure will be saved.
savefig_kwargs (Dict[str, Any]):
The kwargs for plt.savefig except filename.
Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html
args, kwargs (Any):
Arguments for the ax.plot.
"""
Expand All @@ -108,12 +123,16 @@ class PlotSettingParams:
xlabel: Optional[str] = None
ylabel: Optional[str] = None
title: Optional[str] = None
title_kwargs: Dict[str, Any] = {}
xlim: Optional[Tuple[float, float]] = None
ylim: Optional[Tuple[float, float]] = None
grid: bool = True
legend: bool = True
legend_loc: str = 'best'
legend_kwargs: Dict[str, Any] = {}
show: bool = False
figname: Optional[str] = None
figsize: Optional[Tuple[int, int]] = None
savefig_kwargs: Dict[str, Any] = {}


class ScaleChoices(Enum):
Expand Down Expand Up @@ -201,17 +220,22 @@ def _set_plot_args(

ax.set_xscale(plot_setting_params.xscale)
ax.set_yscale(plot_setting_params.yscale)
if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log':
ax.grid(True, which='minor', color='gray', linestyle=':')

ax.grid(True, which='major', color='black')
if plot_setting_params.grid:
if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log':
ax.grid(True, which='minor', color='gray', linestyle=':')

ax.grid(True, which='major', color='black')

if plot_setting_params.legend:
ax.legend(loc=plot_setting_params.legend_loc)
ax.legend(**plot_setting_params.legend_kwargs)

if plot_setting_params.title is not None:
ax.set_title(plot_setting_params.title)
if plot_setting_params.show:
ax.set_title(plot_setting_params.title, **plot_setting_params.title_kwargs)

if plot_setting_params.figname is not None:
plt.savefig(plot_setting_params.figname, **plot_setting_params.savefig_kwargs)
elif plot_setting_params.show:
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
plt.show()

@staticmethod
Expand Down
11 changes: 5 additions & 6 deletions examples/40_advanced/example_plot_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,20 @@
xlabel='Runtime',
ylabel='Accuracy',
title='Toy Example',
show=False # If you would like to show, make it True
figname='example_plot_over_time.png',
savefig_kwargs={'bbox_inches': 'tight'},
show=False # If you would like to show, make it True and set figname=None
)

############################################################################
# Plot with the Specified Setting Parameters
# ==========================================
_, ax = plt.subplots()
# _, ax = plt.subplots() <=== You can feed it to post-process the figure.

# You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment.
api.plot_perf_over_time(
ax=ax, # You do not have to provide.
metric_name=metric_name,
plot_setting_params=params,
marker='*',
markersize=10
)

# plt.show() might cause issue depending on environments
plt.savefig('example_plot_over_time.png')
30 changes: 8 additions & 22 deletions test/test_utils/test_results_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,9 @@ def test_extract_results_from_run_history():
time=1.0,
status=StatusType.CAPPED,
)
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
SearchResults(metric=accuracy, scoring_functions=[], run_history=run_history)

assert excinfo._excinfo[0] == ValueError


def test_raise_error_in_update_and_sort_by_time():
cs = ConfigurationSpace()
Expand All @@ -179,7 +177,7 @@ def test_raise_error_in_update_and_sort_by_time():
sr = SearchResults(metric=accuracy, scoring_functions=[], run_history=RunHistory())
er = EnsembleResults(metric=accuracy, ensemble_performance_history=[])

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
sr._update(
config=config,
run_key=RunKey(config_id=0, instance_id=0, seed=0),
Expand All @@ -189,19 +187,13 @@ def test_raise_error_in_update_and_sort_by_time():
)
)

assert excinfo._excinfo[0] == RuntimeError

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
sr._sort_by_endtime()

assert excinfo._excinfo[0] == RuntimeError

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
er._update(data={})

assert excinfo._excinfo[0] == RuntimeError

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
er._sort_by_endtime()


Expand Down Expand Up @@ -244,11 +236,9 @@ def test_raise_error_in_get_start_time():
status=StatusType.CAPPED,
)

with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
get_start_time(run_history)

assert excinfo._excinfo[0] == ValueError


def test_search_results_sort_by_endtime():
run_history = RunHistory()
Expand Down Expand Up @@ -364,11 +354,9 @@ def test_metric_results(metric, scores, ensemble_ends_later):
def test_search_results_sprint_statistics():
api = BaseTask()
for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']:
with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
getattr(api, method)()

assert excinfo._excinfo[0] == RuntimeError

run_history_data = json.load(open(os.path.join(os.path.dirname(__file__),
'runhistory.json'),
mode='r'))['data']
Expand Down Expand Up @@ -420,11 +408,9 @@ def test_check_run_history(run_history):
manager = ResultsManager()
manager.run_history = run_history

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
manager._check_run_history()

assert excinfo._excinfo[0] == RuntimeError


@pytest.mark.parametrize('include_traditional', (True, False))
@pytest.mark.parametrize('metric', (accuracy, log_loss))
Expand Down
48 changes: 37 additions & 11 deletions test/test_utils/test_results_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,46 @@ def test_extract_dicts(cl_settings, with_ensemble):

@pytest.mark.parametrize('params', (
PlotSettingParams(show=True),
PlotSettingParams(show=False)
PlotSettingParams(show=False),
PlotSettingParams(show=True, figname='dummy')
))
def test_plt_show_in_set_plot_args(params): # TODO
plt.show = MagicMock()
plt.savefig = MagicMock()
_, ax = plt.subplots(nrows=1, ncols=1)
viz = ResultsVisualizer()

viz._set_plot_args(ax, params)
assert plt.show._mock_called == params.show
# if figname is not None, show will not be called. (due to the matplotlib design)
assert plt.show._mock_called == (params.figname is None and params.show)
plt.close()


@pytest.mark.parametrize('params', (
PlotSettingParams(),
PlotSettingParams(figname='fig')
))
def test_plt_savefig_in_set_plot_args(params): # TODO
plt.savefig = MagicMock()
_, ax = plt.subplots(nrows=1, ncols=1)
viz = ResultsVisualizer()

viz._set_plot_args(ax, params)
assert plt.savefig._mock_called == (params.figname is not None)
plt.close()


@pytest.mark.parametrize('params', (
PlotSettingParams(grid=True),
PlotSettingParams(grid=False)
))
def test_ax_grid_in_set_plot_args(params): # TODO
_, ax = plt.subplots(nrows=1, ncols=1)
ax.grid = MagicMock()
viz = ResultsVisualizer()

viz._set_plot_args(ax, params)
assert ax.grid._mock_called == params.grid
plt.close()


Expand All @@ -77,10 +108,9 @@ def test_raise_value_error_in_set_plot_args(params): # TODO
_, ax = plt.subplots(nrows=1, ncols=1)
viz = ResultsVisualizer()

with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
viz._set_plot_args(ax, params)

assert excinfo._excinfo[0] == ValueError
plt.close()


Expand Down Expand Up @@ -119,13 +149,11 @@ def test_raise_error_in_plot_perf_over_time_in_base_task(metric_name):
api = BaseTask()

if metric_name == 'unknown':
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
api.plot_perf_over_time(metric_name)
assert excinfo._excinfo[0] == ValueError
else:
with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
api.plot_perf_over_time(metric_name)
assert excinfo._excinfo[0] == RuntimeError


@pytest.mark.parametrize('metric_name', ('balanced_accuracy', 'accuracy'))
Expand Down Expand Up @@ -175,16 +203,14 @@ def test_raise_error_get_perf_and_time(params):
results = np.linspace(-1, 1, 10)
cum_times = np.linspace(0, 1, 10)

with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
_get_perf_and_time(
cum_results=results,
cum_times=cum_times,
plot_setting_params=params,
worst_val=np.inf
)

assert excinfo._excinfo[0] == ValueError


@pytest.mark.parametrize('params', (
PlotSettingParams(n_points=20, xscale='linear', yscale='linear'),
Expand Down