1
1
from __future__ import annotations
2
2
3
- from collections .abc import Sequence
4
3
from typing import Any , Literal , get_args
5
4
6
5
import matplotlib .pyplot as plt
17
16
18
17
StabilityCriterion = Literal ["energy" , "energy+std" , "energy-std" ]
19
18
WhichEnergy = Literal ["true" , "pred" ]
19
+ AxLine = Literal ["x" , "y" , "xy" , "" ]
20
20
21
21
22
22
# --- define global plot settings
53
53
54
54
55
55
plt .rc ("font" , size = 14 )
56
+ plt .rc ("legend" , fontsize = 16 )
56
57
plt .rc ("savefig" , bbox = "tight" , dpi = 200 )
57
58
plt .rc ("figure" , dpi = 200 , titlesize = 16 )
58
59
plt .rcParams ["figure.constrained_layout.use" ] = True
@@ -282,16 +283,18 @@ def rolling_mae_vs_hull_dist(
282
283
return ax
283
284
284
285
285
- def precision_recall_vs_calc_count (
286
+ def cumulative_clf_metric (
286
287
e_above_hull_error : pd .Series ,
287
288
e_above_hull_true : pd .Series ,
289
+ metric : Literal ["precision" , "recall" ],
288
290
std_pred : pd .Series = None ,
289
291
stability_crit : StabilityCriterion = "energy" ,
290
292
stability_threshold : float = 0 , # set stability threshold as distance to convex
291
293
# hull in eV / atom, usually 0 or 0.1 eV
292
294
ax : plt .Axes = None ,
293
295
label : str = None ,
294
- intersect_lines : str | Sequence [str ] = (),
296
+ project_end_point : AxLine = "xy" ,
297
+ show_optimal : bool = False ,
295
298
** kwargs : Any ,
296
299
) -> plt .Axes :
297
300
"""Precision and recall as a function of the number of included materials sorted
@@ -305,26 +308,27 @@ def precision_recall_vs_calc_count(
305
308
predictions, i.e. residual = pred - target. Defaults to "residual".
306
309
e_above_hull_true (str, optional): Column name with convex hull distance values.
307
310
Defaults to "e_above_hull".
311
+ metric ('precision' | 'recall', optional): Metric to plot.
308
312
stability_crit ('energy' | 'energy+std' | 'energy-std', optional): Whether to
309
313
use energy+/-std as stability stability_crit where std is the model
310
314
predicted uncertainty for the energy it stipulated. Defaults to "energy".
311
315
stability_threshold (float, optional): Max distance from convex hull before
312
316
material is considered unstable. Defaults to 0.
313
317
label (str, optional): Model name used to identify its liens in the legend.
314
318
Defaults to None.
315
- intersect_lines (Sequence[str], optional): precision_{x,y,xy} and/or
316
- recall_{x,y,xy}. Defaults to (), i.e. no intersect lines.
319
+ project_end_point ('x' | 'y' | 'xy' | '', optional): Defaults to '', i.e. no
320
+ axis projection lines.
321
+ show_optimal (bool, optional): Whether to plot the optimal precision/recall
322
+ line. Defaults to False.
317
323
318
324
Returns:
319
325
plt.Axes: The matplotlib axes object.
320
326
"""
321
327
ax = ax or plt .gca ()
322
328
323
- # for series in (e_above_hull_error, e_above_hull_true):
324
- # n_nans = series.isna().sum()
325
- # assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"
326
-
327
- is_fresh_ax = len (ax .lines ) == 0
329
+ for series in (e_above_hull_error , e_above_hull_true ):
330
+ n_nans = series .isna ().sum ()
331
+ assert n_nans == 0 , f"{ n_nans :,} NaNs in { series .name } "
328
332
329
333
e_above_hull_error = e_above_hull_error .sort_values ()
330
334
e_above_hull_true = e_above_hull_true .loc [e_above_hull_error .index ]
@@ -338,10 +342,6 @@ def precision_recall_vs_calc_count(
338
342
elif stability_crit == "energy-std" :
339
343
e_above_hull_error -= std_pred
340
344
341
- # stability_threshold = 0.02
342
- stability_threshold = 0
343
- # stability_threshold = 0.10
344
-
345
345
true_pos_mask = (e_above_hull_true <= stability_threshold ) & (
346
346
e_above_hull_error <= stability_threshold
347
347
)
@@ -362,68 +362,56 @@ def precision_recall_vs_calc_count(
362
362
true_pos_rate = true_pos_cumsum / n_total_pos * 100
363
363
364
364
end = int (np .argmax (true_pos_rate ))
365
-
366
365
xs = np .arange (end )
367
366
368
- precision_curve = scipy .interpolate .interp1d (xs , precision [:end ], kind = "cubic" )
369
- rolling_recall_curve = scipy .interpolate .interp1d (
370
- xs , true_pos_rate [:end ], kind = "cubic"
371
- )
367
+ ys_raw = dict (precision = precision , recall = true_pos_rate )[metric ]
368
+ y_interp = scipy .interpolate .interp1d (xs , ys_raw [:end ], kind = "cubic" )
369
+ ys = y_interp (xs )
372
370
373
371
line_kwargs = dict (
374
- linewidth = 4 ,
375
- markevery = [- 1 ],
376
- marker = "x" ,
377
- markersize = 14 ,
378
- markeredgewidth = 2.5 ,
379
- ** kwargs ,
380
- )
381
- ax .plot (xs , precision_curve (xs ), linestyle = "-" , ** line_kwargs )
382
- ax .plot (xs , rolling_recall_curve (xs ), linestyle = ":" , ** line_kwargs )
383
- ax .plot ((0 , 0 ), (0 , 0 ), label = label , ** line_kwargs )
384
-
385
- if intersect_lines == "all" :
386
- intersect_lines = ("precision_xy" , "recall_xy" )
387
- if isinstance (intersect_lines , str ):
388
- intersect_lines = [intersect_lines ]
389
- for line_name in intersect_lines :
390
- try :
391
- line_name_map = dict (precision = precision_curve , recall = rolling_recall_curve )
392
- y_func = line_name_map [line_name .split ("_" )[0 ]]
393
- except KeyError :
394
- raise ValueError (
395
- f"Invalid { intersect_lines = } , must be one of { list (line_name_map )} "
396
- )
397
- intersect_kwargs = dict (
398
- linestyle = ":" , alpha = 0.4 , color = kwargs .get ("color" , "gray" )
399
- )
400
- # Add some visual guidelines
401
- if "x" in line_name :
402
- ax .plot ((0 , xs [- 1 ]), (y_func (xs [- 1 ]), y_func (xs [- 1 ])), ** intersect_kwargs )
403
- if "y" in line_name :
404
- ax .plot ((xs [- 1 ], xs [- 1 ]), (0 , y_func (xs [- 1 ])), ** intersect_kwargs )
405
-
406
- if not is_fresh_ax :
407
- # return earlier if all plot objects besides the line were already drawn by a
408
- # previous call
409
- return ax
410
-
411
- xlabel = "Number of compounds sorted by model-predicted hull distance"
412
- ylabel = "Precision and Recall (%)"
413
- ax .set (ylim = (0 , 100 ), xlabel = xlabel , ylabel = ylabel )
414
-
415
- [precision ] = ax .plot (
416
- (0 , 0 ), (0 , 0 ), "black" , linestyle = "-" , linewidth = line_kwargs ["linewidth" ]
417
- )
418
- [recall ] = ax .plot (
419
- (0 , 0 ), (0 , 0 ), "black" , linestyle = ":" , linewidth = line_kwargs ["linewidth" ]
372
+ linewidth = 2 , markevery = [- 1 ], marker = "x" , markersize = 14 , markeredgewidth = 2.5
420
373
)
421
- legend = ax .legend (
422
- [precision , recall ],
423
- ("Precision" , "Recall" ),
424
- frameon = False ,
425
- loc = "upper right" ,
374
+ ax .plot (xs , ys , ** line_kwargs | kwargs )
375
+ ax .text (
376
+ xs [- 1 ],
377
+ ys [- 1 ],
378
+ label ,
379
+ color = kwargs .get ("color" ),
380
+ verticalalignment = "bottom" ,
381
+ rotation = 30 ,
382
+ bbox = dict (facecolor = "white" , alpha = 0.5 , edgecolor = "none" ),
426
383
)
427
- ax .add_artist (legend )
384
+
385
+ # add some visual guidelines
386
+ intersect_kwargs = dict (linestyle = ":" , alpha = 0.4 , color = kwargs .get ("color" ))
387
+ if "x" in project_end_point :
388
+ ax .plot ((0 , xs [- 1 ]), (ys [- 1 ], ys [- 1 ]), ** intersect_kwargs )
389
+ if "y" in project_end_point :
390
+ ax .plot ((xs [- 1 ], xs [- 1 ]), (0 , ys [- 1 ]), ** intersect_kwargs )
391
+
392
+ ax .set (ylim = (0 , 100 ), ylabel = f"{ metric .title ()} (%)" )
393
+
394
+ # optimal recall line finds all stable materials without any false positives
395
+ # can be included to confirm all models start out of with near optimal recall
396
+ # and to see how much each model overshoots total n_stable
397
+ n_below_hull = sum (e_above_hull_true < 0 )
398
+ if show_optimal :
399
+ ax .plot (
400
+ [0 , n_below_hull ],
401
+ [0 , 100 ],
402
+ color = "green" ,
403
+ linestyle = "dashed" ,
404
+ linewidth = 1 ,
405
+ label = f"Optimal { metric .title ()} " ,
406
+ )
407
+ ax .text (
408
+ n_below_hull ,
409
+ 100 ,
410
+ label ,
411
+ color = kwargs .get ("color" ),
412
+ verticalalignment = "top" ,
413
+ rotation = - 30 ,
414
+ bbox = dict (facecolor = "white" , alpha = 0.5 , edgecolor = "none" ),
415
+ )
428
416
429
417
return ax
0 commit comments