Skip to content

Commit

Permalink
classify/curves_plotly.py show threshold on hover, add _add_no_skill_…
Browse files Browse the repository at this point in the history
…line helper

no skill line no longer included in legend by default but annotated in plot

print time taken in errors/warnings in test_import_time
  • Loading branch information
janosh committed Nov 27, 2024
1 parent ee40029 commit 77e765a
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 78 deletions.
2 changes: 1 addition & 1 deletion assets/svg/precision-recall-curve-plotly-multiple.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
170 changes: 119 additions & 51 deletions pymatviz/classify/curves_plotly.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Plotly-based classification metrics visualization."""

from typing import Any, TypeAlias
from typing import Any, Literal, TypeAlias

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import sklearn.metrics as skm
from numpy.typing import ArrayLike
Expand All @@ -18,14 +19,23 @@ def _standardize_input(
targets: ArrayLike | str,
probs_positive: Predictions,
df: Any = None,
*,
strict: bool = False,
) -> tuple[ArrayLike, dict[str, dict[str, Any]]]:
"""Standardize input into tuple of (targets, {name: {probs_positive,
**trace_kwargs}}).
Handles three input formats for probs_positive:
1. Basic: array of probabilities
2. dict of arrays: {"name": probabilities}
3. dict of dicts: {"name": {"probs_positive": np.array, **trace_kwargs}}
Args:
targets: Ground truth binary labels
probs_positive: Either:
- Predicted probabilities for positive class, or
- dict of form {"name": probabilities}, or
- dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
df: Optional DataFrame containing targets and probs_positive columns
strict: If True, check that probabilities are in [0, 1].
Returns:
tuple[ArrayLike, dict[str, dict[str, Any]]]: targets, curves_dict
"""
if df is not None:
if not isinstance(targets, str):
Expand All @@ -50,32 +60,90 @@ def _standardize_input(
else:
curves_dict = {"": {"probs_positive": probs_positive}}

for trace_dict in curves_dict.values():
curve_probs = np.asarray(trace_dict["probs_positive"])
min_prob, max_prob = curve_probs.min(), curve_probs.max()
if not (0 <= min_prob <= max_prob <= 1):
raise ValueError(
f"Probabilities must be in [0, 1], got range {(min_prob, max_prob)}"
)
if strict:
for trace_dict in curves_dict.values():
curve_probs = np.asarray(trace_dict["probs_positive"])
curve_probs_no_nan = curve_probs[~np.isnan(curve_probs)]
min_prob, max_prob = curve_probs_no_nan.min(), curve_probs_no_nan.max()
if not (0 <= min_prob <= max_prob <= 1):
raise ValueError(
f"Probabilities must be in [0, 1], got range {(min_prob, max_prob)}"
)

return targets, curves_dict


def _add_no_skill_line(
fig: go.Figure, y_values: ArrayLike, scatter_kwargs: dict[str, Any] | None = None
) -> None:
"""Add no-skill baseline line to figure.
Args:
fig (go.Figure): Plotly figure to add line to
y_values (ArrayLike): Y-values for no-skill line (constant or linear)
scatter_kwargs (dict[str, Any] | None): Options for no-skill baseline.
Commonly needed keys:
- show_legend: bool = True
- annotation: dict = None (plotly annotation dict to label the line)
All other keys are passed to fig.add_scatter()
"""
if scatter_kwargs is False:
return

scatter_kwargs = scatter_kwargs or {}
annotation = scatter_kwargs.pop("annotation", {})

no_skill_line = dict(color="gray", width=1, dash="dash")
no_skill_defaults = dict(
x=np.linspace(0, 1, 100),
y=y_values,
name="No skill",
line=no_skill_line,
showlegend=False,
hovertemplate=(
"<b>No skill</b><br>"
f"{fig.layout.xaxis.title.text}: %{{x:.3f}}<br>"
f"{fig.layout.yaxis.title.text}: %{{y:.3f}}<br>"
"<extra></extra>"
),
)
fig.add_scatter(**no_skill_defaults | scatter_kwargs)

if annotation is not None:
anno_defaults = dict(
x=0.5,
y=0.5,
text="No skill",
showarrow=False,
font=dict(color="gray"),
yshift=10,
)
fig.add_annotation(anno_defaults | annotation)


def roc_curve_plotly(
targets: ArrayLike | str,
probs_positive: Predictions,
df: Any = None,
df: pd.DataFrame | None = None,
*,
no_skill: dict[str, Any] | Literal[False] | None = None,
**kwargs: Any,
) -> go.Figure:
"""Plot the receiver operating characteristic (ROC) curve using Plotly.
Args:
targets: Ground truth binary labels
probs_positive: Either:
targets (ArrayLike | str): Ground truth binary labels
probs_positive (Predictions): Either:
- Predicted probabilities for positive class, or
- dict of form {"name": probabilities}, or
- dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
df: Optional DataFrame containing targets and probs_positive columns
df (pd.DataFrame | None): Optional DataFrame containing targets and
probs_positive columns
no_skill (dict[str, Any] | False): Options for no-skill baseline
or False to hide it. Commonly needed keys:
- show_legend: bool = True
- annotation: dict = None (plotly annotation dict to label the line)
All other keys are passed to fig.add_scatter()
**kwargs: Additional keywords passed to fig.add_scatter()
Returns:
Expand All @@ -90,7 +158,7 @@ def roc_curve_plotly(
curve_probs = np.asarray(trace_kwargs.pop("probs_positive"))

no_nan = ~np.isnan(targets) & ~np.isnan(curve_probs)
fpr, tpr, _ = skm.roc_curve(targets[no_nan], curve_probs[no_nan])
fpr, tpr, thresholds = skm.roc_curve(targets[no_nan], curve_probs[no_nan])
roc_auc = skm.roc_auc_score(targets[no_nan], curve_probs[no_nan])

roc_str = f"AUC={roc_auc:.2f}"
Expand All @@ -106,8 +174,10 @@ def roc_curve_plotly(
f"<b>{display_name}</b><br>"
"FPR: %{x:.3f}<br>"
"TPR: %{y:.3f}<br>"
"Threshold: %{customdata.threshold:.3f}<br>"
"<extra></extra>"
),
"customdata": [dict(threshold=thr) for thr in thresholds],
"meta": dict(roc_auc=roc_auc),
}
fig.add_scatter(**trace_defaults | kwargs | trace_kwargs)
Expand All @@ -116,18 +186,10 @@ def roc_curve_plotly(
fig.data = sorted(fig.data, key=lambda tr: tr.meta.get("roc_auc"), reverse=True)

# Random baseline (has 100 points so whole line is hoverable, not just end points)
rand_baseline = dict(color="gray", width=2, dash="dash")
fig.add_scatter(
x=np.linspace(0, 1, 100),
y=np.linspace(0, 1, 100),
name="Random",
line=rand_baseline,
hovertemplate=(
"<b>Random</b><br>"
"FPR: %{x:.3f}<br>"
"TPR: %{y:.3f}<br>"
"<extra></extra>"
),
_add_no_skill_line(
fig,
y_values=np.linspace(0, 1, 100),
scatter_kwargs=dict(annotation=dict(textangle=0)) | (no_skill or {}),
)

fig.layout.legend.update(yanchor="bottom", y=0, xanchor="right", x=0.99)
Expand All @@ -142,18 +204,26 @@ def roc_curve_plotly(
def precision_recall_curve_plotly(
targets: ArrayLike | str,
probs_positive: Predictions,
df: Any = None,
df: pd.DataFrame | None = None,
*,
no_skill: dict[str, Any] | None = None,
**kwargs: Any,
) -> go.Figure:
"""Plot the precision-recall curve using Plotly.
Args:
targets: Ground truth binary labels
probs_positive: Either:
targets (ArrayLike | str): Ground truth binary labels
probs_positive (Predictions): Either:
- Predicted probabilities for positive class, or
- dict of form {"name": probabilities}, or
- dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
df: Optional DataFrame containing targets and probs_positive columns
df (pd.DataFrame | None): Optional DataFrame containing targets and
probs_positive columns
no_skill (dict[str, Any] | None): options for no-skill baseline or None
to hide it. Commonly needed keys:
- show_legend: bool = True
- annotation: dict = None (plotly annotation dict to label the line)
All other keys are passed to fig.add_scatter()
**kwargs: Additional keywords passed to fig.add_scatter()
Returns:
Expand All @@ -166,18 +236,23 @@ def precision_recall_curve_plotly(
for idx, (name, trace_kwargs) in enumerate(curves_dict.items()):
# Extract required data and optional trace kwargs
curve_probs = np.asarray(trace_kwargs.pop("probs_positive"))

no_nan = ~np.isnan(targets) & ~np.isnan(curve_probs)
precision, recall, _ = skm.precision_recall_curve(
prec_curve, recall_curve, thresholds = skm.precision_recall_curve(
targets[no_nan], curve_probs[no_nan]
)
# f1 scores for each threshold
f1_curve = 2 * (prec_curve * recall_curve) / (prec_curve + recall_curve)
f1_curve = np.nan_to_num(f1_curve) # Handle division by zero
f1_score = skm.f1_score(targets[no_nan], np.round(curve_probs[no_nan]))

# append final value since threshold has N-1 elements
thresholds = [*thresholds, 1.0]

metrics_str = f"F1={f1_score:.2f}"
display_name = f"{name} · {metrics_str}" if name else metrics_str
trace_defaults = {
"x": recall,
"y": precision,
"x": recall_curve,
"y": prec_curve,
"name": display_name,
"line": dict(
width=2, dash=PLOTLY_LINE_STYLES[idx % len(PLOTLY_LINE_STYLES)]
Expand All @@ -186,9 +261,14 @@ def precision_recall_curve_plotly(
f"<b>{display_name}</b><br>"
"Recall: %{x:.3f}<br>"
"Prec: %{y:.3f}<br>"
"F1: {f1_score:.3f}<br>"
"F1: %{customdata.f1:.3f}<br>"
"Threshold: %{customdata.threshold:.3f}<br>"
"<extra></extra>"
),
"customdata": [
dict(threshold=thr, f1=f1)
for thr, f1 in zip(thresholds, f1_curve, strict=True)
],
"meta": dict(f1_score=f1_score),
}
fig.add_scatter(**trace_defaults | kwargs | trace_kwargs)
Expand All @@ -197,19 +277,7 @@ def precision_recall_curve_plotly(
fig.data = sorted(fig.data, key=lambda tr: tr.meta.get("f1_score"), reverse=True)

# No-skill baseline (has 100 points so whole line is hoverable, not just end points)
no_skill_line = dict(color="gray", width=2, dash="dash")
fig.add_scatter(
x=np.linspace(0, 1, 100),
y=np.full_like(np.linspace(0, 1, 100), 0.5),
name="No skill",
line=no_skill_line,
hovertemplate=(
"<b>No skill</b><br>"
"Recall: %{x:.3f}<br>"
"Prec: %{y:.3f}<br>"
"<extra></extra>"
),
)
_add_no_skill_line(fig, y_values=np.full(100, 0.5), scatter_kwargs=no_skill)

fig.layout.legend.update(yanchor="bottom", y=0, xanchor="left", x=0)
fig.layout.update(xaxis_title="Recall", yaxis_title="Precision")
Expand Down
16 changes: 8 additions & 8 deletions tests/.pytest-split-durations
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,14 @@
"tests/test_rdf.py::test_element_pair_rdfs_reference_line": 0.019186166988220066,
"tests/test_rdf.py::test_element_pair_rdfs_subplot_layout": 0.013891042035538703,
"tests/test_readme.py::test_no_missing_images": 0.0013329579960554838,
"tests/test_relevance.py::test_precision_recall_curve[None-y_binary0-y_proba0-None]": 0.013538249972043559,
"tests/test_relevance.py::test_precision_recall_curve[None-y_binary0-y_proba0-ax1]": 0.0019967920379713178,
"tests/test_relevance.py::test_precision_recall_curve[df1-y_binary-y_proba-None]": 0.010434875992359594,
"tests/test_relevance.py::test_precision_recall_curve[df1-y_binary-y_proba-ax1]": 0.0031556260073557496,
"tests/test_relevance.py::test_roc_curve[None-y_binary0-y_proba0-None]": 0.012826499965740368,
"tests/test_relevance.py::test_roc_curve[None-y_binary0-y_proba0-ax1]": 0.0016057499451562762,
"tests/test_relevance.py::test_roc_curve[df1-y_binary-y_proba-None]": 0.011317623982904479,
"tests/test_relevance.py::test_roc_curve[df1-y_binary-y_proba-ax1]": 0.0020046669815201312,
"tests/classify/test_curves_matplotlib.py::test_precision_recall_curve[None-y_binary0-y_proba0-None]": 0.013538249972043559,
"tests/classify/test_curves_matplotlib.py::test_precision_recall_curve[None-y_binary0-y_proba0-ax1]": 0.0019967920379713178,
"tests/classify/test_curves_matplotlib.py::test_precision_recall_curve[df1-y_binary-y_proba-None]": 0.010434875992359594,
"tests/classify/test_curves_matplotlib.py::test_precision_recall_curve[df1-y_binary-y_proba-ax1]": 0.0031556260073557496,
"tests/classify/test_curves_matplotlib.py::test_roc_curve[None-y_binary0-y_proba0-None]": 0.012826499965740368,
"tests/classify/test_curves_matplotlib.py::test_roc_curve[None-y_binary0-y_proba0-ax1]": 0.0016057499451562762,
"tests/classify/test_curves_matplotlib.py::test_roc_curve[df1-y_binary-y_proba-None]": 0.011317623982904479,
"tests/classify/test_curves_matplotlib.py::test_roc_curve[df1-y_binary-y_proba-ax1]": 0.0020046669815201312,
"tests/test_sankey.py::test_sankey_from_2_df_cols[False]": 0.0017317909805569798,
"tests/test_sankey.py::test_sankey_from_2_df_cols[True]": 0.008406835026107728,
"tests/test_sankey.py::test_sankey_from_2_df_cols[percent]": 0.001626958983251825,
Expand Down
40 changes: 22 additions & 18 deletions tests/performance_tests/test_import_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@

# Last update: 2024-10-23
REF_IMPORT_TIME: dict[str, float] = {
"pymatviz": 2084.25,
"pymatviz.coordination": 2342.41,
"pymatviz.cumulative": 2299.73,
"pymatviz.histogram": 2443.11,
"pymatviz.phonons": 2235.57,
"pymatviz.powerups": 2172.71,
"pymatviz.ptable": 2286.77,
"pymatviz.rainclouds": 2702.03,
"pymatviz.rdf": 2331.98,
"pymatviz.relevance": 2256.29,
"pymatviz.sankey": 2313.12,
"pymatviz.scatter": 2312.48,
"pymatviz.structure_viz": 2330.39,
"pymatviz.sunburst": 2395.04,
"pymatviz.uncertainty": 2317.87,
"pymatviz.xrd": 2242.09,
"pymatviz": 2084,
"pymatviz.coordination": 2342,
"pymatviz.cumulative": 2299,
"pymatviz.histogram": 2443,
"pymatviz.phonons": 2235,
"pymatviz.powerups": 2172,
"pymatviz.ptable": 2286,
"pymatviz.rainclouds": 2702,
"pymatviz.rdf": 2331,
"pymatviz.classify": 2256,
"pymatviz.sankey": 2313,
"pymatviz.scatter": 2312,
"pymatviz.structure_viz": 2330,
"pymatviz.sunburst": 2395,
"pymatviz.uncertainty": 2317,
"pymatviz.xrd": 2242,
}


Expand Down Expand Up @@ -96,9 +96,13 @@ def test_import_time(grace_percent: float = 0.20, hard_percent: float = 0.50) ->

if current_time > grace_threshold:
if current_time > hard_threshold:
pytest.fail(f"{module_name} import too slow! {hard_threshold=:.2f} ms")
pytest.fail(
f"{module_name} import too slow! took {current_time:.0f} ms, "
f"{hard_threshold=:.0f} ms"
)
else:
warnings.warn(
f"{module_name} import slightly slower: {grace_threshold=:.2f} ms",
f"{module_name} import slightly slower: took {current_time:.0f} "
f"ms, {grace_threshold=:.0f} ms",
stacklevel=2,
)

0 comments on commit 77e765a

Please sign in to comment.