1
1
from __future__ import annotations
2
2
3
- from typing import Any , Literal , Sequence
3
+ from typing import Any , Literal , Sequence , get_args
4
4
5
5
import matplotlib .pyplot as plt
6
6
import numpy as np
13
13
__author__ = "Janosh Riebesell"
14
14
__date__ = "2022-08-05"
15
15
16
+ StabilityCriterion = Literal ["energy" , "energy+std" , "energy-std" ]
16
17
17
18
plt .rc ("savefig" , bbox = "tight" , dpi = 200 )
18
19
plt .rcParams ["figure.constrained_layout.use" ] = True
@@ -27,10 +28,10 @@ def hist_classified_stable_as_func_of_hull_dist(
27
28
e_above_hull_col : str ,
28
29
ax : plt .Axes = None ,
29
30
energy_type : Literal ["true" , "pred" ] = "true" ,
30
- criterion : Literal [ "energy" , "std" , "neg_std" ] = "energy" ,
31
+ stability_crit : StabilityCriterion = "energy" ,
31
32
show_mae : bool = False ,
32
- stability_thresh : float = 0 , # set stability threshold as distance to convex hull
33
- # in eV / atom, usually 0 or 0.1 eV
33
+ stability_threshold : float = 0 , # set stability threshold as distance to convex
34
+ # hull in eV / atom, usually 0 or 0.1 eV
34
35
x_lim : tuple [float , float ] = (- 0.4 , 0.4 ),
35
36
) -> plt .Axes :
36
37
"""
@@ -52,28 +53,28 @@ def hist_classified_stable_as_func_of_hull_dist(
52
53
53
54
error = df [pred_cols ].mean (axis = 1 ) - df [target_col ]
54
55
e_above_hull_vals = df [e_above_hull_col ]
55
- mean = error + e_above_hull_vals
56
+ residuals = error + e_above_hull_vals
56
57
57
- if criterion == "energy" :
58
- test = mean
59
- elif "std" in criterion :
58
+ if stability_crit == "energy" :
59
+ test = residuals
60
+ elif "std" in stability_crit :
60
61
# TODO column names to compute standard deviation from are currently hardcoded
61
62
# needs to be updated when adding non-aviary models with uncertainty estimation
62
63
var_aleatoric = (df .filter (like = "_ale_" ) ** 2 ).mean (axis = 1 )
63
64
var_epistemic = df .filter (regex = r"_pred_\d" ).var (axis = 1 , ddof = 0 )
64
65
std_total = (var_epistemic + var_aleatoric ) ** 0.5
65
66
66
- if criterion == "std" :
67
+ if stability_crit == "energy+ std" :
67
68
test += std_total
68
- elif criterion == "neg_std " :
69
+ elif stability_crit == "energy-std " :
69
70
test -= std_total
70
71
71
72
# --- histogram by DFT-computed distance to convex hull
72
73
if energy_type == "true" :
73
- actual_pos = e_above_hull_vals <= stability_thresh
74
- actual_neg = e_above_hull_vals > stability_thresh
75
- model_pos = test <= stability_thresh
76
- model_neg = test > stability_thresh
74
+ actual_pos = e_above_hull_vals <= stability_threshold
75
+ actual_neg = e_above_hull_vals > stability_threshold
76
+ model_pos = test <= stability_threshold
77
+ model_neg = test > stability_threshold
77
78
78
79
n_true_pos = len (e_above_hull_vals [actual_pos & model_pos ])
79
80
n_false_neg = len (e_above_hull_vals [actual_pos & model_neg ])
@@ -89,10 +90,10 @@ def hist_classified_stable_as_func_of_hull_dist(
89
90
90
91
# --- histogram by model-predicted distance to convex hull
91
92
if energy_type == "pred" :
92
- true_pos = mean [actual_pos & model_pos ]
93
- false_neg = mean [actual_pos & model_neg ]
94
- false_pos = mean [actual_neg & model_pos ]
95
- true_neg = mean [actual_neg & model_neg ]
93
+ true_pos = residuals [actual_pos & model_pos ]
94
+ false_neg = residuals [actual_pos & model_neg ]
95
+ false_pos = residuals [actual_neg & model_pos ]
96
+ true_neg = residuals [actual_neg & model_neg ]
96
97
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
97
98
98
99
ax .hist (
@@ -163,6 +164,10 @@ def rolling_mae_vs_hull_dist(
163
164
if ax is None :
164
165
ax = plt .gca ()
165
166
167
+ for col in (residual_col , e_above_hull_col ):
168
+ n_nans = df [col ].isna ().sum ()
169
+ assert n_nans == 0 , f"{ n_nans } NaNs in { col } column"
170
+
166
171
is_fresh_ax = len (ax .lines ) == 0
167
172
168
173
bins = np .arange (* x_lim , increment )
@@ -256,9 +261,9 @@ def precision_recall_vs_calc_count(
256
261
df : pd .DataFrame ,
257
262
residual_col : str = "residual" ,
258
263
e_above_hull_col : str = "e_above_hull" ,
259
- criterion : Literal [ "energy" , "std" , "neg_std" ] = "energy" ,
260
- stability_thresh : float = 0 , # set stability threshold as distance to convex hull
261
- # in eV / atom, usually 0 or 0.1 eV
264
+ stability_crit : StabilityCriterion = "energy" ,
265
+ stability_threshold : float = 0 , # set stability threshold as distance to convex
266
+ # hull in eV / atom, usually 0 or 0.1 eV
262
267
ax : plt .Axes = None ,
263
268
label : str = None ,
264
269
intersect_lines : str | Sequence [str ] = (),
@@ -272,10 +277,10 @@ def precision_recall_vs_calc_count(
272
277
i.e. residual = pred - target. Defaults to "residual".
273
278
e_above_hull_col (str, optional): Column name with convex hull distance values.
274
279
Defaults to "e_above_hull".
275
- criterion (Literal[ 'energy', ' std', 'neg_std'] , optional): Whether to use
276
- energy, energy+model_std, or energy-model_std as stability criterion.
277
- Defaults to "energy".
278
- stability_thresh (float, optional): Max distance from convex hull before
280
+ stability_crit ( 'energy' | 'energy+ std' | 'energy-std' , optional): Whether to
281
+ use energy+/-std as stability stability_crit where std is the model
282
+ predicted uncertainty for the energy it stipulated. Defaults to "energy".
283
+ stability_threshold (float, optional): Max distance from convex hull before
279
284
material is considered unstable. Defaults to 0.
280
285
label (str, optional): Model name used to identify its liens in the legend.
281
286
Defaults to None.
@@ -288,36 +293,43 @@ def precision_recall_vs_calc_count(
288
293
if ax is None :
289
294
ax = plt .gca ()
290
295
296
+ for col in (residual_col , e_above_hull_col ):
297
+ n_nans = df [col ].isna ().sum ()
298
+ assert n_nans == 0 , f"{ n_nans } NaNs in { col } column"
299
+
291
300
is_fresh_ax = len (ax .lines ) == 0
292
301
293
302
df = df .sort_values (by = "residual" )
303
+ residuals = df [residual_col ]
294
304
295
- if criterion == "energy" :
296
- test = df [residual_col ]
297
- elif "std" in criterion :
305
+ if stability_crit not in get_args (StabilityCriterion ):
306
+ raise ValueError (
307
+ f"Invalid { stability_crit = } must be one of { get_args (StabilityCriterion )} "
308
+ )
309
+ if "std" in stability_crit :
298
310
# TODO column names to compute standard deviation from are currently hardcoded
299
311
# needs to be updated when adding non-aviary models with uncertainty estimation
300
312
var_aleatoric = (df .filter (like = "_ale_" ) ** 2 ).mean (axis = 1 )
301
313
var_epistemic = df .filter (regex = r"_pred_\d" ).var (axis = 1 , ddof = 0 )
302
314
std_total = (var_epistemic + var_aleatoric ) ** 0.5
303
315
304
- if criterion == "std" :
305
- test += std_total
306
- elif criterion == "neg_std " :
307
- test -= std_total
316
+ if stability_crit == "energy+ std" :
317
+ residuals += std_total
318
+ elif stability_crit == "energy-std " :
319
+ residuals -= std_total
308
320
309
- # stability_thresh = 0.02
310
- stability_thresh = 0
311
- # stability_thresh = 0.10
321
+ # stability_threshold = 0.02
322
+ stability_threshold = 0
323
+ # stability_threshold = 0.10
312
324
313
- true_pos_mask = (df [e_above_hull_col ] <= stability_thresh ) & (
314
- df .residual <= stability_thresh
325
+ true_pos_mask = (df [e_above_hull_col ] <= stability_threshold ) & (
326
+ df .residual <= stability_threshold
315
327
)
316
- false_neg_mask = (df [e_above_hull_col ] <= stability_thresh ) & (
317
- df .residual > stability_thresh
328
+ false_neg_mask = (df [e_above_hull_col ] <= stability_threshold ) & (
329
+ df .residual > stability_threshold
318
330
)
319
- false_pos_mask = (df [e_above_hull_col ] > stability_thresh ) & (
320
- df .residual <= stability_thresh
331
+ false_pos_mask = (df [e_above_hull_col ] > stability_threshold ) & (
332
+ df .residual <= stability_threshold
321
333
)
322
334
323
335
true_pos_cumsum = true_pos_mask .cumsum ()
@@ -349,11 +361,16 @@ def precision_recall_vs_calc_count(
349
361
350
362
if intersect_lines == "all" :
351
363
intersect_lines = ("precision_xy" , "recall_xy" )
364
+ if isinstance (intersect_lines , str ):
365
+ intersect_lines = [intersect_lines ]
352
366
for line_name in intersect_lines :
353
- y_func = dict (
354
- precision = precision_curve ,
355
- recall = rolling_recall_curve ,
356
- )[line_name .split ("_" )[0 ]]
367
+ try :
368
+ line_name_map = dict (precision = precision_curve , recall = rolling_recall_curve )
369
+ y_func = line_name_map [line_name .split ("_" )[0 ]]
370
+ except KeyError :
371
+ raise ValueError (
372
+ f"Invalid { intersect_lines = } , must be one of { list (line_name_map )} "
373
+ )
357
374
intersect_kwargs = dict (
358
375
linestyle = ":" , alpha = 0.4 , color = kwargs .get ("color" , "gray" )
359
376
)
@@ -370,7 +387,7 @@ def precision_recall_vs_calc_count(
370
387
371
388
ax .set (xlabel = "Number of Calculations" , ylabel = "Percentage" )
372
389
373
- ax .set (xlim = ( 0 , 8e4 ), ylim = (0 , 100 ))
390
+ ax .set (ylim = (0 , 100 ))
374
391
375
392
[precision ] = ax .plot ((0 , 0 ), (0 , 0 ), "black" , linestyle = "-" )
376
393
[recall ] = ax .plot ((0 , 0 ), (0 , 0 ), "black" , linestyle = ":" )
0 commit comments