12
12
import wandb
13
13
from mpl_toolkits .axes_grid1 .anchored_artists import AnchoredSizeBar
14
14
15
+ from matbench_discovery .energy import classify_stable
16
+
15
17
__author__ = "Janosh Riebesell"
16
18
__date__ = "2022-08-05"
17
19
69
71
70
72
71
73
def hist_classified_stable_vs_hull_dist (
72
- e_above_hull_pred : pd .Series ,
73
74
e_above_hull_true : pd .Series ,
75
+ e_above_hull_pred : pd .Series ,
74
76
ax : plt .Axes = None ,
75
77
which_energy : WhichEnergy = "true" ,
76
78
stability_threshold : float = 0 ,
@@ -90,14 +92,14 @@ def hist_classified_stable_vs_hull_dist(
90
92
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
91
93
92
94
Args:
93
- e_above_hull_pred (pd.Series): energy difference to convex hull predicted by
94
- model, i.e. difference between the model's predicted and true formation
95
- energy.
96
- e_above_hull_true (pd.Series): energy diff to convex hull according to DFT
97
- ground truth .
95
+ e_above_hull_true (pd.Series): Distance to convex hull according to DFT
96
+ ground truth (in eV / atom).
97
+ e_above_hull_pred (pd.Series): Distance to convex hull predicted by model
98
+ (in eV / atom). Same as true energy to convex hull plus predicted minus true
99
+ formation energy .
98
100
ax (plt.Axes, optional): matplotlib axes to plot on.
99
- which_energy (WhichEnergy, optional): Whether to use the true formation energy
100
- or the model's predicted formation energy for the histogram.
101
+ which_energy (WhichEnergy, optional): Whether to use the true (DFT) hull
102
+ distance or the model's predicted hull distance for the histogram.
101
103
stability_threshold (float, optional): set stability threshold as distance to
102
104
convex hull in eV/atom, usually 0 or 0.1 eV.
103
105
show_threshold (bool, optional): Whether to plot stability threshold as dashed
@@ -114,36 +116,28 @@ def hist_classified_stable_vs_hull_dist(
114
116
"""
115
117
ax = ax or plt .gca ()
116
118
117
- test = e_above_hull_pred + e_above_hull_true
118
- # --- histogram of DFT-computed distance to convex hull
119
- if which_energy == "true" :
120
- actual_pos = e_above_hull_true <= stability_threshold
121
- actual_neg = e_above_hull_true > stability_threshold
122
- model_pos = test <= stability_threshold
123
- model_neg = test > stability_threshold
124
-
125
- n_true_pos = len (e_above_hull_true [actual_pos & model_pos ])
126
- n_false_neg = len (e_above_hull_true [actual_pos & model_neg ])
127
-
128
- n_total_pos = n_true_pos + n_false_neg
129
- null = n_total_pos / len (e_above_hull_true )
130
-
131
- true_pos = e_above_hull_true [actual_pos & model_pos ]
132
- false_neg = e_above_hull_true [actual_pos & model_neg ]
133
- false_pos = e_above_hull_true [actual_neg & model_pos ]
134
- true_neg = e_above_hull_true [actual_neg & model_neg ]
135
- xlabel = r"$E_\mathrm{above\ hull}$ (eV / atom)"
136
-
137
- # --- histogram of model-predicted distance to convex hull
138
- if which_energy == "pred" :
139
- true_pos = e_above_hull_pred [actual_pos & model_pos ]
140
- false_neg = e_above_hull_pred [actual_pos & model_neg ]
141
- false_pos = e_above_hull_pred [actual_neg & model_pos ]
142
- true_neg = e_above_hull_pred [actual_neg & model_neg ]
143
- xlabel = r"$\Delta E_{Hull-Pred}$ (eV / atom)"
119
+ true_pos , false_neg , false_pos , true_neg = classify_stable (
120
+ e_above_hull_true , e_above_hull_pred , stability_threshold
121
+ )
122
+ n_true_pos = sum (true_pos )
123
+ n_false_neg = sum (false_neg )
124
+
125
+ n_total_pos = n_true_pos + n_false_neg
126
+ null = n_total_pos / len (e_above_hull_true )
127
+
128
+ # toggle between histogram of DFT-computed/model-predicted distance to convex hull
129
+ e_above_hull = e_above_hull_true if which_energy == "true" else e_above_hull_pred
130
+ eah_true_pos = e_above_hull [true_pos ]
131
+ eah_false_neg = e_above_hull [false_neg ]
132
+ eah_false_pos = e_above_hull [false_pos ]
133
+ eah_true_neg = e_above_hull [true_neg ]
134
+ xlabel = dict (
135
+ true = "$E_\\ mathrm{above\\ hull}$ (eV / atom)" ,
136
+ pred = "$E_\\ mathrm{above\\ hull\\ pred}$ (eV / atom)" ,
137
+ )[which_energy ]
144
138
145
139
ax .hist (
146
- [true_pos , false_neg , false_pos , true_neg ],
140
+ [eah_true_pos , eah_false_neg , eah_false_pos , eah_true_neg ],
147
141
bins = 200 ,
148
142
range = x_lim ,
149
143
alpha = 0.5 ,
@@ -158,7 +152,7 @@ def hist_classified_stable_vs_hull_dist(
158
152
)
159
153
160
154
n_true_pos , n_false_pos , n_true_neg , n_false_neg = map (
161
- len , (true_pos , false_pos , true_neg , false_neg )
155
+ len , (eah_true_pos , eah_false_pos , eah_true_neg , eah_false_neg )
162
156
)
163
157
# null = (tp + fn) / (tp + tn + fp + fn)
164
158
precision = n_true_pos / (n_true_pos + n_false_pos )
@@ -181,8 +175,8 @@ def hist_classified_stable_vs_hull_dist(
181
175
# compute accuracy within 20 meV/atom intervals
182
176
bins = np .arange (x_lim [0 ], x_lim [1 ], rolling_accuracy )
183
177
bin_counts = np .histogram (e_above_hull_true , bins )[0 ]
184
- bin_true_pos = np .histogram (true_pos , bins )[0 ]
185
- bin_true_neg = np .histogram (true_neg , bins )[0 ]
178
+ bin_true_pos = np .histogram (eah_true_pos , bins )[0 ]
179
+ bin_true_neg = np .histogram (eah_true_neg , bins )[0 ]
186
180
187
181
# compute accuracy
188
182
bin_accuracies = (bin_true_pos + bin_true_neg ) / bin_counts
@@ -327,8 +321,8 @@ def rolling_mae_vs_hull_dist(
327
321
328
322
329
323
def cumulative_clf_metric (
330
- e_above_hull_error : pd .Series ,
331
324
e_above_hull_true : pd .Series ,
325
+ e_above_hull_pred : pd .Series ,
332
326
metric : Literal ["precision" , "recall" ],
333
327
stability_threshold : float = 0 , # set stability threshold as distance to convex
334
328
# hull in eV / atom, usually 0 or 0.1 eV
@@ -344,11 +338,11 @@ def cumulative_clf_metric(
344
338
predicted stable are included.
345
339
346
340
Args:
347
- df (pd.DataFrame ): Model predictions and target energy values.
348
- e_above_hull_error (str, optional): Column name with residuals of model
349
- predictions, i.e. residual = pred - target. Defaults to "residual".
350
- e_above_hull_true (str, optional): Column name with convex hull distance values.
351
- Defaults to "e_above_hull" .
341
+ e_above_hull_true (pd.Series ): Distance to convex hull according to DFT
342
+ ground truth (in eV / atom).
343
+ e_above_hull_pred (pd.Series): Distance to convex hull predicted by model
344
+ (in eV / atom). Same as true energy to convex hull plus predicted minus true
345
+ formation energy .
352
346
metric ('precision' | 'recall', optional): Metric to plot.
353
347
stability_threshold (float, optional): Max distance from convex hull before
354
348
material is considered unstable. Defaults to 0.
@@ -365,25 +359,19 @@ def cumulative_clf_metric(
365
359
"""
366
360
ax = ax or plt .gca ()
367
361
368
- e_above_hull_error = e_above_hull_error .sort_values ()
369
- e_above_hull_true = e_above_hull_true .loc [e_above_hull_error .index ]
362
+ e_above_hull_pred = e_above_hull_pred .sort_values ()
363
+ e_above_hull_true = e_above_hull_true .loc [e_above_hull_pred .index ]
370
364
371
- true_pos_mask = (e_above_hull_true <= stability_threshold ) & (
372
- e_above_hull_error <= stability_threshold
373
- )
374
- false_neg_mask = (e_above_hull_true <= stability_threshold ) & (
375
- e_above_hull_error > stability_threshold
376
- )
377
- false_pos_mask = (e_above_hull_true > stability_threshold ) & (
378
- e_above_hull_error <= stability_threshold
365
+ true_pos , false_neg , false_pos , _true_neg = classify_stable (
366
+ e_above_hull_true , e_above_hull_pred , stability_threshold
379
367
)
380
368
381
- true_pos_cumsum = true_pos_mask .cumsum ()
369
+ true_pos_cumsum = true_pos .cumsum ()
382
370
383
371
# precision aka positive predictive value (PPV)
384
- precision = true_pos_cumsum / (true_pos_cumsum + false_pos_mask .cumsum ()) * 100
385
- n_true_pos = sum (true_pos_mask )
386
- n_false_neg = sum (false_neg_mask )
372
+ precision = true_pos_cumsum / (true_pos_cumsum + false_pos .cumsum ()) * 100
373
+ n_true_pos = sum (true_pos )
374
+ n_false_neg = sum (false_neg )
387
375
n_total_pos = n_true_pos + n_false_neg
388
376
true_pos_rate = true_pos_cumsum / n_total_pos * 100
389
377
@@ -443,9 +431,7 @@ def cumulative_clf_metric(
443
431
return ax
444
432
445
433
446
- def wandb_log_scatter (
447
- table : wandb .Table , fields : dict [str , str ], ** kwargs : Any
448
- ) -> None :
434
+ def wandb_scatter (table : wandb .Table , fields : dict [str , str ], ** kwargs : Any ) -> None :
449
435
"""Log a parity scatter plot using custom vega spec to WandB.
450
436
451
437
Args:
0 commit comments