@@ -84,7 +84,7 @@ def hist_classified_stable_vs_hull_dist(
84
84
x_lim : tuple [float | None , float | None ] = (- 0.4 , 0.4 ),
85
85
rolling_accuracy : float | None = 0.02 ,
86
86
backend : Backend = "plotly" ,
87
- ylabel : str = "Number of materials" ,
87
+ y_label : str = "Number of materials" ,
88
88
** kwargs : Any ,
89
89
) -> tuple [plt .Axes | go .Figure , dict [str , float ]]:
90
90
"""
@@ -112,8 +112,9 @@ def hist_classified_stable_vs_hull_dist(
112
112
x_lim (tuple[float | None, float | None]): x-axis limits.
113
113
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to None
114
114
or 0 to disable. Defaults to 0.02, meaning 20 meV / atom.
115
- backend ('matplotlib' | 'plotly'], optional): Which plotting backend to use.
116
- Changes the return type.
115
+ backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
116
+ Changes the return type. Defaults to 'plotly'.
117
+ y_label (str, optional): y-axis label. Defaults to "Number of materials".
117
118
kwargs: Additional keyword arguments passed to the ax.hist() or px.histogram()
118
119
depending on backend.
119
120
@@ -162,7 +163,7 @@ def hist_classified_stable_vs_hull_dist(
162
163
stacked = True ,
163
164
** kwargs ,
164
165
)
165
- ax .set (xlabel = xlabel , ylabel = ylabel , xlim = x_lim )
166
+ ax .set (xlabel = xlabel , ylabel = y_label , xlim = x_lim )
166
167
167
168
if stability_threshold is not None :
168
169
ax .axvline (
@@ -221,7 +222,7 @@ def hist_classified_stable_vs_hull_dist(
221
222
** kwargs ,
222
223
)
223
224
ax .update_layout (
224
- dict (xaxis_title = xlabel , yaxis_title = ylabel ),
225
+ dict (xaxis_title = xlabel , yaxis_title = y_label ),
225
226
legend = dict (title = None , yanchor = "top" , y = 1 , xanchor = "right" , x = 1 ),
226
227
)
227
228
@@ -251,27 +252,46 @@ def hist_classified_stable_vs_hull_dist(
251
252
def rolling_mae_vs_hull_dist (
252
253
e_above_hull_true : pd .Series ,
253
254
e_above_hull_error : pd .Series ,
254
- window : float = 0.04 ,
255
- bin_width : float = 0.002 ,
255
+ window : float = 0.02 ,
256
+ bin_width : float = 0.001 ,
256
257
x_lim : tuple [float , float ] = (- 0.2 , 0.3 ),
258
+ y_lim : tuple [float , float ] = (0.0 , 0.14 ),
257
259
ax : plt .Axes = None ,
260
+ backend : Backend = "plotly" ,
261
+ y_label : str = "rolling MAE (eV/atom)" ,
258
262
** kwargs : Any ,
259
263
) -> plt .Axes :
260
264
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
261
- bar is shown for the windowing period of 40 meV per atom used when calculating
262
- the rolling MAE. The standard error in the mean is shaded
263
- around each curve. The highlighted V-shaped region shows the area in which the
264
- average absolute error is greater than the energy to the known convex hull. This is
265
- where models are most at risk of misclassifying structures.
266
- """
267
- ax = ax or plt .gca ()
265
+ bar is shown for the windowing period of 40 meV per atom used when calculating the
266
+ rolling MAE. The standard error in the mean is shaded around each curve. The
267
+ highlighted V-shaped region shows the area in which the average absolute error is
268
+ greater than the energy to the known convex hull. This is where models are most at
269
+ risk of misclassifying structures.
268
270
269
- is_fresh_ax = len (ax .lines ) == 0
271
+ Args:
272
+ e_above_hull_true (pd.Series): Distance to convex hull according to DFT
273
+ ground truth (in eV / atom).
274
+ e_above_hull_error (pd.Series): Error in model-predicted distance to convex
275
+ hull, i.e. actual hull distance minus predicted hull distance (in eV / atom).
276
+ window (float, optional): Rolling MAE averaging window. Defaults to 0.02 (20 meV/atom)
277
+ bin_width (float, optional): Density of line points (more points the smaller).
278
+ Defaults to 0.002.
279
+ x_lim (tuple[float, float], optional): x-axis range. Defaults to (-0.2, 0.3).
280
+ y_lim (tuple[float, float], optional): y-axis range. Defaults to (0.0, 0.14).
281
+ ax (plt.Axes, optional): matplotlib Axes object. Defaults to None.
282
+ backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
283
+ Changes the return type. Defaults to 'plotly'.
284
+ y_label (str, optional): y-axis label. Defaults to "rolling MAE (eV/atom)".
285
+
286
+ Returns:
287
+ plt.Axes: _description_
288
+ """
270
289
271
290
bins = np .arange (* x_lim , bin_width )
272
291
273
292
rolling_maes = np .zeros_like (bins )
274
293
rolling_stds = np .zeros_like (bins )
294
+
275
295
for idx , bin_center in enumerate (bins ):
276
296
low = bin_center - window
277
297
high = bin_center + window
@@ -280,79 +300,152 @@ def rolling_mae_vs_hull_dist(
280
300
rolling_maes [idx ] = e_above_hull_error .loc [mask ].abs ().mean ()
281
301
rolling_stds [idx ] = scipy .stats .sem (e_above_hull_error .loc [mask ].abs ())
282
302
283
- kwargs = dict (linewidth = 3 ) | kwargs
284
- ax .plot (bins , rolling_maes , ** kwargs )
285
-
286
- ax .fill_between (
287
- bins , rolling_maes + rolling_stds , rolling_maes - rolling_stds , alpha = 0.3
288
- )
289
- # alternative implementation using pandas.rolling(). drawback: window size can only
290
- # be set as number of observations, not fixed-size energy above hull interval.
291
- # e_above_hull_error.index = e_above_hull_true # warning: in-place change
292
- # e_above_hull_error.sort_index().abs().rolling(window=8000).mean().plot(
293
- # ax=ax, **kwargs
294
- # )
295
-
296
- if not is_fresh_ax :
297
- # return earlier if all plot objects besides the line were already drawn by a
298
- # previous call
299
- return ax
300
-
301
- scale_bar = AnchoredSizeBar (
302
- ax .transData ,
303
- window ,
304
- "40 meV" ,
305
- "lower left" ,
306
- pad = 0.5 ,
307
- frameon = False ,
308
- size_vertical = 0.002 ,
309
- )
310
- # indicate size of MAE averaging window
311
- ax .add_artist (scale_bar )
312
-
313
- # DFT accuracy at 25 meV/atom for relative e_above_hull which is lower than
314
- # formation energy error due to systematic error cancellation among
315
- # similar chemistries, supporting ref:
303
+ # DFT accuracy at 25 meV/atom for e_above_hull calculations of chemically similar
304
+ # systems which is lower than formation energy error due to systematic error
305
+ # cancellation among similar chemistries, supporting ref:
316
306
# https://journals.aps.org/prb/abstract/10.1103/PhysRevB.85.155208
317
307
dft_acc = 0.025
318
- ax .plot ((dft_acc , 1 ), (dft_acc , 1 ), color = "grey" , linestyle = "--" , alpha = 0.3 )
319
- ax .plot ((- 1 , - dft_acc ), (1 , dft_acc ), color = "grey" , linestyle = "--" , alpha = 0.3 )
320
- ax .plot (
321
- (- dft_acc , dft_acc ), (dft_acc , dft_acc ), color = "grey" , linestyle = "--" , alpha = 0.3
322
- )
323
- ax .fill_between (
324
- (- 1 , - dft_acc , dft_acc , 1 ),
325
- (1 , 1 , 1 , 1 ),
326
- (1 , dft_acc , dft_acc , 1 ),
327
- color = "tab:red" ,
328
- alpha = 0.2 ,
329
- )
330
308
331
- ax .plot ((0 , dft_acc ), (0 , dft_acc ), color = "grey" , linestyle = "--" , alpha = 0.3 )
332
- ax .plot ((- dft_acc , 0 ), (dft_acc , 0 ), color = "grey" , linestyle = "--" , alpha = 0.3 )
333
- ax .fill_between (
334
- (- dft_acc , 0 , dft_acc ),
335
- (dft_acc , dft_acc , dft_acc ),
336
- (dft_acc , 0 , dft_acc ),
337
- color = "tab:orange" ,
338
- alpha = 0.2 ,
339
- )
340
- # shrink=0.1 means cut off 10% length from both sides of arrow line
341
- arrowprops = dict (
342
- facecolor = "black" , width = 0.5 , headwidth = 5 , headlength = 5 , shrink = 0.1
343
- )
344
- ax .annotate (
345
- xy = (- dft_acc , dft_acc ),
346
- xytext = (- 2 * dft_acc , dft_acc ),
347
- text = "Corrected\n GGA DFT\n Accuracy" ,
348
- arrowprops = arrowprops ,
349
- verticalalignment = "center" ,
350
- horizontalalignment = "right" ,
351
- )
309
+ if backend == "matplotlib" :
310
+ ax = ax or plt .gca ()
311
+ is_fresh_ax = len (ax .lines ) == 0
312
+ kwargs = dict (linewidth = 3 ) | kwargs
313
+ ax .plot (bins , rolling_maes , ** kwargs )
314
+
315
+ ax .fill_between (
316
+ bins , rolling_maes + rolling_stds , rolling_maes - rolling_stds , alpha = 0.3
317
+ )
318
+ # alternative implementation using pandas.rolling(). drawback: window size can only
319
+ # be set as number of observations, not fixed-size energy above hull interval.
320
+ # e_above_hull_error.index = e_above_hull_true # warning: in-place change
321
+ # e_above_hull_error.sort_index().abs().rolling(window=8000).mean().plot(
322
+ # ax=ax, **kwargs
323
+ # )
324
+ if not is_fresh_ax :
325
+ # return earlier if all plot objects besides the line were already drawn by a
326
+ # previous call
327
+ return ax
328
+
329
+ scale_bar = AnchoredSizeBar (
330
+ ax .transData ,
331
+ window ,
332
+ "40 meV" ,
333
+ "lower left" ,
334
+ pad = 0.5 ,
335
+ frameon = False ,
336
+ size_vertical = 0.002 ,
337
+ )
338
+ # indicate size of MAE averaging window
339
+ ax .add_artist (scale_bar )
340
+
341
+ ax .fill_between (
342
+ (- 1 , - dft_acc , dft_acc , 1 ),
343
+ (1 , 1 , 1 , 1 ),
344
+ (1 , dft_acc , dft_acc , 1 ),
345
+ color = "tab:red" ,
346
+ alpha = 0.2 ,
347
+ )
352
348
353
- ax .text (0 , 0.13 , r"$|E_\mathrm{above\ hull}| > $MAE" , horizontalalignment = "center" )
354
- ax .set (xlabel = r"$E_\mathrm{above\ hull}$ (eV / atom)" , ylabel = "MAE (eV / atom)" )
355
- ax .set (xlim = x_lim , ylim = (0.0 , 0.14 ))
349
+ ax .fill_between (
350
+ (- dft_acc , 0 , dft_acc ),
351
+ (dft_acc , dft_acc , dft_acc ),
352
+ (dft_acc , 0 , dft_acc ),
353
+ color = "tab:orange" ,
354
+ alpha = 0.2 ,
355
+ )
356
+ # shrink=0.1 means cut off 10% length from both sides of arrow line
357
+ arrowprops = dict (
358
+ facecolor = "black" , width = 0.5 , headwidth = 5 , headlength = 5 , shrink = 0.1
359
+ )
360
+ ax .annotate (
361
+ xy = (- dft_acc , dft_acc ),
362
+ xytext = (- 2 * dft_acc , dft_acc ),
363
+ text = "Corrected\n GGA DFT\n Accuracy" ,
364
+ arrowprops = arrowprops ,
365
+ verticalalignment = "center" ,
366
+ horizontalalignment = "right" ,
367
+ )
368
+
369
+ ax .text (
370
+ 0 , 0.13 , r"MAE > $|E_\mathrm{above\ hull}|$" , horizontalalignment = "center"
371
+ )
372
+ ax .set (xlabel = r"$E_\mathrm{above\ hull}$ (eV/atom)" , ylabel = y_label )
373
+ ax .set (xlim = x_lim , ylim = y_lim )
374
+ elif backend == "plotly" :
375
+ title = kwargs .pop ("label" , None )
376
+ ax = px .line (
377
+ x = bins ,
378
+ y = rolling_maes ,
379
+ # error_y=rolling_stds,
380
+ markers = False ,
381
+ title = title ,
382
+ ** kwargs ,
383
+ )
384
+ ax_std = go .Scatter (
385
+ x = list (bins ) + list (bins )[::- 1 ], # bins, then bins reversed
386
+ y = list (rolling_maes + 2 * rolling_stds )
387
+ + list (rolling_maes - 2 * rolling_stds )[::- 1 ], # upper, then lower reversed
388
+ fill = "toself" ,
389
+ line_color = "white" ,
390
+ fillcolor = ax .data [0 ].line .color ,
391
+ opacity = 0.3 ,
392
+ hoverinfo = "skip" ,
393
+ showlegend = False ,
394
+ )
395
+ ax .add_trace (ax_std )
396
+
397
+ ax .update_layout (
398
+ dict (
399
+ xaxis_title = "E<sub>above hull</sub> (eV/atom)" ,
400
+ yaxis_title = "rolling MAE (eV/atom)" ,
401
+ ),
402
+ legend = dict (title = None , xanchor = "right" , x = 1 , yanchor = "bottom" , y = 0 ),
403
+ )
404
+ ax .update_xaxes (range = x_lim )
405
+ ax .update_yaxes (range = y_lim )
406
+ scatter_kwds = dict (fill = "toself" , opacity = 0.5 )
407
+ err_gt_each_region = go .Scatter (
408
+ x = (- 1 , - dft_acc , dft_acc , 1 ),
409
+ y = (1 , dft_acc , dft_acc , 1 ),
410
+ name = "MAE > |E<sub>above hull</sub>|" ,
411
+ # fillcolor="yellow",
412
+ ** scatter_kwds ,
413
+ )
414
+ ml_err_lt_dft_err_region = go .Scatter (
415
+ x = (- dft_acc , dft_acc , 0 , - dft_acc ),
416
+ y = (dft_acc , dft_acc , 0 , dft_acc ),
417
+ name = "MAE < |DFT error|" ,
418
+ # fillcolor="red",
419
+ ** scatter_kwds ,
420
+ )
421
+ ax .add_traces ([err_gt_each_region , ml_err_lt_dft_err_region ])
422
+ ax .add_annotation (
423
+ x = 4 * dft_acc ,
424
+ y = dft_acc ,
425
+ text = "Corrected GGA DFT Accuracy" ,
426
+ showarrow = True ,
427
+ # arrowhead=1,
428
+ ax = - dft_acc ,
429
+ ay = dft_acc ,
430
+ )
431
+
432
+ ax .data = ax .data [::- 1 ] # bring px.line() to front
433
+ # show MAE window size
434
+ x0 , y0 = x_lim [0 ] + 0.01 , y_lim [0 ] + 0.01
435
+ ax .add_annotation (
436
+ x = x0 + 0.05 ,
437
+ y = y0 + 0.01 ,
438
+ text = f"rolling MAE window<br>{ window } eV/atom" ,
439
+ showarrow = False ,
440
+ )
441
+ ax .add_shape (
442
+ type = "rect" ,
443
+ x0 = x0 ,
444
+ y0 = y0 ,
445
+ x1 = x0 + window ,
446
+ y1 = y0 + window / 5 ,
447
+ fillcolor = "black" ,
448
+ )
356
449
357
450
return ax
358
451
@@ -388,8 +481,9 @@ def cumulative_precision_recall(
388
481
axis projection lines.
389
482
show_optimal (bool, optional): Whether to plot the optimal recall line. Defaults
390
483
to False.
391
- backend ('plotly' | 'matplotlib', optional): Defaults to 'plotly'. **kwargs:
392
- Keyword arguments passed to df.plot().
484
+ backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
485
+ Changes the return type. Defaults to 'plotly'.
486
+ **kwargs: Keyword arguments passed to df.plot().
393
487
394
488
Returns:
395
489
tuple[plt.Figure | go.Figure, pd.DataFrame]: The matplotlib/plotly figure and
0 commit comments