1
1
from __future__ import annotations
2
2
3
+ import math
3
4
from typing import Any , Literal
4
5
5
6
import matplotlib .pyplot as plt
@@ -80,11 +81,11 @@ def hist_classified_stable_vs_hull_dist(
80
81
ax : plt .Axes = None ,
81
82
which_energy : WhichEnergy = "true" ,
82
83
stability_threshold : float = 0 ,
83
- show_threshold : bool = True ,
84
84
x_lim : tuple [float | None , float | None ] = (- 0.4 , 0.4 ),
85
85
rolling_accuracy : float | None = 0.02 ,
86
86
backend : Backend = "plotly" ,
87
87
ylabel : str = "Number of materials" ,
88
+ ** kwargs : Any ,
88
89
) -> tuple [plt .Axes | go .Figure , dict [str , float ]]:
89
90
"""
90
91
Histogram of the energy difference (either according to DFT ground truth [default]
@@ -108,13 +109,13 @@ def hist_classified_stable_vs_hull_dist(
108
109
distance or the model's predicted hull distance for the histogram.
109
110
stability_threshold (float, optional): set stability threshold as distance to
110
111
convex hull in eV/atom, usually 0 or 0.1 eV.
111
- show_threshold (bool, optional): Whether to plot stability threshold as dashed
112
- vertical line.
113
112
x_lim (tuple[float | None, float | None]): x-axis limits.
114
113
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to None
115
114
or 0 to disable. Defaults to 0.02, meaning 20 meV / atom.
116
115
backend ('matplotlib' | 'plotly'], optional): Which plotting backend to use.
117
116
Changes the return type.
117
+ kwargs: Additional keyword arguments passed to the ax.hist() or px.histogram()
118
+ depending on backend.
118
119
119
120
Returns:
120
121
tuple[plt.Axes, dict[str, float]]: plot axes and classification metrics
@@ -159,15 +160,17 @@ def hist_classified_stable_vs_hull_dist(
159
160
color = ["tab:green" , "tab:orange" , "tab:red" , "tab:blue" ],
160
161
label = labels ,
161
162
stacked = True ,
163
+ ** kwargs ,
162
164
)
163
165
ax .set (xlabel = xlabel , ylabel = ylabel , xlim = x_lim )
164
166
165
- ax .axvline (
166
- stability_threshold ,
167
- color = "black" ,
168
- linestyle = "--" ,
169
- label = "Stability Threshold" ,
170
- )
167
+ if stability_threshold is not None :
168
+ ax .axvline (
169
+ stability_threshold ,
170
+ color = "black" ,
171
+ linestyle = "--" ,
172
+ label = "Stability Threshold" ,
173
+ )
171
174
172
175
if rolling_accuracy :
173
176
# add moving average of the accuracy computed within given window
@@ -203,28 +206,35 @@ def hist_classified_stable_vs_hull_dist(
203
206
# )
204
207
205
208
if backend == "plotly" :
206
- clf = (true_pos * 1 + false_neg * 2 + false_pos * 3 + true_neg * 4 ).map (
209
+ clf = (true_pos + false_neg * 2 + false_pos * 3 + true_neg * 4 ).map (
207
210
dict (zip (range (1 , 5 ), labels ))
208
211
)
209
212
df = pd .DataFrame (dict (e_above_hull = e_above_hull , clf = clf ))
210
213
211
214
ax = px .histogram (
212
- df , x = "e_above_hull" , color = "clf" , nbins = 20000 , range_x = x_lim , opacity = 0.9
215
+ df ,
216
+ x = "e_above_hull" ,
217
+ color = "clf" ,
218
+ nbins = 20000 ,
219
+ range_x = x_lim ,
220
+ opacity = 0.9 ,
221
+ ** kwargs ,
213
222
)
214
223
ax .update_layout (
215
224
dict (xaxis_title = xlabel , yaxis_title = ylabel ),
216
225
legend = dict (title = None , yanchor = "top" , y = 1 , xanchor = "right" , x = 1 ),
217
226
)
218
227
219
- ax .add_vline (stability_threshold , line = dict (dash = "dash" , width = 1 ))
220
- ax .add_annotation (
221
- text = "Stability threshold" ,
222
- x = stability_threshold ,
223
- y = 1.1 ,
224
- yref = "paper" ,
225
- font = dict (size = 14 , color = "gray" ),
226
- showarrow = False ,
227
- )
228
+ if stability_threshold is not None :
229
+ ax .add_vline (stability_threshold , line = dict (dash = "dash" , width = 1 ))
230
+ ax .add_annotation (
231
+ text = "Stability threshold" ,
232
+ x = stability_threshold ,
233
+ y = 1.1 ,
234
+ yref = "paper" ,
235
+ font = dict (size = 14 , color = "gray" ),
236
+ showarrow = False ,
237
+ )
228
238
229
239
recall = n_true_pos / n_total_pos
230
240
return ax , dict (
@@ -341,115 +351,141 @@ def rolling_mae_vs_hull_dist(
341
351
return ax
342
352
343
353
344
- def cumulative_clf_metric (
354
+ def cumulative_precision_recall (
345
355
e_above_hull_true : pd .Series ,
346
- e_above_hull_pred : pd .Series ,
347
- metric : Literal ["precision" , "recall" ],
356
+ df_preds : pd .DataFrame ,
348
357
stability_threshold : float = 0 , # set stability threshold as distance to convex
349
358
# hull in eV / atom, usually 0 or 0.1 eV
350
- ax : plt .Axes = None ,
351
- label : str = None ,
352
359
project_end_point : AxLine = "xy" ,
353
360
show_optimal : bool = False ,
361
+ backend : Backend = "plotly" ,
354
362
** kwargs : Any ,
355
- ) -> plt .Axes :
356
- """Precision and recall as a function of the number of included materials sorted
357
- by model-predicted distance to the convex hull, i.e. materials predicted most stable
358
- enter the precision and recall calculation first. The curves end when all materials
359
- predicted stable are included.
363
+ ) -> tuple [plt .Figure | go .Figure , pd .DataFrame ]:
364
+ """Create 2 subplots side-by-side with cumulative precision and recall curves for
365
+ all models starting with materials predicted most stable, adding the next material,
366
+ recomputing the cumulative metrics, adding the next most stable material and so on
367
+ until each model no longer predicts the material to be stable. Again, materials
368
+ predicted more stable enter the precision and recall calculation sooner. Different
369
+ models predict different number of materials to be stable. Hence the curves end at
370
+ different points.
360
371
361
372
Args:
362
373
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
363
374
ground truth (in eV / atom).
364
- e_above_hull_pred (pd.Series): Distance to convex hull predicted by model
365
- (in eV / atom). Same as true energy to convex hull plus predicted minus true
366
- formation energy.
367
- metric ('precision' | 'recall', optional): Metric to plot.
368
- stability_threshold (float, optional): Max distance from convex hull before
375
+ df_preds (pd.DataFrame): Distance to convex hull predicted by models, one column
376
+ per model (in eV / atom). Same as true energy to convex hull plus predicted
377
+ minus true formation energy.
378
+ stability_threshold (float, optional): Max distance above convex hull before
369
379
material is considered unstable. Defaults to 0.
370
- label (str, optional): Model name used to identify its liens in the legend.
371
- Defaults to None.
372
- project_end_point ('x' | 'y' | 'xy' | '', optional): Defaults to '', i.e. no
380
+ project_end_point ('x' | 'y' | 'xy' | '', optional): Whether to project end
381
+ points of precision and recall curves to the x/y axis. Defaults to '', i.e. no
373
382
axis projection lines.
374
- show_optimal (bool, optional): Whether to plot the optimal precision/recall
375
- line. Defaults to False.
376
- **kwargs: Keyword arguments passed to ax.plot().
383
+ show_optimal (bool, optional): Whether to plot the optimal recall line. Defaults
384
+ to False.
385
+ backend ('plotly' | 'matplotlib', optional): Defaults to 'plotly'. **kwargs:
386
+ Keyword arguments passed to df.plot().
377
387
378
388
Returns:
379
- plt.Axes: The matplotlib axes object.
389
+ tuple[plt.Figure | go.Figure, pd.DataFrame]: The matplotlib/plotly figure and
390
+ dataframe of cumulative metrics for each model.
380
391
"""
381
- ax = ax or plt .gca ()
392
+ fact = lambda : pd .DataFrame (index = range (len (e_above_hull_true )))
393
+ dfs = dict (Precision = fact (), Recall = fact ())
382
394
383
- e_above_hull_pred = e_above_hull_pred .sort_values ()
384
- e_above_hull_true = e_above_hull_true .loc [e_above_hull_pred .index ]
395
+ for model_name in df_preds :
396
+ model_preds = df_preds [model_name ].sort_values ()
397
+ e_above_hull_true = e_above_hull_true .loc [model_preds .index ]
385
398
386
- true_pos , false_neg , false_pos , _true_neg = classify_stable (
387
- e_above_hull_true , e_above_hull_pred , stability_threshold
388
- )
399
+ true_pos , false_neg , false_pos , _true_neg = classify_stable (
400
+ e_above_hull_true , model_preds , stability_threshold
401
+ )
389
402
390
- true_pos_cumsum = true_pos .cumsum ()
403
+ true_pos_cumsum = true_pos .cumsum ()
404
+ # precision aka positive predictive value (PPV)
405
+ precision = true_pos_cumsum / (true_pos_cumsum + false_pos .cumsum ())
406
+ n_total_pos = sum (true_pos ) + sum (false_neg )
407
+ recall = true_pos_cumsum / n_total_pos # aka true_pos_rate aka sensitivity
391
408
392
- # precision aka positive predictive value (PPV)
393
- precision = true_pos_cumsum / (true_pos_cumsum + false_pos .cumsum ()) * 100
394
- n_true_pos = sum (true_pos )
395
- n_false_neg = sum (false_neg )
396
- n_total_pos = n_true_pos + n_false_neg
397
- true_pos_rate = true_pos_cumsum / n_total_pos * 100
409
+ end = int (np .argmax (recall ))
410
+ xs = np .arange (end )
398
411
399
- end = int (np .argmax (true_pos_rate ))
400
- xs = np .arange (end )
412
+ prec_interp = scipy .interpolate .interp1d (xs , precision [:end ], kind = "cubic" )
413
+ recall_interp = scipy .interpolate .interp1d (xs , recall [:end ], kind = "cubic" )
414
+ dfs ["Precision" ][model_name ] = pd .Series (prec_interp (xs ))
415
+ dfs ["Recall" ][model_name ] = pd .Series (recall_interp (xs ))
401
416
402
- ys_raw = dict (precision = precision , recall = true_pos_rate )[metric ]
403
- y_interp = scipy .interpolate .interp1d (xs , ys_raw [:end ], kind = "cubic" )
404
- ys = y_interp (xs )
417
+ for key , df in dfs .items ():
418
+ # drop all-NaN rows so plotly plot x-axis only extends to largest number of
419
+ # predicted materials by any model
420
+ df .dropna (how = "all" , inplace = True )
421
+ df ["metric" ] = key
405
422
406
- line_kwargs = dict (
407
- linewidth = 2 , markevery = [- 1 ], marker = "x" , markersize = 14 , markeredgewidth = 2.5
408
- )
409
- ax .plot (xs , ys , ** line_kwargs | kwargs )
410
- ax .text (
411
- xs [- 1 ],
412
- ys [- 1 ],
413
- label ,
414
- color = kwargs .get ("color" ),
415
- verticalalignment = "bottom" ,
416
- rotation = 30 ,
417
- bbox = dict (facecolor = "white" , alpha = 0.5 , edgecolor = "none" ),
418
- )
423
+ df_cum = pd .concat (dfs .values ())
419
424
420
- # add some visual guidelines
421
- intersect_kwargs = dict (linestyle = ":" , alpha = 0.4 , color = kwargs .get ("color" ))
422
- if "x" in project_end_point :
423
- ax .plot ((0 , xs [- 1 ]), (ys [- 1 ], ys [- 1 ]), ** intersect_kwargs )
424
- if "y" in project_end_point :
425
- ax .plot ((xs [- 1 ], xs [- 1 ]), (0 , ys [- 1 ]), ** intersect_kwargs )
426
-
427
- ax .set (ylim = (0 , 100 ), ylabel = f"{ metric .title ()} (%)" )
428
-
429
- # optimal recall line finds all stable materials without any false positives
430
- # can be included to confirm all models start out of with near optimal recall
431
- # and to see how much each model overshoots total n_stable
432
- n_below_hull = sum (e_above_hull_true < 0 )
433
- if show_optimal :
434
- ax .plot (
435
- [0 , n_below_hull ],
436
- [0 , 100 ],
437
- color = "green" ,
438
- linestyle = "dashed" ,
439
- linewidth = 1 ,
440
- label = f"Optimal { metric .title ()} " ,
425
+ if backend == "matplotlib" :
426
+ fig , axs = plt .subplots (1 , 2 , figsize = (15 , 7 ), sharey = True )
427
+ line_kwargs = dict (
428
+ linewidth = 3 , markevery = [- 1 ], marker = "x" , markersize = 14 , markeredgewidth = 2.5
441
429
)
442
- ax .text (
443
- n_below_hull ,
444
- 100 ,
445
- label ,
446
- color = kwargs .get ("color" ),
447
- verticalalignment = "top" ,
448
- rotation = - 30 ,
449
- bbox = dict (facecolor = "white" , alpha = 0.5 , edgecolor = "none" ),
430
+ for (key , df ), ax in zip (dfs .items (), axs ):
431
+ # select every n-th row of df so that 1000 rows are left for increased
432
+ # plotting speed and reduced file size
433
+ # falls back on every row if df has less than 1000 rows
434
+
435
+ df .iloc [:: len (df ) // 1000 or 1 ].plot (
436
+ ax = ax , legend = False , backend = backend , ** line_kwargs | kwargs , ylabel = key
437
+ )
438
+
439
+ # add some visual guidelines
440
+ intersect_kwargs = dict (linestyle = ":" , alpha = 0.4 , linewidth = 2 )
441
+ bbox = dict (facecolor = "white" , alpha = 0.5 , edgecolor = "none" )
442
+ assert len (axs ) == len (dfs ), f"{ len (axs )} != { len (dfs )} "
443
+
444
+ for ax , df in zip (axs , dfs .values ()):
445
+ ax .set (ylim = (0 , 1 ), xlim = (0 , None ), ylabel = key )
446
+ for model in df_preds :
447
+ x_end = df [model ].dropna ().index [- 1 ]
448
+ y_end = df [model ].dropna ().iloc [- 1 ]
449
+ # place model name at the end of every line
450
+ ax .text (x_end , y_end , model , va = "bottom" , rotation = 30 , bbox = bbox )
451
+ if "x" in project_end_point :
452
+ ax .plot ((x_end , x_end ), (0 , y_end ), ** intersect_kwargs )
453
+ if "y" in project_end_point :
454
+ ax .plot ((0 , x_end ), (y_end , y_end ), ** intersect_kwargs )
455
+
456
+ # optimal recall line finds all stable materials without any false positives
457
+ # can be included to confirm all models start out of with near optimal recall
458
+ # and to see how much each model overshoots total n_stable
459
+ n_below_hull = sum (e_above_hull_true < 0 )
460
+ if show_optimal :
461
+ opt_label = "Optimal Recall"
462
+ axs [1 ].plot ([0 , n_below_hull ], [0 , 1 ], color = "green" , linestyle = "--" )
463
+ axs [1 ].text (
464
+ * [n_below_hull , 0.81 ],
465
+ opt_label ,
466
+ color = "green" ,
467
+ va = "bottom" ,
468
+ ha = "right" ,
469
+ rotation = math .degrees (math .cos (math .atan (1 / n_below_hull ))),
470
+ bbox = bbox ,
471
+ )
472
+
473
+ elif backend == "plotly" :
474
+ fig = df_cum .iloc [:: len (df_cum ) // 1000 or 1 ].plot (
475
+ backend = backend , facet_col = "metric" , ** kwargs
450
476
)
477
+ fig .update_traces (line = dict (width = 4 ))
478
+ for idx in range (1 , 3 ):
479
+ fig .update_xaxes (
480
+ title_text = "Number of materials predicted stable" , row = 1 , col = idx
481
+ )
482
+ fig .update_yaxes (title = "Precision" , col = 1 )
483
+ fig .update_yaxes (title = "Recall" , col = 2 )
484
+ fig .for_each_annotation (lambda a : a .update (text = "" ))
485
+ fig .update_layout (legend = dict (title = "" ))
486
+ fig .update_layout (showlegend = False )
451
487
452
- return ax
488
+ return fig , df_cum
453
489
454
490
455
491
def wandb_scatter (table : wandb .Table , fields : dict [str , str ], ** kwargs : Any ) -> None :
0 commit comments