@@ -296,6 +296,7 @@ def rolling_mae_vs_hull_dist(
296
296
y_label : str = "rolling MAE (eV/atom)" ,
297
297
just_plot_lines : bool = False ,
298
298
with_sem : bool = True ,
299
+ show_dft_acc : bool = False ,
299
300
** kwargs : Any ,
300
301
) -> plt .Axes | go .Figure :
301
302
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
@@ -325,6 +326,9 @@ def rolling_mae_vs_hull_dist(
325
326
to False.
326
327
with_sem (bool, optional): If True, plot the standard error of the mean as
327
328
shaded area around the rolling MAE. Defaults to True.
329
+ show_dft_acc (bool, optional): If True, change color of the cone of peril's tip
330
+ and annotate it with 'Corrected GGA Accuracy' at rolling MAE of 25 meV/atom.
331
+ Defaults to False.
328
332
329
333
Returns:
330
334
tuple[plt.Axes | go.Figure, pd.DataFrame, pd.DataFrame]: matplotlib Axes or
@@ -363,8 +367,8 @@ def rolling_mae_vs_hull_dist(
363
367
# previous call
364
368
return ax , df_rolling_err , df_err_std
365
369
366
- # DFT accuracy at 25 meV/atom for e_above_hull calculations of chemically similar
367
- # systems which is lower than formation energy error due to systematic error
370
+ # DFT accuracy at 25 meV/atom for relative difference of e_above_hull for chemically
371
+ # similar systems which is lower than formation energy error due to systematic error
368
372
# cancellation among similar chemistries, supporting ref:
369
373
href = "https://doi.org/10.1103/PhysRevB.85.155208"
370
374
dft_acc = 0.025
@@ -397,32 +401,33 @@ def rolling_mae_vs_hull_dist(
397
401
ax .add_artist (scale_bar )
398
402
399
403
ax .fill_between (
400
- (- 1 , - dft_acc , dft_acc , 1 ),
401
- (1 , 1 , 1 , 1 ),
402
- (1 , dft_acc , dft_acc , 1 ),
404
+ (- 1 , - dft_acc , dft_acc , 1 ) if show_dft_acc else ( - 1 , 0 , 1 ) ,
405
+ (1 , 1 , 1 , 1 ) if show_dft_acc else ( 1 , 1 , 1 ) ,
406
+ (1 , dft_acc , dft_acc , 1 ) if show_dft_acc else ( 1 , 0 , 1 ) ,
403
407
color = "tab:red" ,
404
408
alpha = 0.2 ,
405
409
)
406
410
407
- ax .fill_between (
408
- (- dft_acc , 0 , dft_acc ),
409
- (dft_acc , dft_acc , dft_acc ),
410
- (dft_acc , 0 , dft_acc ),
411
- color = "tab:orange" ,
412
- alpha = 0.2 ,
413
- )
414
- # shrink=0.1 means cut off 10% length from both sides of arrow line
415
- arrowprops = dict (
416
- facecolor = "black" , width = 0.5 , headwidth = 5 , headlength = 5 , shrink = 0.1
417
- )
418
- ax .annotate (
419
- xy = (- dft_acc , dft_acc ),
420
- xytext = (- 2 * dft_acc , dft_acc ),
421
- text = "Corrected\n GGA DFT\n Accuracy" ,
422
- arrowprops = arrowprops ,
423
- verticalalignment = "center" ,
424
- horizontalalignment = "right" ,
425
- )
411
+ if show_dft_acc :
412
+ ax .fill_between (
413
+ (- dft_acc , 0 , dft_acc ),
414
+ (dft_acc , dft_acc , dft_acc ),
415
+ (dft_acc , 0 , dft_acc ),
416
+ color = "tab:orange" ,
417
+ alpha = 0.2 ,
418
+ )
419
+ # shrink=0.1 means cut off 10% length from both sides of arrow line
420
+ arrowprops = dict (
421
+ facecolor = "black" , width = 0.5 , headwidth = 5 , headlength = 5 , shrink = 0.1
422
+ )
423
+ ax .annotate (
424
+ xy = (- dft_acc , dft_acc ),
425
+ xytext = (- 2 * dft_acc , dft_acc ),
426
+ text = "Corrected GGA\n Accuracy" ,
427
+ arrowprops = arrowprops ,
428
+ verticalalignment = "center" ,
429
+ horizontalalignment = "right" ,
430
+ )
426
431
427
432
ax .text (
428
433
0 , 0.13 , r"MAE > $|E_\mathrm{above\ hull}|$" , horizontalalignment = "center"
@@ -457,43 +462,49 @@ def rolling_mae_vs_hull_dist(
457
462
yanchor = "bottom" ,
458
463
title_font = dict (size = 13 ),
459
464
)
460
- ax .update_layout (
461
- dict (
462
- xaxis_title = "E<sub>above MP hull</sub> (eV/atom)" ,
463
- yaxis_title = "rolling MAE (eV/atom)" ,
464
- ),
465
- legend = legend ,
466
- )
465
+ ax .layout .legend .update (legend )
466
+ ax .layout .xaxis .title .text = "E<sub>above MP hull</sub> (eV/atom)"
467
+ ax .layout .yaxis .title .text = "rolling MAE (eV/atom)"
467
468
ax .update_xaxes (range = x_lim )
468
469
ax .update_yaxes (range = y_lim )
469
- scatter_kwds = dict (fill = "toself" , opacity = 0.4 )
470
- ax .add_scatter (
471
- x = (- 1 , - dft_acc , dft_acc , 1 ),
472
- y = (1 , dft_acc , dft_acc , 1 ),
473
- name = "MAE > |E<sub>above hull</sub>|" ,
474
- # fillcolor="yellow",
475
- ** scatter_kwds ,
476
- )
470
+ scatter_kwds = dict (fill = "toself" , opacity = 0.2 )
471
+ peril_cone_anno = "MAE > |E<sub>above hull</sub>|"
477
472
ax .add_scatter (
478
- x = (- dft_acc , dft_acc , 0 , - dft_acc ),
479
- y = (dft_acc , dft_acc , 0 , dft_acc ),
480
- name = "MAE < |DFT error|" ,
481
- # fillcolor="red",
473
+ x = (- 1 , - dft_acc , dft_acc , 1 ) if show_dft_acc else (- 1 , 0 , 1 ),
474
+ y = (1 , dft_acc , dft_acc , 1 ) if show_dft_acc else (1 , 0 , 1 ),
475
+ name = peril_cone_anno ,
476
+ fillcolor = "red" ,
477
+ showlegend = False ,
482
478
** scatter_kwds ,
483
479
)
484
480
ax .add_annotation (
485
- x = - dft_acc ,
486
- y = dft_acc ,
487
- text = f"<a { href = } >Corrected GGA Accuracy<br>for rel. Energy</a> "
488
- "[<a href='#hautier_accuracy_2012' target='_self'>ref</a>]" ,
489
- showarrow = True ,
490
- xshift = - 10 ,
491
- arrowhead = 2 ,
492
- ax = - 4 * dft_acc ,
493
- ay = 2 * dft_acc ,
494
- axref = "x" ,
495
- ayref = "y" ,
481
+ x = 0 ,
482
+ y = 0.8 ,
483
+ text = peril_cone_anno ,
484
+ showarrow = False ,
485
+ yref = "paper" ,
496
486
)
487
+ if show_dft_acc :
488
+ ax .add_scatter (
489
+ x = (- dft_acc , dft_acc , 0 , - dft_acc ),
490
+ y = (dft_acc , dft_acc , 0 , dft_acc ),
491
+ name = "MAE < |Corrected GGA error|" ,
492
+ fillcolor = "red" ,
493
+ ** scatter_kwds ,
494
+ )
495
+ ax .add_annotation (
496
+ x = - dft_acc ,
497
+ y = dft_acc ,
498
+ text = f"<a { href = } >Corrected GGA Accuracy<br>for rel. Energy</a> "
499
+ "[<a href='#hautier_accuracy_2012' target='_self'>ref</a>]" ,
500
+ showarrow = True ,
501
+ xshift = - 10 ,
502
+ arrowhead = 2 ,
503
+ ax = - 4 * dft_acc ,
504
+ ay = 2 * dft_acc ,
505
+ axref = "x" ,
506
+ ayref = "y" ,
507
+ )
497
508
498
509
ax .data = ax .data [::- 1 ] # bring px.line() to front
499
510
# plot rectangle to indicate MAE window size
0 commit comments