1
- # %%
2
- from typing import Literal
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal , Sequence
3
4
4
5
import matplotlib .pyplot as plt
5
6
import pandas as pd
6
- from matplotlib .offsetbox import AnchoredText
7
7
8
8
9
9
__author__ = "Janosh Riebesell"
10
10
__date__ = "2022-08-05"
11
11
12
- """
13
- Histogram of the energy difference (either according to DFT ground truth [default] or
14
- model predicted energy) to the convex hull for materials in the WBM data set. The
15
- histogram is broken down into true positives, false negatives, false positives, and true
16
- negatives based on whether the model predicts candidates to be below the known convex
17
- hull. Ideally, in discovery setting a model should exhibit high recall, i.e. the
18
- majority of materials below the convex hull being correctly identified by the model.
19
-
20
- See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
21
- """
22
-
23
12
24
13
plt .rc ("savefig" , bbox = "tight" , dpi = 200 )
25
14
plt .rcParams ["figure.constrained_layout.use" ] = True
26
15
plt .rc ("figure" , dpi = 150 )
27
16
plt .rc ("font" , size = 16 )
28
17
29
18
30
- def hist_classify_stable_as_func_of_hull_dist (
31
- # df: pd.DataFrame,
32
- formation_energy_targets : pd .Series ,
33
- formation_energy_preds : pd .Series ,
34
- e_above_hull_vals : pd .Series ,
35
- rare : str = "all" ,
36
- std_vals : pd .Series = None ,
37
- criterion : Literal ["energy" , "std" , "neg" ] = "energy" ,
19
+ def hist_classified_stable_as_func_of_hull_dist (
20
+ df : pd .DataFrame ,
21
+ target_col : str ,
22
+ pred_cols : Sequence [str ],
23
+ e_above_hull_col : str ,
24
+ ax : plt .Axes = None ,
38
25
energy_type : Literal ["true" , "pred" ] = "true" ,
39
- annotate_all_stats : bool = False ,
26
+ criterion : Literal ["energy" , "std" , "neg_std" ] = "energy" ,
27
+ show_mae : bool = False ,
28
+ stability_thresh : float = 0 , # set stability threshold as distance to convex hull
29
+ # in eV / atom, usually 0 or 0.1 eV
30
+ x_lim : tuple [float , float ] = (- 0.4 , 0.4 ),
40
31
) -> plt .Axes :
41
32
"""
42
33
Histogram of the energy difference (either according to DFT ground truth [default]
@@ -49,40 +40,43 @@ def hist_classify_stable_as_func_of_hull_dist(
49
40
50
41
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
51
42
52
-
53
- # NOTE this figure plots hist bars separately which causes aliasing in pdf
54
- # to resolve this take into Inkscape and merge regions by color
43
+ NOTE this figure plots hist bars separately which causes aliasing in pdf
44
+ to resolve this take into Inkscape and merge regions by color
55
45
"""
56
- assert e_above_hull_vals .isna ().sum () == 0
46
+ if ax is None :
47
+ ax = plt .gca ()
57
48
58
- error = formation_energy_preds - formation_energy_targets
49
+ error = df [pred_cols ].mean (axis = 1 ) - df [target_col ]
50
+ e_above_hull_vals = df [e_above_hull_col ]
59
51
mean = error + e_above_hull_vals
60
52
61
- test = mean
53
+ if criterion == "energy" :
54
+ test = mean
55
+ elif "std" in criterion :
56
+ # TODO column names to compute standard deviation from are currently hardcoded
57
+ # needs to be updated when adding non-aviary models with uncertainty estimation
58
+ var_aleatoric = (df .filter (like = "_ale_" ) ** 2 ).mean (axis = 1 )
59
+ var_epistemic = df .filter (regex = r"_pred_\d" ).var (axis = 1 , ddof = 0 )
60
+ std_total = (var_epistemic + var_aleatoric ) ** 0.5
62
61
63
- if std_vals is not None :
64
62
if criterion == "std" :
65
- test += std_vals
66
- elif criterion == "neg " :
67
- test -= std_vals
63
+ test += std_total
64
+ elif criterion == "neg_std " :
65
+ test -= std_total
68
66
69
- xlim = (- 0.4 , 0.4 )
70
-
71
- # set stability threshold at on or 0.1 eV / atom above the hull
72
- stability_thresh = (0 , 0.1 )[0 ]
73
-
74
- actual_pos = e_above_hull_vals <= stability_thresh
75
- actual_neg = e_above_hull_vals > stability_thresh
76
- model_pos = test <= stability_thresh
77
- model_neg = test > stability_thresh
67
+ # --- histogram by DFT-computed distance to convex hull
68
+ if energy_type == "true" :
69
+ actual_pos = e_above_hull_vals <= stability_thresh
70
+ actual_neg = e_above_hull_vals > stability_thresh
71
+ model_pos = test <= stability_thresh
72
+ model_neg = test > stability_thresh
78
73
79
- n_true_pos = len (e_above_hull_vals [actual_pos & model_pos ])
80
- n_false_neg = len (e_above_hull_vals [actual_pos & model_neg ])
74
+ n_true_pos = len (e_above_hull_vals [actual_pos & model_pos ])
75
+ n_false_neg = len (e_above_hull_vals [actual_pos & model_neg ])
81
76
82
- n_total_pos = n_true_pos + n_false_neg
77
+ n_total_pos = n_true_pos + n_false_neg
78
+ null = n_total_pos / len (e_above_hull_vals )
83
79
84
- # --- histogram by DFT-computed distance to convex hull
85
- if energy_type == "true" :
86
80
true_pos = e_above_hull_vals [actual_pos & model_pos ]
87
81
false_neg = e_above_hull_vals [actual_pos & model_neg ]
88
82
false_pos = e_above_hull_vals [actual_neg & model_pos ]
@@ -97,12 +91,10 @@ def hist_classify_stable_as_func_of_hull_dist(
97
91
true_neg = mean [actual_neg & model_neg ]
98
92
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
99
93
100
- fig , ax = plt .subplots (1 , 1 , figsize = (10 , 9 ))
101
-
102
94
ax .hist (
103
95
[true_pos , false_neg , false_pos , true_neg ],
104
96
bins = 200 ,
105
- range = xlim ,
97
+ range = x_lim ,
106
98
alpha = 0.5 ,
107
99
color = ["tab:green" , "tab:orange" , "tab:red" , "tab:blue" ],
108
100
label = [
@@ -114,49 +106,34 @@ def hist_classify_stable_as_func_of_hull_dist(
114
106
stacked = True ,
115
107
)
116
108
117
- ax .legend (frameon = False , loc = "upper left" )
118
-
119
109
n_true_pos , n_false_pos , n_true_neg , n_false_neg = (
120
110
len (true_pos ),
121
111
len (false_pos ),
122
112
len (true_neg ),
123
113
len (false_neg ),
124
114
)
125
115
# null = (tp + fn) / (tp + tn + fp + fn)
126
- Null = n_total_pos / len (e_above_hull_vals )
127
- PPV = n_true_pos / (n_true_pos + n_false_pos )
128
- TPR = n_true_pos / n_total_pos
129
- F1 = 2 * PPV * TPR / (PPV + TPR )
130
-
131
- assert n_true_pos + n_false_pos + n_true_neg + n_false_neg == len (
132
- formation_energy_targets
133
- )
134
-
135
- RMSE = (error ** 2 ).mean () ** 0.5
136
- MAE = error .abs ().mean ()
137
-
138
- # anno_text = f"Prevalence = {null:.2f}\nPrecision = {ppv:.2f}\nRecall = {tpr:.2f}",
139
- anno_text = f"Enrichment Factor = { PPV / Null :.3} "
140
- if annotate_all_stats :
141
- anno_text += f"\n { MAE = :.3} \n { RMSE = :.3} \n { Null = :.3} \n { TPR = :.3} "
142
- else :
143
- print (f"{ PPV = :.3} " )
144
- print (f"{ TPR = :.3} " )
145
- print (f"{ F1 = :.3} " )
146
- print (f"Enrich: { PPV / Null :.2f} " )
147
- print (f"{ Null = :.3} " )
148
- print (f"{ MAE = :.3} " )
149
- print (f"{ RMSE = :.3} " )
150
-
151
- text_box = AnchoredText (
152
- anno_text , loc = "upper right" , frameon = False , prop = dict (fontsize = 16 )
116
+ precision = n_true_pos / (n_true_pos + n_false_pos )
117
+
118
+ assert n_true_pos + n_false_pos + n_true_neg + n_false_neg == len (df )
119
+
120
+ # recall = n_true_pos / n_total_pos
121
+ # f"Prevalence = {null:.2f}\n{precision = :.2f}\n{recall = :.2f}",
122
+ text = f"Enrichment\n Factor = { precision / null :.3} "
123
+ if show_mae :
124
+ MAE = error .abs ().mean ()
125
+ text += f"\n { MAE = :.3} "
126
+
127
+ ax .text (
128
+ 0.98 ,
129
+ 0.98 ,
130
+ text ,
131
+ fontsize = 18 ,
132
+ verticalalignment = "top" ,
133
+ horizontalalignment = "right" ,
134
+ transform = ax .transAxes ,
153
135
)
154
- ax .add_artist (text_box )
155
136
156
- ax .set (
157
- xlabel = xlabel ,
158
- ylabel = "Number of Compounds" ,
159
- title = f"data size = { len (e_above_hull_vals ):,} " ,
160
- )
137
+ ax .set (xlabel = xlabel , ylabel = "Number of compounds" )
161
138
162
139
return ax
0 commit comments