Skip to content

Commit 77e765a

Browse files
committed
classify/curves_plotly.py show threshold on hover, add _add_no_skill_line helper
no skill line no longer included in legend by default but annotated in plot print time taken in errors/warnings in test_import_time
1 parent ee40029 commit 77e765a

File tree

4 files changed

+150
-78
lines changed

4 files changed

+150
-78
lines changed

assets/svg/precision-recall-curve-plotly-multiple.svg

+1-1
Loading

pymatviz/classify/curves_plotly.py

+119-51
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Plotly-based classification metrics visualization."""
22

3-
from typing import Any, TypeAlias
3+
from typing import Any, Literal, TypeAlias
44

55
import numpy as np
6+
import pandas as pd
67
import plotly.graph_objects as go
78
import sklearn.metrics as skm
89
from numpy.typing import ArrayLike
@@ -18,14 +19,23 @@ def _standardize_input(
1819
targets: ArrayLike | str,
1920
probs_positive: Predictions,
2021
df: Any = None,
22+
*,
23+
strict: bool = False,
2124
) -> tuple[ArrayLike, dict[str, dict[str, Any]]]:
2225
"""Standardize input into tuple of (targets, {name: {probs_positive,
2326
**trace_kwargs}}).
2427
25-
Handles three input formats for probs_positive:
26-
1. Basic: array of probabilities
27-
2. dict of arrays: {"name": probabilities}
28-
3. dict of dicts: {"name": {"probs_positive": np.array, **trace_kwargs}}
28+
Args:
29+
targets: Ground truth binary labels
30+
probs_positive: Either:
31+
- Predicted probabilities for positive class, or
32+
- dict of form {"name": probabilities}, or
33+
- dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
34+
df: Optional DataFrame containing targets and probs_positive columns
35+
strict: If True, check that probabilities are in [0, 1].
36+
37+
Returns:
38+
tuple[ArrayLike, dict[str, dict[str, Any]]]: targets, curves_dict
2939
"""
3040
if df is not None:
3141
if not isinstance(targets, str):
@@ -50,32 +60,90 @@ def _standardize_input(
5060
else:
5161
curves_dict = {"": {"probs_positive": probs_positive}}
5262

53-
for trace_dict in curves_dict.values():
54-
curve_probs = np.asarray(trace_dict["probs_positive"])
55-
min_prob, max_prob = curve_probs.min(), curve_probs.max()
56-
if not (0 <= min_prob <= max_prob <= 1):
57-
raise ValueError(
58-
f"Probabilities must be in [0, 1], got range {(min_prob, max_prob)}"
59-
)
63+
if strict:
64+
for trace_dict in curves_dict.values():
65+
curve_probs = np.asarray(trace_dict["probs_positive"])
66+
curve_probs_no_nan = curve_probs[~np.isnan(curve_probs)]
67+
min_prob, max_prob = curve_probs_no_nan.min(), curve_probs_no_nan.max()
68+
if not (0 <= min_prob <= max_prob <= 1):
69+
raise ValueError(
70+
f"Probabilities must be in [0, 1], got range {(min_prob, max_prob)}"
71+
)
6072

6173
return targets, curves_dict
6274

6375

76+
def _add_no_skill_line(
77+
fig: go.Figure, y_values: ArrayLike, scatter_kwargs: dict[str, Any] | None = None
78+
) -> None:
79+
"""Add no-skill baseline line to figure.
80+
81+
Args:
82+
fig (go.Figure): Plotly figure to add line to
83+
y_values (ArrayLike): Y-values for no-skill line (constant or linear)
84+
scatter_kwargs (dict[str, Any] | None): Options for no-skill baseline.
85+
Commonly needed keys:
86+
- show_legend: bool = True
87+
- annotation: dict = None (plotly annotation dict to label the line)
88+
All other keys are passed to fig.add_scatter()
89+
"""
90+
if scatter_kwargs is False:
91+
return
92+
93+
scatter_kwargs = scatter_kwargs or {}
94+
annotation = scatter_kwargs.pop("annotation", {})
95+
96+
no_skill_line = dict(color="gray", width=1, dash="dash")
97+
no_skill_defaults = dict(
98+
x=np.linspace(0, 1, 100),
99+
y=y_values,
100+
name="No skill",
101+
line=no_skill_line,
102+
showlegend=False,
103+
hovertemplate=(
104+
"<b>No skill</b><br>"
105+
f"{fig.layout.xaxis.title.text}: %{{x:.3f}}<br>"
106+
f"{fig.layout.yaxis.title.text}: %{{y:.3f}}<br>"
107+
"<extra></extra>"
108+
),
109+
)
110+
fig.add_scatter(**no_skill_defaults | scatter_kwargs)
111+
112+
if annotation is not None:
113+
anno_defaults = dict(
114+
x=0.5,
115+
y=0.5,
116+
text="No skill",
117+
showarrow=False,
118+
font=dict(color="gray"),
119+
yshift=10,
120+
)
121+
fig.add_annotation(anno_defaults | annotation)
122+
123+
64124
def roc_curve_plotly(
65125
targets: ArrayLike | str,
66126
probs_positive: Predictions,
67-
df: Any = None,
127+
df: pd.DataFrame | None = None,
128+
*,
129+
no_skill: dict[str, Any] | Literal[False] | None = None,
68130
**kwargs: Any,
69131
) -> go.Figure:
70132
"""Plot the receiver operating characteristic (ROC) curve using Plotly.
71133
72134
Args:
73-
targets: Ground truth binary labels
74-
probs_positive: Either:
135+
targets (ArrayLike | str): Ground truth binary labels
136+
probs_positive (Predictions): Either:
75137
- Predicted probabilities for positive class, or
76138
- dict of form {"name": probabilities}, or
77139
- dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
78-
df: Optional DataFrame containing targets and probs_positive columns
140+
df (pd.DataFrame | None): Optional DataFrame containing targets and
141+
probs_positive columns
142+
no_skill (dict[str, Any] | False): Options for no-skill baseline
143+
or False to hide it. Commonly needed keys:
144+
- show_legend: bool = True
145+
- annotation: dict = None (plotly annotation dict to label the line)
146+
All other keys are passed to fig.add_scatter()
79147
**kwargs: Additional keywords passed to fig.add_scatter()
80148
81149
Returns:
@@ -90,7 +158,7 @@ def roc_curve_plotly(
90158
curve_probs = np.asarray(trace_kwargs.pop("probs_positive"))
91159

92160
no_nan = ~np.isnan(targets) & ~np.isnan(curve_probs)
93-
fpr, tpr, _ = skm.roc_curve(targets[no_nan], curve_probs[no_nan])
161+
fpr, tpr, thresholds = skm.roc_curve(targets[no_nan], curve_probs[no_nan])
94162
roc_auc = skm.roc_auc_score(targets[no_nan], curve_probs[no_nan])
95163

96164
roc_str = f"AUC={roc_auc:.2f}"
@@ -106,8 +174,10 @@ def roc_curve_plotly(
106174
f"<b>{display_name}</b><br>"
107175
"FPR: %{x:.3f}<br>"
108176
"TPR: %{y:.3f}<br>"
177+
"Threshold: %{customdata.threshold:.3f}<br>"
109178
"<extra></extra>"
110179
),
180+
"customdata": [dict(threshold=thr) for thr in thresholds],
111181
"meta": dict(roc_auc=roc_auc),
112182
}
113183
fig.add_scatter(**trace_defaults | kwargs | trace_kwargs)
@@ -116,18 +186,10 @@ def roc_curve_plotly(
116186
fig.data = sorted(fig.data, key=lambda tr: tr.meta.get("roc_auc"), reverse=True)
117187

118188
# Random baseline (has 100 points so whole line is hoverable, not just end points)
119-
rand_baseline = dict(color="gray", width=2, dash="dash")
120-
fig.add_scatter(
121-
x=np.linspace(0, 1, 100),
122-
y=np.linspace(0, 1, 100),
123-
name="Random",
124-
line=rand_baseline,
125-
hovertemplate=(
126-
"<b>Random</b><br>"
127-
"FPR: %{x:.3f}<br>"
128-
"TPR: %{y:.3f}<br>"
129-
"<extra></extra>"
130-
),
189+
_add_no_skill_line(
190+
fig,
191+
y_values=np.linspace(0, 1, 100),
192+
scatter_kwargs=dict(annotation=dict(textangle=0)) | (no_skill or {}),
131193
)
132194

133195
fig.layout.legend.update(yanchor="bottom", y=0, xanchor="right", x=0.99)
@@ -142,18 +204,26 @@ def roc_curve_plotly(
142204
def precision_recall_curve_plotly(
143205
targets: ArrayLike | str,
144206
probs_positive: Predictions,
145-
df: Any = None,
207+
df: pd.DataFrame | None = None,
208+
*,
209+
no_skill: dict[str, Any] | None = None,
146210
**kwargs: Any,
147211
) -> go.Figure:
148212
"""Plot the precision-recall curve using Plotly.
149213
150214
Args:
151-
targets: Ground truth binary labels
152-
probs_positive: Either:
215+
targets (ArrayLike | str): Ground truth binary labels
216+
probs_positive (Predictions): Either:
153217
- Predicted probabilities for positive class, or
154218
- dict of form {"name": probabilities}, or
155219
- dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
156-
df: Optional DataFrame containing targets and probs_positive columns
220+
df (pd.DataFrame | None): Optional DataFrame containing targets and
221+
probs_positive columns
222+
no_skill (dict[str, Any] | None): options for no-skill baseline or None
223+
to hide it. Commonly needed keys:
224+
- show_legend: bool = True
225+
- annotation: dict = None (plotly annotation dict to label the line)
226+
All other keys are passed to fig.add_scatter()
157227
**kwargs: Additional keywords passed to fig.add_scatter()
158228
159229
Returns:
@@ -166,18 +236,23 @@ def precision_recall_curve_plotly(
166236
for idx, (name, trace_kwargs) in enumerate(curves_dict.items()):
167237
# Extract required data and optional trace kwargs
168238
curve_probs = np.asarray(trace_kwargs.pop("probs_positive"))
169-
170239
no_nan = ~np.isnan(targets) & ~np.isnan(curve_probs)
171-
precision, recall, _ = skm.precision_recall_curve(
240+
prec_curve, recall_curve, thresholds = skm.precision_recall_curve(
172241
targets[no_nan], curve_probs[no_nan]
173242
)
243+
# f1 scores for each threshold
244+
f1_curve = 2 * (prec_curve * recall_curve) / (prec_curve + recall_curve)
245+
f1_curve = np.nan_to_num(f1_curve) # Handle division by zero
174246
f1_score = skm.f1_score(targets[no_nan], np.round(curve_probs[no_nan]))
175247

248+
# append final value since threshold has N-1 elements
249+
thresholds = [*thresholds, 1.0]
250+
176251
metrics_str = f"F1={f1_score:.2f}"
177252
display_name = f"{name} · {metrics_str}" if name else metrics_str
178253
trace_defaults = {
179-
"x": recall,
180-
"y": precision,
254+
"x": recall_curve,
255+
"y": prec_curve,
181256
"name": display_name,
182257
"line": dict(
183258
width=2, dash=PLOTLY_LINE_STYLES[idx % len(PLOTLY_LINE_STYLES)]
@@ -186,9 +261,14 @@ def precision_recall_curve_plotly(
186261
f"<b>{display_name}</b><br>"
187262
"Recall: %{x:.3f}<br>"
188263
"Prec: %{y:.3f}<br>"
189-
"F1: {f1_score:.3f}<br>"
264+
"F1: %{customdata.f1:.3f}<br>"
265+
"Threshold: %{customdata.threshold:.3f}<br>"
190266
"<extra></extra>"
191267
),
268+
"customdata": [
269+
dict(threshold=thr, f1=f1)
270+
for thr, f1 in zip(thresholds, f1_curve, strict=True)
271+
],
192272
"meta": dict(f1_score=f1_score),
193273
}
194274
fig.add_scatter(**trace_defaults | kwargs | trace_kwargs)
@@ -197,19 +277,7 @@ def precision_recall_curve_plotly(
197277
fig.data = sorted(fig.data, key=lambda tr: tr.meta.get("f1_score"), reverse=True)
198278

199279
# No-skill baseline (has 100 points so whole line is hoverable, not just end points)
200-
no_skill_line = dict(color="gray", width=2, dash="dash")
201-
fig.add_scatter(
202-
x=np.linspace(0, 1, 100),
203-
y=np.full_like(np.linspace(0, 1, 100), 0.5),
204-
name="No skill",
205-
line=no_skill_line,
206-
hovertemplate=(
207-
"<b>No skill</b><br>"
208-
"Recall: %{x:.3f}<br>"
209-
"Prec: %{y:.3f}<br>"
210-
"<extra></extra>"
211-
),
212-
)
280+
_add_no_skill_line(fig, y_values=np.full(100, 0.5), scatter_kwargs=no_skill)
213281

214282
fig.layout.legend.update(yanchor="bottom", y=0, xanchor="left", x=0)
215283
fig.layout.update(xaxis_title="Recall", yaxis_title="Precision")

tests/.pytest-split-durations

+8-8
Original file line numberDiff line numberDiff line change
@@ -449,14 +449,14 @@
449449
"tests/test_rdf.py::test_element_pair_rdfs_reference_line": 0.019186166988220066,
450450
"tests/test_rdf.py::test_element_pair_rdfs_subplot_layout": 0.013891042035538703,
451451
"tests/test_readme.py::test_no_missing_images": 0.0013329579960554838,
452-
"tests/test_relevance.py::test_precision_recall_curve[None-y_binary0-y_proba0-None]": 0.013538249972043559,
453-
"tests/test_relevance.py::test_precision_recall_curve[None-y_binary0-y_proba0-ax1]": 0.0019967920379713178,
454-
"tests/test_relevance.py::test_precision_recall_curve[df1-y_binary-y_proba-None]": 0.010434875992359594,
455-
"tests/test_relevance.py::test_precision_recall_curve[df1-y_binary-y_proba-ax1]": 0.0031556260073557496,
456-
"tests/test_relevance.py::test_roc_curve[None-y_binary0-y_proba0-None]": 0.012826499965740368,
457-
"tests/test_relevance.py::test_roc_curve[None-y_binary0-y_proba0-ax1]": 0.0016057499451562762,
458-
"tests/test_relevance.py::test_roc_curve[df1-y_binary-y_proba-None]": 0.011317623982904479,
459-
"tests/test_relevance.py::test_roc_curve[df1-y_binary-y_proba-ax1]": 0.0020046669815201312,
452+
"tests/classify/test_curves_matplotlib.py::test_precision_recall_curve[None-y_binary0-y_proba0-None]": 0.013538249972043559,
453+
"tests/classify/test_curves_matplotlib.py::test_precision_recall_curve[None-y_binary0-y_proba0-ax1]": 0.0019967920379713178,
454+
"tests/classify/test_curves_matplotlib.py::test_precision_recall_curve[df1-y_binary-y_proba-None]": 0.010434875992359594,
455+
"tests/classify/test_curves_matplotlib.py::test_precision_recall_curve[df1-y_binary-y_proba-ax1]": 0.0031556260073557496,
456+
"tests/classify/test_curves_matplotlib.py::test_roc_curve[None-y_binary0-y_proba0-None]": 0.012826499965740368,
457+
"tests/classify/test_curves_matplotlib.py::test_roc_curve[None-y_binary0-y_proba0-ax1]": 0.0016057499451562762,
458+
"tests/classify/test_curves_matplotlib.py::test_roc_curve[df1-y_binary-y_proba-None]": 0.011317623982904479,
459+
"tests/classify/test_curves_matplotlib.py::test_roc_curve[df1-y_binary-y_proba-ax1]": 0.0020046669815201312,
460460
"tests/test_sankey.py::test_sankey_from_2_df_cols[False]": 0.0017317909805569798,
461461
"tests/test_sankey.py::test_sankey_from_2_df_cols[True]": 0.008406835026107728,
462462
"tests/test_sankey.py::test_sankey_from_2_df_cols[percent]": 0.001626958983251825,

tests/performance_tests/test_import_time.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@
1818

1919
# Last update: 2024-10-23
2020
REF_IMPORT_TIME: dict[str, float] = {
21-
"pymatviz": 2084.25,
22-
"pymatviz.coordination": 2342.41,
23-
"pymatviz.cumulative": 2299.73,
24-
"pymatviz.histogram": 2443.11,
25-
"pymatviz.phonons": 2235.57,
26-
"pymatviz.powerups": 2172.71,
27-
"pymatviz.ptable": 2286.77,
28-
"pymatviz.rainclouds": 2702.03,
29-
"pymatviz.rdf": 2331.98,
30-
"pymatviz.relevance": 2256.29,
31-
"pymatviz.sankey": 2313.12,
32-
"pymatviz.scatter": 2312.48,
33-
"pymatviz.structure_viz": 2330.39,
34-
"pymatviz.sunburst": 2395.04,
35-
"pymatviz.uncertainty": 2317.87,
36-
"pymatviz.xrd": 2242.09,
21+
"pymatviz": 2084,
22+
"pymatviz.coordination": 2342,
23+
"pymatviz.cumulative": 2299,
24+
"pymatviz.histogram": 2443,
25+
"pymatviz.phonons": 2235,
26+
"pymatviz.powerups": 2172,
27+
"pymatviz.ptable": 2286,
28+
"pymatviz.rainclouds": 2702,
29+
"pymatviz.rdf": 2331,
30+
"pymatviz.classify": 2256,
31+
"pymatviz.sankey": 2313,
32+
"pymatviz.scatter": 2312,
33+
"pymatviz.structure_viz": 2330,
34+
"pymatviz.sunburst": 2395,
35+
"pymatviz.uncertainty": 2317,
36+
"pymatviz.xrd": 2242,
3737
}
3838

3939

@@ -96,9 +96,13 @@ def test_import_time(grace_percent: float = 0.20, hard_percent: float = 0.50) ->
9696

9797
if current_time > grace_threshold:
9898
if current_time > hard_threshold:
99-
pytest.fail(f"{module_name} import too slow! {hard_threshold=:.2f} ms")
99+
pytest.fail(
100+
f"{module_name} import too slow! took {current_time:.0f} ms, "
101+
f"{hard_threshold=:.0f} ms"
102+
)
100103
else:
101104
warnings.warn(
102-
f"{module_name} import slightly slower: {grace_threshold=:.2f} ms",
105+
f"{module_name} import slightly slower: took {current_time:.0f} "
106+
f"ms, {grace_threshold=:.0f} ms",
103107
stacklevel=2,
104108
)

0 commit comments

Comments
 (0)