Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove internal type shed for plotly #2715

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ax/analysis/old/helpers/cross_validation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def diagonal_trace(min_: float, max_: float, visible: bool = True) -> dict[str,
max_: maximum to be used for ending point of line.
visible: if True, trace is set to visible.
"""
# pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
return go.Scatter(
x=[min_, max_],
y=[min_, max_],
Expand Down
1 change: 1 addition & 0 deletions ax/analysis/old/helpers/layout_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ def layout_format(
width=530,
height=500,
)
# pyre-fixme[7]: Expected `Type[Figure]` but got `Layout`.
return layout
2 changes: 2 additions & 0 deletions ax/analysis/old/helpers/scatter_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def error_dot_plot_trace_from_df(

trace.update(visible=visible)
trace.update(showlegend=False)
# pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
return trace


Expand Down Expand Up @@ -309,4 +310,5 @@ def error_scatter_trace_from_df(

trace.update(visible=visible)
trace.update(showlegend=True)
# pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
return trace
1 change: 0 additions & 1 deletion ax/analysis/old/tests/test_cross_validation_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def setUp(self) -> None:

def test_cross_validation_plot(self) -> None:
plot = CrossValidationPlot(experiment=self.exp, model=self.model).get_fig()
# pyre-ignore [16]
x_range = plot.layout.updatemenus[0].buttons[0].args[1]["xaxis.range"]
y_range = plot.layout.updatemenus[0].buttons[0].args[1]["yaxis.range"]
self.assertTrue((len(x_range) == 2) and (x_range[0] < x_range[1]))
Expand Down
1 change: 1 addition & 0 deletions ax/plot/bandit_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,5 @@ def plot_bandit_rollout(experiment: Experiment) -> AxPlotConfig:
del bandit["index"] # Have to delete index or figure creation causes error
fig = go.Figure(data=bandits, layout=layout)

# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
4 changes: 4 additions & 0 deletions ax/plot/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ def plot_contour(
AxPlotConfig: contour plot of objective vs. parameter values
"""
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=plot_contour_plotly(
model=model,
param_x=param_x,
Expand Down Expand Up @@ -925,6 +927,8 @@ def interact_contour(
AxPlotConfig: interactive plot of objective vs. parameters
"""
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=interact_contour_plotly(
model=model,
metric_name=metric_name,
Expand Down
8 changes: 7 additions & 1 deletion ax/plot/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _diagonal_trace(min_: float, max_: float, visible: bool = True) -> dict[str,
visible: if True, trace is set to visible.

"""
# pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
return go.Scatter(
x=[min_, max_],
y=[min_, max_],
Expand Down Expand Up @@ -479,6 +480,7 @@ def interact_empirical_model_validation(batch: BatchTrial, data: Data) -> AxPlot

fig = _obs_vs_pred_dropdown_plot(data=plot_data, rel=False)
fig["layout"]["title"] = "Cross-validation"
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -541,6 +543,8 @@ def interact_cross_validation(
Returns an AxPlotConfig
"""
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=interact_cross_validation_plotly(
cv_results=cv_results,
show_context=show_context,
Expand Down Expand Up @@ -601,7 +605,7 @@ def tile_cross_validation(
y_raw.append(arm.y[metric])
se_raw.append(arm.se[metric])
min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw, se_hat)
fig.append_trace( # pyre-ignore[16]
fig.append_trace(
_diagonal_trace(min_, max_), int(np.floor(i / 2)) + 1, i % 2 + 1
)
fig.append_trace(
Expand Down Expand Up @@ -645,6 +649,7 @@ def tile_cross_validation(
title="Predicted Outcome", mirror=True, linecolor="black", linewidth=0.5
)

# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -687,4 +692,5 @@ def interact_batch_comparison(
ylabel=y_label,
)
fig["layout"]["title"] = "Repeated arms across trials"
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
13 changes: 11 additions & 2 deletions ax/plot/feature_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def plot_feature_importance_plotly(df: pd.DataFrame, title: str) -> go.Figure:
)

for idx, item in enumerate(data):
fig.append_trace(item, idx + 1, 1) # pyre-ignore[16]
fig.append_trace(item, idx + 1, 1)
fig.layout.showlegend = False
fig.layout.margin = go.layout.Margin(
l=8 * min(max(len(idx) for idx in df.index), 75) # noqa E741
Expand All @@ -53,7 +53,10 @@ def plot_feature_importance(df: pd.DataFrame, title: str) -> AxPlotConfig:
"""Wrapper method to convert plot_feature_importance_plotly to
AxPlotConfig"""
return AxPlotConfig(
data=plot_feature_importance_plotly(df, title), plot_type=AxPlotTypes.GENERIC
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=plot_feature_importance_plotly(df, title),
plot_type=AxPlotTypes.GENERIC,
)


Expand Down Expand Up @@ -90,6 +93,8 @@ def plot_feature_importance_by_metric(model: ModelBridge) -> AxPlotConfig:
"""Wrapper method to convert plot_feature_importance_by_metric_plotly to
AxPlotConfig"""
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=plot_feature_importance_by_metric_plotly(model),
plot_type=AxPlotTypes.GENERIC,
)
Expand Down Expand Up @@ -282,6 +287,8 @@ def plot_feature_importance_by_feature(
"""Wrapper method to convert `plot_feature_importance_by_feature_plotly` to
AxPlotConfig"""
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=plot_feature_importance_by_feature_plotly(
model=model,
sensitivity_values=sensitivity_values,
Expand Down Expand Up @@ -328,6 +335,8 @@ def plot_relative_feature_importance(model: ModelBridge) -> AxPlotConfig:
"""Wrapper method to convert plot_relative_feature_importance_plotly to
AxPlotConfig"""
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=plot_relative_feature_importance_plotly(model),
plot_type=AxPlotTypes.GENERIC,
)
3 changes: 2 additions & 1 deletion ax/plot/marginal_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ def plot_marginal_effects(model: ModelBridge, metric: str) -> AxPlotConfig:
shared_yaxes=True,
)
for idx, item in enumerate(data):
fig.append_trace(item, 1, idx + 1) # pyre-ignore[16]
fig.append_trace(item, 1, idx + 1)
fig.layout.showlegend = False
# fig.layout.margin = go.layout.Margin(l=2, r=2)
fig.layout.title = "Marginal Effects by Factor"
fig.layout.yaxis = {
"title": "% higher than experiment average",
"hoverformat": ".{}f".format(DECIMALS),
}
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
1 change: 1 addition & 0 deletions ax/plot/parallel_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ def plot_parallel_coordinates(
experiment=experiment, ignored_names=ignored_names
)

# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
6 changes: 6 additions & 0 deletions ax/plot/pareto_frontier.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def scatter_plot_with_pareto_frontier(
minimize: bool = True,
) -> AxPlotConfig:
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=scatter_plot_with_pareto_frontier_plotly(
Y=Y,
Y_pareto=Y_pareto,
Expand Down Expand Up @@ -411,6 +413,7 @@ def plot_pareto_frontier(
)

fig = go.Figure(data=[trace], layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -519,6 +522,7 @@ def plot_multiple_pareto_frontiers(
)

fig = go.Figure(data=traces, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -625,6 +629,7 @@ def interact_pareto_frontier(
)

fig = go.Figure(data=traces, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -767,6 +772,7 @@ def interact_multiple_pareto_frontier(
)

fig = go.Figure(data=traces, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down
14 changes: 10 additions & 4 deletions ax/plot/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def _error_scatter_trace(
trace.update(legendgroup=legendgroup)
if showlegend is not None:
trace.update(showlegend=showlegend)
# pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
return trace


Expand Down Expand Up @@ -535,6 +536,7 @@ def plot_multiple_metrics(
legend={"x": 1 + layout_offset_x},
)
fig = go.Figure(data=traces, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -772,6 +774,7 @@ def plot_objective_vs_constraints(
)

fig = go.Figure(data=plot_data, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -867,7 +870,7 @@ def lattice_multiple_metrics(
visible=True,
show_arm_details_on_hover=show_arm_details_on_hover,
)
fig.append_trace(obs_insample_trace, j, i) # pyre-ignore[16]
fig.append_trace(obs_insample_trace, j, i)
fig.append_trace(predicted_insample_trace, j, i)

# iterate over models here
Expand Down Expand Up @@ -1068,6 +1071,7 @@ def lattice_multiple_metrics(
for xaxis in boxplot_xaxes:
fig["layout"][xaxis]["showticklabels"] = False

# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -1291,6 +1295,7 @@ def plot_fitted(
)

fig = go.Figure(data=traces, layout=layout)
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)


Expand Down Expand Up @@ -1403,9 +1408,7 @@ def tile_fitted(
"zerolinecolor": "red",
}
for d in data:
fig.append_trace( # pyre-ignore[16]
d, int(np.floor(i / ncols)) + 1, i % ncols + 1
)
fig.append_trace(d, int(np.floor(i / ncols)) + 1, i % ncols + 1)

order_options = [
{"args": [name_order_args], "label": "Name", "method": "relayout"},
Expand Down Expand Up @@ -1469,6 +1472,7 @@ def tile_fitted(
},
)

# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
fig = resize_subtitles(figure=fig, size=10)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)

Expand Down Expand Up @@ -1660,6 +1664,8 @@ def interact_fitted(
"""

return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=interact_fitted_plotly(
model=model,
generator_runs_dict=generator_runs_dict,
Expand Down
4 changes: 4 additions & 0 deletions ax/plot/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def plot_slice(
AxPlotConfig: plot of objective vs. parameter value
"""
return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=plot_slice_plotly(
model=model,
param_name=param_name,
Expand Down Expand Up @@ -570,6 +572,8 @@ def interact_slice(
"""

return AxPlotConfig(
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got
# `Figure`.
data=interact_slice_plotly(
model=model,
generator_runs_dict=generator_runs_dict,
Expand Down
1 change: 1 addition & 0 deletions ax/plot/table_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,5 @@ def transpose(m):
)
fig = go.Figure(data=[trace], layout=layout)
# pyre-fixme[7]: Expected `Tuple[DataFrame]` but got `AxPlotConfig`.
# pyre-fixme[6]: For 1st argument expected `Dict[str, typing.Any]` but got `Figure`.
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
2 changes: 0 additions & 2 deletions ax/plot/tests/test_contours.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,4 @@ def test_Contours(self) -> None:
plot = interact_contour_plotly(
model, list(model.metric_names)[0], parameters_to_use=parameters_to_use
)
# pyre-ignore[16]: `plotly.graph_objs.graph_objs.Figure`
# has no attribute `layout`.
self.assertEqual(len(plot.layout.updatemenus[0].buttons), i)
1 change: 0 additions & 1 deletion ax/plot/tests/test_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def test_cross_validation(self) -> None:
plot = interact_cross_validation_plotly(
cv, label_dict=label_dict, autoset_axis_limits=autoset_axis_limits
)
# pyre-ignore [16]
x_range = plot.layout.updatemenus[0].buttons[0].args[1]["xaxis.range"]
y_range = plot.layout.updatemenus[0].buttons[0].args[1]["yaxis.range"]
if autoset_axis_limits:
Expand Down
1 change: 0 additions & 1 deletion ax/plot/tests/test_feature_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def test_FeatureImportances(self) -> None:
model=model, caption=DUMMY_CAPTION
)
self.assertIsInstance(plot, go.Figure)
# pyre-fixme[16]: `Figure` has no attribute `layout`.
self.assertEqual(len(plot.layout.annotations), 1)
self.assertEqual(plot.layout.annotations[0].text, DUMMY_CAPTION)
plot = plot_feature_importance_by_feature(model=model)
Expand Down
2 changes: 1 addition & 1 deletion ax/plot/tests/test_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_TracesAutoAxes(self) -> None:
optimization_direction=optimization_direction,
autoset_axis_limits=True,
)
self.assertIsNone(plot.layout.xaxis.range) # pyre-ignore
self.assertIsNone(plot.layout.xaxis.range)
if optimization_direction == "minimize":
self.assertAlmostEqual(plot.layout.yaxis.range[0], 0.525)
self.assertAlmostEqual(plot.layout.yaxis.range[1], 6.225)
Expand Down
Loading