@@ -254,14 +254,14 @@ def hist_classified_stable_vs_hull_dist(
254
254
255
255
def rolling_mae_vs_hull_dist (
256
256
e_above_hull_true : pd .Series ,
257
- e_above_hull_error : pd .Series ,
257
+ e_above_hull_errors : pd .DataFrame | dict [ str , pd . Series ] ,
258
258
window : float = 0.02 ,
259
259
bin_width : float = 0.001 ,
260
260
x_lim : tuple [float , float ] = (- 0.2 , 0.2 ),
261
- y_lim : tuple [float , float ] = (0 , 0.15 ),
262
- ax : plt .Axes = None ,
261
+ y_lim : tuple [float , float ] = (0 , 0.2 ),
263
262
backend : Backend = "plotly" ,
264
263
y_label : str = "rolling MAE (eV/atom)" ,
264
+ just_plot_lines : bool = False ,
265
265
** kwargs : Any ,
266
266
) -> plt .Axes | go .Figure :
267
267
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
@@ -274,61 +274,75 @@ def rolling_mae_vs_hull_dist(
274
274
Args:
275
275
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
276
276
ground truth (in eV / atom).
277
- e_above_hull_error (pd.Series): Error in model-predicted distance to convex
278
- hull, i.e. actual hull distance minus predicted hull distance (in eV / atom).
279
- window (float, optional): Rolling MAE averaging window. Defaults to 0.02 (20 meV/atom)
280
- bin_width (float, optional): Density of line points (more points the smaller).
277
+ e_above_hull_errors (pd.DataFrame | dict[str, pd.Series]): Error in
278
+ model-predicted distance to convex hull, i.e. actual hull distance minus
279
+ predicted hull distance (in eV / atom).
280
+ window (float, optional): Rolling MAE averaging window. Defaults to 0.02 (20
281
+ meV/atom) bin_width (float, optional): Density of line points (more points the
282
+ smaller).
281
283
Defaults to 0.002.
282
284
x_lim (tuple[float, float], optional): x-axis range. Defaults to (-0.2, 0.3).
283
285
y_lim (tuple[float, float], optional): y-axis range. Defaults to (0.0, 0.14).
284
- ax (plt.Axes, optional): matplotlib Axes object. Defaults to None.
285
286
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
286
287
Changes the return type. Defaults to 'plotly'.
287
288
y_label (str, optional): y-axis label. Defaults to "rolling MAE (eV/atom)".
289
+ just_plot_line (bool, optional): If True, plot only the rolling MAE, no shapes
290
+ and annotations. Also won't plot the standard error in the mean. Defaults
291
+ to False.
288
292
289
293
Returns:
290
- plt.Axes | go.Figure: matplotlib Axes or plotly Figure depending on backend.
294
+ tuple[plt.Axes | go.Figure, pd.DataFrame, pd.DataFrame]: matplotlib Axes or
295
+ plotly
296
+ Figure depending on backend, followed by two dataframes containing the
297
+ rolling error for each column in e_above_hull_errors and the rolling
298
+ standard error in the mean.
291
299
"""
292
300
bins = np .arange (* x_lim , bin_width )
301
+ models = list (e_above_hull_errors )
302
+
303
+ df_rolling_err = pd .DataFrame (columns = models , index = bins )
304
+ df_err_std = df_rolling_err .copy ()
305
+
306
+ for model in models :
307
+ for idx , bin_center in enumerate (bins ):
308
+ low = bin_center - window
309
+ high = bin_center + window
293
310
294
- rolling_maes = np .zeros_like (bins )
295
- rolling_stds = np .zeros_like (bins )
311
+ mask = (e_above_hull_true <= high ) & (e_above_hull_true > low )
296
312
297
- for idx , bin_center in enumerate (bins ):
298
- low = bin_center - window
299
- high = bin_center + window
313
+ each_mae = e_above_hull_errors [model ].loc [mask ].abs ().mean ()
314
+ df_rolling_err [model ].iloc [idx ] = each_mae
300
315
301
- mask = (e_above_hull_true <= high ) & (e_above_hull_true > low )
302
- rolling_maes [idx ] = e_above_hull_error .loc [mask ].abs ().mean ()
303
- rolling_stds [idx ] = scipy .stats .sem (e_above_hull_error .loc [mask ].abs ())
316
+ # drop NaNs to avoid error, scipy doesn't ignore NaNs
317
+ each_std = scipy .stats .sem (
318
+ e_above_hull_errors [model ].loc [mask ].dropna ().abs ()
319
+ )
320
+ df_err_std [model ].iloc [idx ] = each_std
321
+
322
+ # increase line width
323
+ ax = df_rolling_err .plot (backend = backend , ** kwargs )
324
+
325
+ if just_plot_lines :
326
+ # return earlier if all plot objects besides the line were already drawn by a
327
+ # previous call
328
+ return ax , df_rolling_err , df_err_std
304
329
305
330
# DFT accuracy at 25 meV/atom for e_above_hull calculations of chemically similar
306
331
# systems which is lower than formation energy error due to systematic error
307
332
# cancellation among similar chemistries, supporting ref:
308
- # https://journals.aps. org/prb/abstract/ 10.1103/PhysRevB.85.155208
333
+ href = " https://doi. org/10.1103/PhysRevB.85.155208"
309
334
dft_acc = 0.025
310
- # used by plotly branch of this function, unrecognized by matplotlib
311
- fig = kwargs .pop ("fig" , None )
312
335
313
336
if backend == "matplotlib" :
314
- ax = ax or plt .gca ()
315
- is_fresh_ax = len (ax .lines ) == 0
316
- kwargs = dict (linewidth = 3 ) | kwargs
317
- ax .plot (bins , rolling_maes , ** kwargs )
318
-
319
- ax .fill_between (
320
- bins , rolling_maes + rolling_stds , rolling_maes - rolling_stds , alpha = 0.3
321
- )
322
- # alternative implementation using pandas.rolling(). drawback: window size can only
323
- # be set as number of observations, not fixed-size energy above hull interval.
324
- # e_above_hull_error.index = e_above_hull_true # warning: in-place change
325
- # e_above_hull_error.sort_index().abs().rolling(window=8000).mean().plot(
326
- # ax=ax, **kwargs
327
- # )
328
- if not is_fresh_ax :
329
- # return earlier if all plot objects besides the line were already drawn by a
330
- # previous call
331
- return ax
337
+ # assert df_rolling_err.isna().sum().sum() == 0, "NaNs in df_rolling_err"
338
+ # assert df_err_std.isna().sum().sum() == 0, "NaNs in df_err_std"
339
+ # for model in df_rolling_err:
340
+ # ax.fill_between(
341
+ # bins,
342
+ # df_rolling_err[model] + df_err_std[model],
343
+ # df_rolling_err[model] - df_err_std[model],
344
+ # alpha=0.3,
345
+ # )
332
346
333
347
scale_bar = AnchoredSizeBar (
334
348
ax .transData ,
@@ -376,34 +390,22 @@ def rolling_mae_vs_hull_dist(
376
390
ax .set (xlabel = r"$E_\mathrm{above\ hull}$ (eV/atom)" , ylabel = y_label )
377
391
ax .set (xlim = x_lim , ylim = y_lim )
378
392
elif backend == "plotly" :
379
- title = kwargs .pop ("label" , None )
380
- ax = px .line (
381
- x = bins ,
382
- y = rolling_maes ,
383
- # error_y=rolling_stds,
384
- markers = False ,
385
- title = title ,
386
- ** kwargs ,
387
- )
388
- line_color = ax .data [0 ].line .color
389
- ax_std = go .Scatter (
390
- x = list (bins ) + list (bins )[::- 1 ], # bins, then bins reversed
391
- y = list (rolling_maes + 2 * rolling_stds )
392
- + list (rolling_maes - 2 * rolling_stds )[::- 1 ], # upper, then lower reversed
393
- fill = "toself" ,
394
- line_color = "white" ,
395
- fillcolor = line_color ,
396
- opacity = 0.3 ,
397
- hoverinfo = "skip" ,
398
- showlegend = False ,
399
- )
400
- ax .add_trace (ax_std )
401
-
402
- if isinstance (fig , go .Figure ):
403
- # if passed existing plotly figure, add traces to it
404
- # return without changing layout and adding annotations
405
- fig .add_traces (ax .data )
406
- return fig
393
+ for idx , model in enumerate (df_rolling_err ):
394
+ ax .data [idx ].legendgroup = model
395
+ ax .add_scatter (
396
+ x = list (bins ) + list (bins )[::- 1 ], # bins, then bins reversed
397
+ y = list (df_rolling_err [model ] + 3 * df_err_std [model ])
398
+ + list (df_rolling_err [model ] - 3 * df_err_std [model ])[
399
+ ::- 1
400
+ ], # upper, then lower reversed
401
+ mode = "lines" ,
402
+ line = dict (color = "white" , width = 0 ),
403
+ fill = "toself" ,
404
+ legendgroup = model ,
405
+ fillcolor = ax .data [0 ].line .color ,
406
+ opacity = 0.3 ,
407
+ showlegend = False ,
408
+ )
407
409
408
410
legend = dict (title = None , xanchor = "right" , x = 1 , yanchor = "bottom" , y = 0 )
409
411
ax .update_layout (
@@ -415,32 +417,30 @@ def rolling_mae_vs_hull_dist(
415
417
)
416
418
ax .update_xaxes (range = x_lim )
417
419
ax .update_yaxes (range = y_lim )
418
- scatter_kwds = dict (fill = "toself" , opacity = 0.5 )
419
- err_gt_each_region = go . Scatter (
420
+ scatter_kwds = dict (fill = "toself" , opacity = 0.3 )
421
+ ax . add_scatter (
420
422
x = (- 1 , - dft_acc , dft_acc , 1 ),
421
423
y = (1 , dft_acc , dft_acc , 1 ),
422
424
name = "MAE > |E<sub>above hull</sub>|" ,
423
425
# fillcolor="yellow",
424
426
** scatter_kwds ,
425
427
)
426
- ml_err_lt_dft_err_region = go . Scatter (
428
+ ax . add_scatter (
427
429
x = (- dft_acc , dft_acc , 0 , - dft_acc ),
428
430
y = (dft_acc , dft_acc , 0 , dft_acc ),
429
431
name = "MAE < |DFT error|" ,
430
432
# fillcolor="red",
431
433
** scatter_kwds ,
432
434
)
433
- ax .add_traces ([err_gt_each_region , ml_err_lt_dft_err_region ])
434
435
ax .add_annotation (
435
- x = dft_acc ,
436
+ x = - dft_acc ,
436
437
y = dft_acc ,
437
- text = "<a href='https://doi.org/10.1103/PhysRevB.85.155208'>Corrected GGA DFT "
438
- "Accuracy</a>" ,
438
+ text = f"<a { href = } >Corrected GGA Accuracy</a>" ,
439
439
showarrow = True ,
440
- xshift = 10 ,
441
- arrowhead = 1 ,
442
- ax = 4 * dft_acc ,
443
- ay = dft_acc ,
440
+ xshift = - 10 ,
441
+ arrowhead = 2 ,
442
+ ax = - 4 * dft_acc ,
443
+ ay = 2 * dft_acc ,
444
444
axref = "x" ,
445
445
ayref = "y" ,
446
446
)
@@ -464,10 +464,9 @@ def rolling_mae_vs_hull_dist(
464
464
y0 = y0 ,
465
465
x1 = x0 + window ,
466
466
y1 = y0 + window / 5 ,
467
- fillcolor = line_color ,
468
467
)
469
468
470
- return ax
469
+ return ax , df_rolling_err , df_err_std
471
470
472
471
473
472
def cumulative_precision_recall (
0 commit comments