forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New parallel coordinates plot (facebook#2590)
Summary: Pull Request resolved: facebook#2590 An improved version of the parallel coordinates analysis I implemented earlier this year, now refactored with AnalysisCards. Main improvements include: * Ability to infer what metric to use based on the OptimizationConfig if one is not provided * Compatibility with ChoiceParameters and FixedParameters * Truncation of long parameter and metric names where appropriate NOTE: This analysis introduces a number of helper functions in parallel_coordinates.py -- as we add more analyses these should be refactored out into analysis/plotly/utils.py or analysis/utils.py as appropriate. Differential Revision: D59927703
- Loading branch information
1 parent
6d5b07f
commit 193e6e6
Showing
3 changed files
with
268 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
175 changes: 175 additions & 0 deletions
175
ax/analysis/plotly/parallel_coordinates/parallel_coordinates.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from ax.analysis.analysis import AnalysisCardLevel | ||
|
||
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard | ||
from ax.core.experiment import Experiment | ||
from ax.core.objective import MultiObjective, ScalarizedObjective | ||
from ax.exceptions.core import UnsupportedError, UserInputError | ||
from ax.modelbridge.generation_strategy import GenerationStrategy | ||
from plotly import graph_objects as go | ||
|
||
|
||
class ParallelCoordinatesPlot(PlotlyAnalysis): | ||
""" | ||
Plotly Parcoords plot for a single metric, with one line per arm and dimensions for | ||
each parameter in the search space. This plot is useful for understanding how | ||
thoroughly the search space is explored as well as for identifying if there is any | ||
clusertering for either good or bad parameterizations. | ||
The DataFrame computed will contain one row per arm and the following columns: | ||
- arm_name: The name of the arm | ||
- METRIC_NAME: The observed mean of the metric specified | ||
- **PARAMETER_NAME: The value of said parameter for the arm, for each parameter | ||
""" | ||
|
||
def __init__(self, metric_name: Optional[str] = None) -> None: | ||
""" | ||
Args: | ||
metric_name: The name of the metric to plot. If not specified the objective | ||
will be used. Note that the metric cannot be inferred for | ||
multi-objective or scalarized-objective experiments. | ||
""" | ||
|
||
self.metric_name = metric_name | ||
|
||
def compute( | ||
self, | ||
experiment: Optional[Experiment] = None, | ||
generation_strategy: Optional[GenerationStrategy] = None, | ||
) -> PlotlyAnalysisCard: | ||
if experiment is None: | ||
raise UserInputError("ParallelCoordinatesPlot requires an Experiment") | ||
|
||
metric_name = self.metric_name or _select_metric(experiment=experiment) | ||
|
||
df = _prepare_data(experiment=experiment, metric=metric_name) | ||
fig = _prepare_plot(df=df, metric_name=metric_name) | ||
|
||
return PlotlyAnalysisCard( | ||
name=self.__class__.__name__, | ||
title=f"Parallel Coordinates for {metric_name}", | ||
subtitle="View arm parameterizations with their respective metric values", | ||
level=AnalysisCardLevel.HIGH, | ||
df=df, | ||
blob=fig, | ||
) | ||
|
||
|
||
def _prepare_data(experiment: Experiment, metric: str) -> pd.DataFrame: | ||
data_df = experiment.lookup_data().df | ||
filtered_df = data_df.loc[data_df["metric_name"] == metric] | ||
|
||
if filtered_df.empty: | ||
raise ValueError(f"No data found for metric {metric}") | ||
|
||
records = [ | ||
{ | ||
"arm_name": arm.name, | ||
**arm.parameters, | ||
metric: _find_mean_by_arm_name(df=filtered_df, arm_name=arm.name), | ||
} | ||
for trial in experiment.trials.values() | ||
for arm in trial.arms | ||
] | ||
|
||
return pd.DataFrame.from_records(records) | ||
|
||
|
||
def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure: | ||
|
||
# ParCoords requires that the dimensions are specified on continuous scales, so | ||
# ChoiceParameters and FixedParameters must be preprocessed to allow for | ||
# appropriate plotting. | ||
parameter_dimensions = [ | ||
_get_parameter_dimension(series=df[col]) | ||
for col in df.columns | ||
if col != "arm_name" and col != metric_name | ||
] | ||
|
||
return go.Figure( | ||
go.Parcoords( | ||
line={ | ||
"color": df[metric_name], | ||
"showscale": True, | ||
}, | ||
dimensions=[ | ||
*parameter_dimensions, | ||
{ | ||
"label": _truncate_label(label=metric_name), | ||
"values": df[metric_name].tolist(), | ||
}, | ||
], | ||
) | ||
) | ||
|
||
|
||
def _select_metric(experiment: Experiment) -> str: | ||
if experiment.optimization_config is not None: | ||
objective = experiment.optimization_config.objective | ||
if isinstance(objective, MultiObjective): | ||
raise UnsupportedError( | ||
"Cannot infer metric to plot from MultiObjective, please " | ||
"specify a metric" | ||
) | ||
if isinstance(objective, ScalarizedObjective): | ||
raise UnsupportedError( | ||
"Cannot infer metric to plot from ScalarizedObjective, please " | ||
"specify a metric" | ||
) | ||
return experiment.optimization_config.objective.metric.name | ||
else: | ||
raise ValueError( | ||
"Cannot infer metric to plot from Experiment without OptimizationConfig" | ||
) | ||
|
||
|
||
def _find_mean_by_arm_name( | ||
df: pd.DataFrame, | ||
arm_name: str, | ||
) -> float: | ||
# Given a dataframe with arm_name and mean columns, find the mean for a given | ||
# arm_name. If an arm_name is not found (as can happen if the arm is still running | ||
# or has failed) return NaN. | ||
series = df.loc[df["arm_name"] == arm_name]["mean"] | ||
|
||
if series.empty: | ||
return np.nan | ||
|
||
return series.item() | ||
|
||
|
||
def _get_parameter_dimension(series: pd.Series) -> Dict[str, Any]: | ||
# For numeric parameters allow Plotly to infer tick attributes. Note: booleans are | ||
# considered numeric, but in this case we want to treat them as categorical. | ||
if pd.api.types.is_numeric_dtype(series) and not pd.api.types.is_bool_dtype(series): | ||
return { | ||
"tickvals": None, | ||
"ticktext": None, | ||
"label": _truncate_label(label=str(series.name)), | ||
"values": series.tolist(), | ||
} | ||
|
||
# For non-numeric parameters, sort, map onto an integer scale, and provide | ||
# corresponding tick attributes | ||
mapping = {v: k for k, v in enumerate(sorted(series.unique()))} | ||
|
||
return { | ||
"tickvals": [_truncate_label(label=str(val)) for val in mapping.values()], | ||
"ticktext": [_truncate_label(label=str(key)) for key in mapping.keys()], | ||
"label": _truncate_label(label=str(series.name)), | ||
"values": series.map(mapping).tolist(), | ||
} | ||
|
||
|
||
def _truncate_label(label: str, n: int = 8) -> str: | ||
if len(label) > n: | ||
return label[:n] + "..." | ||
return label |
89 changes: 89 additions & 0 deletions
89
ax/analysis/plotly/parallel_coordinates/tests/test_parallel_coordinates.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pandas as pd | ||
from ax.analysis.analysis import AnalysisCardLevel | ||
from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import ( | ||
_get_parameter_dimension, | ||
_select_metric, | ||
ParallelCoordinatesPlot, | ||
) | ||
from ax.exceptions.core import UnsupportedError, UserInputError | ||
from ax.utils.common.testutils import TestCase | ||
from ax.utils.testing.core_stubs import ( | ||
get_branin_experiment, | ||
get_experiment_with_multi_objective, | ||
get_experiment_with_scalarized_objective_and_outcome_constraint, | ||
) | ||
|
||
|
||
class TestParallelCoordinatesPlot(TestCase): | ||
def test_compute(self) -> None: | ||
analysis = ParallelCoordinatesPlot("branin") | ||
experiment = get_branin_experiment(with_completed_trial=True) | ||
|
||
with self.assertRaisesRegex(UserInputError, "requires an Experiment"): | ||
analysis.compute() | ||
|
||
card = analysis.compute(experiment=experiment) | ||
self.assertEqual(card.name, "ParallelCoordinatesPlot") | ||
self.assertEqual(card.title, "Parallel Coordinates for branin") | ||
self.assertEqual( | ||
card.subtitle, | ||
"View arm parameterizations with their respective metric values", | ||
) | ||
self.assertEqual(card.level, AnalysisCardLevel.HIGH) | ||
self.assertEqual({*card.df.columns}, {"arm_name", "branin", "x1", "x2"}) | ||
self.assertIsNotNone(card.blob) | ||
self.assertEqual(card.blob_annotation, "plotly") | ||
|
||
analysis_no_metric = ParallelCoordinatesPlot() | ||
_ = analysis_no_metric.compute(experiment=experiment) | ||
|
||
def test_select_metric(self) -> None: | ||
experiment = get_branin_experiment() | ||
experiment_no_optimization_config = get_branin_experiment( | ||
has_optimization_config=False | ||
) | ||
experiment_multi_objective = get_experiment_with_multi_objective() | ||
experiment_scalarized_objective = ( | ||
get_experiment_with_scalarized_objective_and_outcome_constraint() | ||
) | ||
|
||
self.assertEqual(_select_metric(experiment=experiment), "branin") | ||
|
||
with self.assertRaisesRegex(ValueError, "OptimizationConfig"): | ||
_select_metric(experiment=experiment_no_optimization_config) | ||
|
||
with self.assertRaisesRegex(UnsupportedError, "MultiObjective"): | ||
_select_metric(experiment=experiment_multi_objective) | ||
|
||
with self.assertRaisesRegex(UnsupportedError, "ScalarizedObjective"): | ||
_select_metric(experiment=experiment_scalarized_objective) | ||
|
||
def test_get_parameter_dimension(self) -> None: | ||
range_series = pd.Series([0, 1, 2, 3], name="range") | ||
range_dimension = _get_parameter_dimension(series=range_series) | ||
self.assertEqual( | ||
range_dimension, | ||
{ | ||
"tickvals": None, | ||
"ticktext": None, | ||
"label": "range", | ||
"values": range_series.tolist(), | ||
}, | ||
) | ||
|
||
choice_series = pd.Series(["foo", "bar", "baz"], name="choice") | ||
choice_dimension = _get_parameter_dimension(series=choice_series) | ||
self.assertEqual( | ||
choice_dimension, | ||
{ | ||
"tickvals": ["0", "1", "2"], | ||
"ticktext": ["bar", "baz", "foo"], | ||
"label": "choice", | ||
"values": [2, 0, 1], | ||
}, | ||
) |