1
1
"""Plotly-based classification metrics visualization."""
2
2
3
- from typing import Any , TypeAlias
3
+ from typing import Any , Literal , TypeAlias
4
4
5
5
import numpy as np
6
+ import pandas as pd
6
7
import plotly .graph_objects as go
7
8
import sklearn .metrics as skm
8
9
from numpy .typing import ArrayLike
@@ -18,14 +19,23 @@ def _standardize_input(
18
19
targets : ArrayLike | str ,
19
20
probs_positive : Predictions ,
20
21
df : Any = None ,
22
+ * ,
23
+ strict : bool = False ,
21
24
) -> tuple [ArrayLike , dict [str , dict [str , Any ]]]:
22
25
"""Standardize input into tuple of (targets, {name: {probs_positive,
23
26
**trace_kwargs}}).
24
27
25
- Handles three input formats for probs_positive:
26
- 1. Basic: array of probabilities
27
- 2. dict of arrays: {"name": probabilities}
28
- 3. dict of dicts: {"name": {"probs_positive": np.array, **trace_kwargs}}
28
+ Args:
29
+ targets: Ground truth binary labels
30
+ probs_positive: Either:
31
+ - Predicted probabilities for positive class, or
32
+ - dict of form {"name": probabilities}, or
33
+ - dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
34
+ df: Optional DataFrame containing targets and probs_positive columns
35
+ strict: If True, check that probabilities are in [0, 1].
36
+
37
+ Returns:
38
+ tuple[ArrayLike, dict[str, dict[str, Any]]]: targets, curves_dict
29
39
"""
30
40
if df is not None :
31
41
if not isinstance (targets , str ):
@@ -50,32 +60,90 @@ def _standardize_input(
50
60
else :
51
61
curves_dict = {"" : {"probs_positive" : probs_positive }}
52
62
53
- for trace_dict in curves_dict .values ():
54
- curve_probs = np .asarray (trace_dict ["probs_positive" ])
55
- min_prob , max_prob = curve_probs .min (), curve_probs .max ()
56
- if not (0 <= min_prob <= max_prob <= 1 ):
57
- raise ValueError (
58
- f"Probabilities must be in [0, 1], got range { (min_prob , max_prob )} "
59
- )
63
+ if strict :
64
+ for trace_dict in curves_dict .values ():
65
+ curve_probs = np .asarray (trace_dict ["probs_positive" ])
66
+ curve_probs_no_nan = curve_probs [~ np .isnan (curve_probs )]
67
+ min_prob , max_prob = curve_probs_no_nan .min (), curve_probs_no_nan .max ()
68
+ if not (0 <= min_prob <= max_prob <= 1 ):
69
+ raise ValueError (
70
+ f"Probabilities must be in [0, 1], got range { (min_prob , max_prob )} "
71
+ )
60
72
61
73
return targets , curves_dict
62
74
63
75
76
+ def _add_no_skill_line (
77
+ fig : go .Figure , y_values : ArrayLike , scatter_kwargs : dict [str , Any ] | None = None
78
+ ) -> None :
79
+ """Add no-skill baseline line to figure.
80
+
81
+ Args:
82
+ fig (go.Figure): Plotly figure to add line to
83
+ y_values (ArrayLike): Y-values for no-skill line (constant or linear)
84
+ scatter_kwargs (dict[str, Any] | None): Options for no-skill baseline.
85
+ Commonly needed keys:
86
+ - show_legend: bool = True
87
+ - annotation: dict = None (plotly annotation dict to label the line)
88
+ All other keys are passed to fig.add_scatter()
89
+ """
90
+ if scatter_kwargs is False :
91
+ return
92
+
93
+ scatter_kwargs = scatter_kwargs or {}
94
+ annotation = scatter_kwargs .pop ("annotation" , {})
95
+
96
+ no_skill_line = dict (color = "gray" , width = 1 , dash = "dash" )
97
+ no_skill_defaults = dict (
98
+ x = np .linspace (0 , 1 , 100 ),
99
+ y = y_values ,
100
+ name = "No skill" ,
101
+ line = no_skill_line ,
102
+ showlegend = False ,
103
+ hovertemplate = (
104
+ "<b>No skill</b><br>"
105
+ f"{ fig .layout .xaxis .title .text } : %{{x:.3f}}<br>"
106
+ f"{ fig .layout .yaxis .title .text } : %{{y:.3f}}<br>"
107
+ "<extra></extra>"
108
+ ),
109
+ )
110
+ fig .add_scatter (** no_skill_defaults | scatter_kwargs )
111
+
112
+ if annotation is not None :
113
+ anno_defaults = dict (
114
+ x = 0.5 ,
115
+ y = 0.5 ,
116
+ text = "No skill" ,
117
+ showarrow = False ,
118
+ font = dict (color = "gray" ),
119
+ yshift = 10 ,
120
+ )
121
+ fig .add_annotation (anno_defaults | annotation )
122
+
123
+
64
124
def roc_curve_plotly (
65
125
targets : ArrayLike | str ,
66
126
probs_positive : Predictions ,
67
- df : Any = None ,
127
+ df : pd .DataFrame | None = None ,
128
+ * ,
129
+ no_skill : dict [str , Any ] | Literal [False ] | None = None ,
68
130
** kwargs : Any ,
69
131
) -> go .Figure :
70
132
"""Plot the receiver operating characteristic (ROC) curve using Plotly.
71
133
72
134
Args:
73
- targets: Ground truth binary labels
74
- probs_positive: Either:
135
+ targets (ArrayLike | str) : Ground truth binary labels
136
+ probs_positive (Predictions) : Either:
75
137
- Predicted probabilities for positive class, or
76
138
- dict of form {"name": probabilities}, or
77
139
- dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
78
- df: Optional DataFrame containing targets and probs_positive columns
140
+ df (pd.DataFrame | None): Optional DataFrame containing targets and
141
+ probs_positive columns
142
+ no_skill (dict[str, Any] | False): Options for no-skill baseline
143
+ or False to hide it. Commonly needed keys:
144
+ - show_legend: bool = True
145
+ - annotation: dict = None (plotly annotation dict to label the line)
146
+ All other keys are passed to fig.add_scatter()
79
147
**kwargs: Additional keywords passed to fig.add_scatter()
80
148
81
149
Returns:
@@ -90,7 +158,7 @@ def roc_curve_plotly(
90
158
curve_probs = np .asarray (trace_kwargs .pop ("probs_positive" ))
91
159
92
160
no_nan = ~ np .isnan (targets ) & ~ np .isnan (curve_probs )
93
- fpr , tpr , _ = skm .roc_curve (targets [no_nan ], curve_probs [no_nan ])
161
+ fpr , tpr , thresholds = skm .roc_curve (targets [no_nan ], curve_probs [no_nan ])
94
162
roc_auc = skm .roc_auc_score (targets [no_nan ], curve_probs [no_nan ])
95
163
96
164
roc_str = f"AUC={ roc_auc :.2f} "
@@ -106,8 +174,10 @@ def roc_curve_plotly(
106
174
f"<b>{ display_name } </b><br>"
107
175
"FPR: %{x:.3f}<br>"
108
176
"TPR: %{y:.3f}<br>"
177
+ "Threshold: %{customdata.threshold:.3f}<br>"
109
178
"<extra></extra>"
110
179
),
180
+ "customdata" : [dict (threshold = thr ) for thr in thresholds ],
111
181
"meta" : dict (roc_auc = roc_auc ),
112
182
}
113
183
fig .add_scatter (** trace_defaults | kwargs | trace_kwargs )
@@ -116,18 +186,10 @@ def roc_curve_plotly(
116
186
fig .data = sorted (fig .data , key = lambda tr : tr .meta .get ("roc_auc" ), reverse = True )
117
187
118
188
# Random baseline (has 100 points so whole line is hoverable, not just end points)
119
- rand_baseline = dict (color = "gray" , width = 2 , dash = "dash" )
120
- fig .add_scatter (
121
- x = np .linspace (0 , 1 , 100 ),
122
- y = np .linspace (0 , 1 , 100 ),
123
- name = "Random" ,
124
- line = rand_baseline ,
125
- hovertemplate = (
126
- "<b>Random</b><br>"
127
- "FPR: %{x:.3f}<br>"
128
- "TPR: %{y:.3f}<br>"
129
- "<extra></extra>"
130
- ),
189
+ _add_no_skill_line (
190
+ fig ,
191
+ y_values = np .linspace (0 , 1 , 100 ),
192
+ scatter_kwargs = dict (annotation = dict (textangle = 0 )) | (no_skill or {}),
131
193
)
132
194
133
195
fig .layout .legend .update (yanchor = "bottom" , y = 0 , xanchor = "right" , x = 0.99 )
@@ -142,18 +204,26 @@ def roc_curve_plotly(
142
204
def precision_recall_curve_plotly (
143
205
targets : ArrayLike | str ,
144
206
probs_positive : Predictions ,
145
- df : Any = None ,
207
+ df : pd .DataFrame | None = None ,
208
+ * ,
209
+ no_skill : dict [str , Any ] | None = None ,
146
210
** kwargs : Any ,
147
211
) -> go .Figure :
148
212
"""Plot the precision-recall curve using Plotly.
149
213
150
214
Args:
151
- targets: Ground truth binary labels
152
- probs_positive: Either:
215
+ targets (ArrayLike | str) : Ground truth binary labels
216
+ probs_positive (Predictions) : Either:
153
217
- Predicted probabilities for positive class, or
154
218
- dict of form {"name": probabilities}, or
155
219
- dict of form {"name": {"probs_positive": np.array, **trace_kwargs}}
156
- df: Optional DataFrame containing targets and probs_positive columns
220
+ df (pd.DataFrame | None): Optional DataFrame containing targets and
221
+ probs_positive columns
222
+ no_skill (dict[str, Any] | None): options for no-skill baseline or None
223
+ to hide it. Commonly needed keys:
224
+ - show_legend: bool = True
225
+ - annotation: dict = None (plotly annotation dict to label the line)
226
+ All other keys are passed to fig.add_scatter()
157
227
**kwargs: Additional keywords passed to fig.add_scatter()
158
228
159
229
Returns:
@@ -166,18 +236,23 @@ def precision_recall_curve_plotly(
166
236
for idx , (name , trace_kwargs ) in enumerate (curves_dict .items ()):
167
237
# Extract required data and optional trace kwargs
168
238
curve_probs = np .asarray (trace_kwargs .pop ("probs_positive" ))
169
-
170
239
no_nan = ~ np .isnan (targets ) & ~ np .isnan (curve_probs )
171
- precision , recall , _ = skm .precision_recall_curve (
240
+ prec_curve , recall_curve , thresholds = skm .precision_recall_curve (
172
241
targets [no_nan ], curve_probs [no_nan ]
173
242
)
243
+ # f1 scores for each threshold
244
+ f1_curve = 2 * (prec_curve * recall_curve ) / (prec_curve + recall_curve )
245
+ f1_curve = np .nan_to_num (f1_curve ) # Handle division by zero
174
246
f1_score = skm .f1_score (targets [no_nan ], np .round (curve_probs [no_nan ]))
175
247
248
+ # append final value since threshold has N-1 elements
249
+ thresholds = [* thresholds , 1.0 ]
250
+
176
251
metrics_str = f"F1={ f1_score :.2f} "
177
252
display_name = f"{ name } · { metrics_str } " if name else metrics_str
178
253
trace_defaults = {
179
- "x" : recall ,
180
- "y" : precision ,
254
+ "x" : recall_curve ,
255
+ "y" : prec_curve ,
181
256
"name" : display_name ,
182
257
"line" : dict (
183
258
width = 2 , dash = PLOTLY_LINE_STYLES [idx % len (PLOTLY_LINE_STYLES )]
@@ -186,9 +261,14 @@ def precision_recall_curve_plotly(
186
261
f"<b>{ display_name } </b><br>"
187
262
"Recall: %{x:.3f}<br>"
188
263
"Prec: %{y:.3f}<br>"
189
- "F1: {f1_score:.3f}<br>"
264
+ "F1: %{customdata.f1:.3f}<br>"
265
+ "Threshold: %{customdata.threshold:.3f}<br>"
190
266
"<extra></extra>"
191
267
),
268
+ "customdata" : [
269
+ dict (threshold = thr , f1 = f1 )
270
+ for thr , f1 in zip (thresholds , f1_curve , strict = True )
271
+ ],
192
272
"meta" : dict (f1_score = f1_score ),
193
273
}
194
274
fig .add_scatter (** trace_defaults | kwargs | trace_kwargs )
@@ -197,19 +277,7 @@ def precision_recall_curve_plotly(
197
277
fig .data = sorted (fig .data , key = lambda tr : tr .meta .get ("f1_score" ), reverse = True )
198
278
199
279
# No-skill baseline (has 100 points so whole line is hoverable, not just end points)
200
- no_skill_line = dict (color = "gray" , width = 2 , dash = "dash" )
201
- fig .add_scatter (
202
- x = np .linspace (0 , 1 , 100 ),
203
- y = np .full_like (np .linspace (0 , 1 , 100 ), 0.5 ),
204
- name = "No skill" ,
205
- line = no_skill_line ,
206
- hovertemplate = (
207
- "<b>No skill</b><br>"
208
- "Recall: %{x:.3f}<br>"
209
- "Prec: %{y:.3f}<br>"
210
- "<extra></extra>"
211
- ),
212
- )
280
+ _add_no_skill_line (fig , y_values = np .full (100 , 0.5 ), scatter_kwargs = no_skill )
213
281
214
282
fig .layout .legend .update (yanchor = "bottom" , y = 0 , xanchor = "left" , x = 0 )
215
283
fig .layout .update (xaxis_title = "Recall" , yaxis_title = "Precision" )
0 commit comments