6
6
import numpy as np
7
7
import pandas as pd
8
8
import plotly .express as px
9
+ import plotly .graph_objs as go
9
10
import plotly .io as pio
10
11
import scipy .interpolate
11
12
import scipy .stats
19
20
20
21
WhichEnergy = Literal ["true" , "pred" ]
21
22
AxLine = Literal ["x" , "y" , "xy" , "" ]
22
-
23
+ Backend = Literal [ "matplotlib" , "plotly" ]
23
24
24
25
# --- start global plot settings
25
26
quantity_labels = dict (
53
54
dft = "DFT" ,
54
55
)
55
56
px .defaults .labels = quantity_labels | model_labels
56
-
57
- pio .templates .default = "plotly_white"
57
+ pastel_layout = dict (
58
+ colorway = px .colors .qualitative .Pastel , margin = dict (l = 40 , r = 30 , t = 60 , b = 30 )
59
+ )
60
+ pio .templates ["pastel" ] = dict (layout = pastel_layout )
61
+ pio .templates .default = "plotly_white+pastel"
58
62
59
63
# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924
60
64
# when seeing MathJax "loading" message in exported PDFs, try:
@@ -79,7 +83,9 @@ def hist_classified_stable_vs_hull_dist(
79
83
show_threshold : bool = True ,
80
84
x_lim : tuple [float | None , float | None ] = (- 0.4 , 0.4 ),
81
85
rolling_accuracy : float | None = 0.02 ,
82
- ) -> tuple [plt .Axes , dict [str , float ]]:
86
+ backend : Backend = "plotly" ,
87
+ ylabel : str = "Number of materials" ,
88
+ ) -> tuple [plt .Axes | go .Figure , dict [str , float ]]:
83
89
"""
84
90
Histogram of the energy difference (either according to DFT ground truth [default]
85
91
or model predicted energy) to the convex hull for materials in the WBM data set. The
@@ -106,16 +112,16 @@ def hist_classified_stable_vs_hull_dist(
106
112
vertical line.
107
113
x_lim (tuple[float | None, float | None]): x-axis limits.
108
114
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to None
109
- or 0 to disable. Defaults to 0.01.
115
+ or 0 to disable. Defaults to 0.02, meaning 20 meV / atom.
116
+ backend ('matplotlib' | 'plotly'], optional): Which plotting backend to use.
117
+ Changes the return type.
110
118
111
119
Returns:
112
120
tuple[plt.Axes, dict[str, float]]: plot axes and classification metrics
113
121
114
122
NOTE this figure plots hist bars separately which causes aliasing in pdf. Can be
115
123
fixed in Inkscape or similar by merging regions by color.
116
124
"""
117
- ax = ax or plt .gca ()
118
-
119
125
true_pos , false_neg , false_pos , true_neg = classify_stable (
120
126
e_above_hull_true , e_above_hull_pred , stability_threshold
121
127
)
@@ -131,90 +137,105 @@ def hist_classified_stable_vs_hull_dist(
131
137
eah_false_neg = e_above_hull [false_neg ]
132
138
eah_false_pos = e_above_hull [false_pos ]
133
139
eah_true_neg = e_above_hull [true_neg ]
134
- xlabel = dict (
135
- true = "$E_\\ mathrm{above\\ hull}$ (eV / atom)" ,
136
- pred = "$E_\\ mathrm{above\\ hull\\ pred}$ (eV / atom)" ,
137
- )[which_energy ]
138
-
139
- ax .hist (
140
- [eah_true_pos , eah_false_neg , eah_false_pos , eah_true_neg ],
141
- bins = 200 ,
142
- range = x_lim ,
143
- alpha = 0.5 ,
144
- color = ["tab:green" , "tab:orange" , "tab:red" , "tab:blue" ],
145
- label = [
146
- "True Positives" ,
147
- "False Negatives" ,
148
- "False Positives" ,
149
- "True Negatives" ,
150
- ],
151
- stacked = True ,
152
- )
153
-
154
140
n_true_pos , n_false_pos , n_true_neg , n_false_neg = map (
155
- len , (eah_true_pos , eah_false_pos , eah_true_neg , eah_false_neg )
141
+ sum , (true_pos , false_pos , true_neg , false_neg )
156
142
)
157
143
# null = (tp + fn) / (tp + tn + fp + fn)
158
144
precision = n_true_pos / (n_true_pos + n_false_pos )
159
145
160
- # assert (n_all := n_true_pos + n_false_pos + n_true_neg + n_false_neg) == len(
161
- # e_above_hull_true
162
- # ), f"{n_all} != {len(e_above_hull_true)}"
163
-
164
- ax .set (xlabel = xlabel , ylabel = "Number of compounds" , xlim = x_lim )
165
-
166
- if rolling_accuracy :
167
- # add moving average of the accuracy (computed within 20 meV/atom intervals) as
168
- # a function of ΔHd,MP is shown as a blue line (right axis)
169
- ax_acc = ax .twinx ()
170
- ax_acc .set_ylabel ("Accuracy" , color = "darkblue" )
171
- ax_acc .tick_params (labelcolor = "darkblue" )
172
- ax_acc .set (ylim = (0 , 1 ))
173
-
174
- # --- moving average of the accuracy
175
- # compute accuracy within 20 meV/atom intervals
176
- bins = np .arange (x_lim [0 ], x_lim [1 ], rolling_accuracy )
177
- bin_counts = np .histogram (e_above_hull_true , bins )[0 ]
178
- bin_true_pos = np .histogram (eah_true_pos , bins )[0 ]
179
- bin_true_neg = np .histogram (eah_true_neg , bins )[0 ]
180
-
181
- # compute accuracy
182
- bin_accuracies = (bin_true_pos + bin_true_neg ) / bin_counts
183
- # plot accuracy
184
- ax_acc .plot (
185
- bins [:- 1 ],
186
- bin_accuracies ,
187
- color = "tab:blue" ,
188
- label = "Accuracy" ,
189
- linewidth = 3 ,
146
+ xlabel = dict (
147
+ true = r"$E_\mathrm{above\ hull}\;\mathrm{(eV / atom)}$" ,
148
+ pred = r"$E_\mathrm{above\ hull\ pred}\;\mathrm{(eV / atom)}$" ,
149
+ )[which_energy ]
150
+ labels = ["True Positives" , "False Negatives" , "False Positives" , "True Negatives" ]
151
+
152
+ if backend == "matplotlib" :
153
+ ax = ax or plt .gca ()
154
+ ax .hist (
155
+ [eah_true_pos , eah_false_neg , eah_false_pos , eah_true_neg ],
156
+ bins = 200 ,
157
+ range = x_lim ,
158
+ alpha = 0.5 ,
159
+ color = ["tab:green" , "tab:orange" , "tab:red" , "tab:blue" ],
160
+ label = labels ,
161
+ stacked = True ,
190
162
)
191
- # ax2.fill_between(
192
- # bin_centers,
193
- # bin_accuracy - bin_accuracy_std,
194
- # bin_accuracy + bin_accuracy_std,
195
- # color="tab:blue",
196
- # alpha=0.2,
197
- # )
198
-
199
- if show_threshold :
163
+ ax .set (xlabel = xlabel , ylabel = ylabel , xlim = x_lim )
164
+
200
165
ax .axvline (
201
166
stability_threshold ,
202
- color = "k " ,
167
+ color = "black " ,
203
168
linestyle = "--" ,
204
169
label = "Stability Threshold" ,
205
170
)
206
171
207
- recall = n_true_pos / n_total_pos
172
+ if rolling_accuracy :
173
+ # add moving average of the accuracy computed within given window
174
+ # as a function of e_above_hull shown as blue line (right axis)
175
+ ax_acc = ax .twinx ()
176
+ ax_acc .set_ylabel ("Accuracy" , color = "darkblue" )
177
+ ax_acc .tick_params (labelcolor = "darkblue" )
178
+ ax_acc .set (ylim = (0 , 1 ))
179
+
180
+ # --- moving average of the accuracy
181
+ # compute accuracy within 20 meV/atom intervals
182
+ bins = np .arange (x_lim [0 ], x_lim [1 ], rolling_accuracy )
183
+ bin_counts = np .histogram (e_above_hull_true , bins )[0 ]
184
+ bin_true_pos = np .histogram (eah_true_pos , bins )[0 ]
185
+ bin_true_neg = np .histogram (eah_true_neg , bins )[0 ]
186
+
187
+ # compute accuracy
188
+ bin_accuracies = (bin_true_pos + bin_true_neg ) / bin_counts
189
+ # plot accuracy
190
+ ax_acc .plot (
191
+ bins [:- 1 ],
192
+ bin_accuracies ,
193
+ color = "tab:blue" ,
194
+ label = "Accuracy" ,
195
+ linewidth = 3 ,
196
+ )
197
+ # ax2.fill_between(
198
+ # bin_centers,
199
+ # bin_accuracy - bin_accuracy_std,
200
+ # bin_accuracy + bin_accuracy_std,
201
+ # color="tab:blue",
202
+ # alpha=0.2,
203
+ # )
204
+
205
+ if backend == "plotly" :
206
+ clf = (true_pos * 1 + false_neg * 2 + false_pos * 3 + true_neg * 4 ).map (
207
+ dict (zip (range (1 , 5 ), labels ))
208
+ )
209
+ df = pd .DataFrame (dict (e_above_hull = e_above_hull , clf = clf ))
208
210
209
- return ax , {
210
- "enrichment" : precision / null ,
211
- "precision" : precision ,
212
- "recall" : recall ,
213
- "prevalence" : null ,
214
- "accuracy" : (n_true_pos + n_true_neg )
211
+ ax = px .histogram (
212
+ df , x = "e_above_hull" , color = "clf" , nbins = 20000 , range_x = x_lim , opacity = 0.9
213
+ )
214
+ ax .update_layout (
215
+ dict (xaxis_title = xlabel , yaxis_title = ylabel ),
216
+ legend = dict (title = None , yanchor = "top" , y = 1 , xanchor = "right" , x = 1 ),
217
+ )
218
+
219
+ ax .add_vline (stability_threshold , line = dict (dash = "dash" , width = 1 ))
220
+ ax .add_annotation (
221
+ text = "Stability threshold" ,
222
+ x = stability_threshold ,
223
+ y = 1.1 ,
224
+ yref = "paper" ,
225
+ font = dict (size = 14 , color = "gray" ),
226
+ showarrow = False ,
227
+ )
228
+
229
+ recall = n_true_pos / n_total_pos
230
+ return ax , dict (
231
+ enrichment = precision / null ,
232
+ precision = precision ,
233
+ recall = recall ,
234
+ prevalence = null ,
235
+ accuracy = (n_true_pos + n_true_neg )
215
236
/ (n_true_pos + n_true_neg + n_false_pos + n_false_neg ),
216
- "f1" : 2 * (precision * recall ) / (precision + recall ),
217
- }
237
+ f1 = 2 * (precision * recall ) / (precision + recall ),
238
+ )
218
239
219
240
220
241
def rolling_mae_vs_hull_dist (
@@ -432,7 +453,7 @@ def cumulative_clf_metric(
432
453
433
454
434
455
def wandb_scatter (table : wandb .Table , fields : dict [str , str ], ** kwargs : Any ) -> None :
435
- """Log a parity scatter plot using custom vega spec to WandB.
456
+ """Log a parity scatter plot using custom Vega spec to WandB.
436
457
437
458
Args:
438
459
table (wandb.Table): WandB data table.
0 commit comments