19
19
AxLine = Literal ["x" , "y" , "xy" , "" ]
20
20
21
21
22
- # --- define global plot settings
22
+ # --- start global plot settings
23
23
quantity_labels = dict (
24
24
n_atoms = "Atom Count" ,
25
25
n_elems = "Element Count" ,
55
55
56
56
57
57
plt .rc ("font" , size = 14 )
58
- plt .rc ("legend" , fontsize = 16 )
58
+ plt .rc ("legend" , fontsize = 16 , title_fontsize = 16 )
59
+ plt .rc ("axes" , titlesize = 16 , labelsize = 16 )
59
60
plt .rc ("savefig" , bbox = "tight" , dpi = 200 )
60
61
plt .rc ("figure" , dpi = 200 , titlesize = 16 )
61
62
plt .rcParams ["figure.constrained_layout.use" ] = True
@@ -69,11 +70,11 @@ def hist_classified_stable_as_func_of_hull_dist(
69
70
ax : plt .Axes = None ,
70
71
which_energy : WhichEnergy = "true" ,
71
72
stability_crit : StabilityCriterion = "energy" ,
72
- show_mae : bool = False ,
73
- stability_threshold : float = 0 , # set stability threshold as distance to convex
74
- # hull in eV / atom, usually 0 or 0.1 eV
75
- x_lim : tuple [ float , float ] = ( - 0.4 , 0.4 ) ,
76
- ) -> plt .Axes :
73
+ stability_threshold : float = 0 ,
74
+ show_threshold : bool = True ,
75
+ x_lim : tuple [ float | None , float | None ] = ( - 0.4 , 0.4 ),
76
+ rolling_accuracy : float = 0.02 ,
77
+ ) -> tuple [ plt .Axes , dict [ str , float ]] :
77
78
"""
78
79
Histogram of the energy difference (either according to DFT ground truth [default]
79
80
or model predicted energy) to the convex hull for materials in the WBM data set. The
@@ -85,8 +86,33 @@ def hist_classified_stable_as_func_of_hull_dist(
85
86
86
87
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
87
88
88
- NOTE this figure plots hist bars separately which causes aliasing in pdf
89
- to resolve this take into Inkscape and merge regions by color
89
+ Args:
90
+ e_above_hull_pred (pd.Series): energy difference to convex hull predicted by
91
+ model, i.e. difference between the model's predicted and true formation
92
+ energy.
93
+ e_above_hull_true (pd.Series): energy diff to convex hull according to DFT
94
+ ground truth.
95
+ std_pred (pd.Series, optional): standard deviation of the model's predicted
96
+ formation energy.
97
+ ax (plt.Axes, optional): matplotlib axes to plot on.
98
+ which_energy (WhichEnergy, optional): Whether to use the true formation energy
99
+ or the model's predicted formation energy for the histogram.
100
+ stability_crit (StabilityCriterion, optional): Whether to add/subtract the
101
+ model's predicted uncertainty from its energy prediction when measuring
102
+ predicted stability.
103
+ stability_threshold (float, optional): set stability threshold as distance to
104
+ convex hull in eV/atom, usually 0 or 0.1 eV.
105
+ show_threshold (bool, optional): Whether to plot stability threshold as dashed
106
+ vertical line.
107
+ x_lim (tuple[float | None, float | None]): x-axis limits.
108
+ rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to 0 to
109
+ disable. Defaults to 0.01.
110
+
111
+ Returns:
112
+ tuple[plt.Axes, dict[str, float]]: plot axes and classification metrics
113
+
114
+ NOTE this figure plots hist bars separately which causes aliasing in pdf. Can be
115
+ fixed in Inkscape or similar by merging regions by color.
90
116
"""
91
117
ax = ax or plt .gca ()
92
118
@@ -153,26 +179,60 @@ def hist_classified_stable_as_func_of_hull_dist(
153
179
# e_above_hull_true
154
180
# ), f"{n_all} != {len(e_above_hull_true)}"
155
181
156
- # recall = n_true_pos / n_total_pos
157
- # f"Prevalence = {null:.2f}\n{precision = :.2f}\n{recall = :.2f}",
158
- text = f"Enrichment\n Factor = { precision / null :.3} "
159
- if show_mae :
160
- MAE = e_above_hull_pred .abs ().mean ()
161
- text += f"\n { MAE = :.3} "
162
-
163
- ax .text (
164
- 0.98 ,
165
- 0.98 ,
166
- text ,
167
- fontsize = 18 ,
168
- verticalalignment = "top" ,
169
- horizontalalignment = "right" ,
170
- transform = ax .transAxes ,
171
- )
172
-
173
- ax .set (xlabel = xlabel , ylabel = "Number of compounds" )
182
+ ax .set (xlabel = xlabel , ylabel = "Number of compounds" , xlim = x_lim )
183
+
184
+ if rolling_accuracy :
185
+ # add moving average of the accuracy (computed within 20 meV/atom intervals) as
186
+ # a function of ΔHd,MP is shown as a blue line (right axis)
187
+ ax_acc = ax .twinx ()
188
+ ax_acc .set_ylabel ("Accuracy" , color = "darkblue" )
189
+ ax_acc .tick_params (labelcolor = "darkblue" )
190
+ ax_acc .set (ylim = (0 , 1 ))
191
+
192
+ # --- moving average of the accuracy
193
+ # compute accuracy within 20 meV/atom intervals
194
+ bins = np .arange (x_lim [0 ], x_lim [1 ], rolling_accuracy )
195
+ bin_counts = np .histogram (e_above_hull_true , bins )[0 ]
196
+ bin_true_pos = np .histogram (true_pos , bins )[0 ]
197
+ bin_true_neg = np .histogram (true_neg , bins )[0 ]
198
+
199
+ # compute accuracy
200
+ bin_accuracies = (bin_true_pos + bin_true_neg ) / bin_counts
201
+ # plot accuracy
202
+ ax_acc .plot (
203
+ bins [:- 1 ],
204
+ bin_accuracies ,
205
+ color = "tab:blue" ,
206
+ label = "Accuracy" ,
207
+ linewidth = 3 ,
208
+ )
209
+ # ax2.fill_between(
210
+ # bin_centers,
211
+ # bin_accuracy - bin_accuracy_std,
212
+ # bin_accuracy + bin_accuracy_std,
213
+ # color="tab:blue",
214
+ # alpha=0.2,
215
+ # )
216
+
217
+ if show_threshold :
218
+ ax .axvline (
219
+ stability_threshold ,
220
+ color = "k" ,
221
+ linestyle = "--" ,
222
+ label = "Stability Threshold" ,
223
+ )
174
224
175
- return ax
225
+ recall = n_true_pos / n_total_pos
226
+
227
+ return ax , {
228
+ "enrichment" : precision / null ,
229
+ "precision" : precision ,
230
+ "recall" : recall ,
231
+ "prevalence" : null ,
232
+ "accuracy" : (n_true_pos + n_true_neg )
233
+ / (n_true_pos + n_true_neg + n_false_pos + n_false_neg ),
234
+ "f1" : 2 * (precision * recall ) / (precision + recall ),
235
+ }
176
236
177
237
178
238
def rolling_mae_vs_hull_dist (
0 commit comments