|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from typing import Literal, Sequence |
| 3 | +from typing import Any, Literal, Sequence |
4 | 4 |
|
5 | 5 | import matplotlib.pyplot as plt
|
| 6 | +import numpy as np |
6 | 7 | import pandas as pd
|
| 8 | +from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar |
| 9 | +from scipy.stats import sem as std_err_of_mean |
7 | 10 |
|
8 | 11 |
|
9 | 12 | __author__ = "Janosh Riebesell"
|
|
12 | 15 |
|
13 | 16 | plt.rc("savefig", bbox="tight", dpi=200)
|
14 | 17 | plt.rcParams["figure.constrained_layout.use"] = True
|
15 |
| -plt.rc("figure", dpi=150) |
| 18 | +plt.rc("figure", dpi=200) |
16 | 19 | plt.rc("font", size=16)
|
17 | 20 |
|
18 | 21 |
|
@@ -137,3 +140,112 @@ def hist_classified_stable_as_func_of_hull_dist(
|
137 | 140 | ax.set(xlabel=xlabel, ylabel="Number of compounds")
|
138 | 141 |
|
139 | 142 | return ax
|
| 143 | + |
| 144 | + |
| 145 | +def rolling_mae_vs_hull_dist( |
| 146 | + df: pd.DataFrame, |
| 147 | + e_above_hull_col: str, |
| 148 | + residual_col: str = "residual", |
| 149 | + half_window: float = 0.02, |
| 150 | + increment: float = 0.002, |
| 151 | + x_lim: tuple[float, float] = (-0.2, 0.3), |
| 152 | + ax: plt.Axes = None, |
| 153 | + **kwargs: Any, |
| 154 | +) -> plt.Axes: |
| 155 | + """Rolling mean absolute error as the energy to the convex hull is varied. A scale |
| 156 | + bar is shown for the windowing period of 40 meV per atom used when calculating |
| 157 | + the rolling MAE. The standard error in the mean is shaded |
| 158 | + around each curve. The highlighted V-shaped region shows the area in which the |
| 159 | + average absolute error is greater than the energy to the known convex hull. This is |
| 160 | + where models are most at risk of misclassifying structures. |
| 161 | + """ |
| 162 | + if ax is None: |
| 163 | + ax = plt.gca() |
| 164 | + |
| 165 | + ax_is_fresh = len(ax.lines) == 0 |
| 166 | + |
| 167 | + bins = np.arange(*x_lim, increment) |
| 168 | + |
| 169 | + rolling_maes = np.zeros_like(bins) |
| 170 | + rolling_stds = np.zeros_like(bins) |
| 171 | + df = df.sort_values(by=e_above_hull_col) |
| 172 | + for idx, bin_center in enumerate(bins): |
| 173 | + low = bin_center - half_window |
| 174 | + high = bin_center + half_window |
| 175 | + |
| 176 | + mask = (df[e_above_hull_col] <= high) & (df[e_above_hull_col] > low) |
| 177 | + rolling_maes[idx] = df[residual_col].loc[mask].abs().mean() |
| 178 | + rolling_stds[idx] = std_err_of_mean(df[residual_col].loc[mask].abs()) |
| 179 | + |
| 180 | + ax.plot(bins, rolling_maes, **kwargs) |
| 181 | + |
| 182 | + ax.fill_between( |
| 183 | + bins, rolling_maes + rolling_stds, rolling_maes - rolling_stds, alpha=0.3 |
| 184 | + ) |
| 185 | + |
| 186 | + if not ax_is_fresh: |
| 187 | + # return earlier if all plot objects besides the line were already drawn by a |
| 188 | + # previous call |
| 189 | + return ax |
| 190 | + |
| 191 | + scale_bar = AnchoredSizeBar( |
| 192 | + ax.transData, |
| 193 | + 2 * half_window, |
| 194 | + "40 meV", |
| 195 | + "lower left", |
| 196 | + pad=0.5, |
| 197 | + frameon=False, |
| 198 | + size_vertical=0.002, |
| 199 | + ) |
| 200 | + |
| 201 | + ax.add_artist(scale_bar) |
| 202 | + |
| 203 | + ax.plot((0.05, 0.5), (0.05, 0.5), color="grey", linestyle="--", alpha=0.3) |
| 204 | + ax.plot((-0.5, -0.05), (0.5, 0.05), color="grey", linestyle="--", alpha=0.3) |
| 205 | + ax.plot((-0.05, 0.05), (0.05, 0.05), color="grey", linestyle="--", alpha=0.3) |
| 206 | + ax.plot((-0.1, 0.1), (0.1, 0.1), color="grey", linestyle="--", alpha=0.3) |
| 207 | + |
| 208 | + ax.fill_between( |
| 209 | + (-0.5, -0.05, 0.05, 0.5), |
| 210 | + (0.5, 0.5, 0.5, 0.5), |
| 211 | + (0.5, 0.05, 0.05, 0.5), |
| 212 | + color="tab:red", |
| 213 | + alpha=0.2, |
| 214 | + ) |
| 215 | + |
| 216 | + ax.plot((0, 0.05), (0, 0.05), color="grey", linestyle="--", alpha=0.3) |
| 217 | + ax.plot((-0.05, 0), (0.05, 0), color="grey", linestyle="--", alpha=0.3) |
| 218 | + |
| 219 | + ax.fill_between( |
| 220 | + (-0.05, 0, 0.05), |
| 221 | + (0.05, 0.05, 0.05), |
| 222 | + (0.05, 0, 0.05), |
| 223 | + color="tab:orange", |
| 224 | + alpha=0.2, |
| 225 | + ) |
| 226 | + |
| 227 | + arrowprops = dict(facecolor="black", width=0.5, headwidth=5, headlength=5) |
| 228 | + ax.annotate( |
| 229 | + xy=(0.055, 0.05), |
| 230 | + xytext=(0.12, 0.05), |
| 231 | + arrowprops=arrowprops, |
| 232 | + text="Corrected\nGGA DFT\nAccuracy", |
| 233 | + verticalalignment="center", |
| 234 | + horizontalalignment="left", |
| 235 | + ) |
| 236 | + ax.annotate( |
| 237 | + xy=(0.105, 0.1), |
| 238 | + xytext=(0.16, 0.1), |
| 239 | + arrowprops=arrowprops, |
| 240 | + text="GGA DFT\nAccuracy", |
| 241 | + verticalalignment="center", |
| 242 | + horizontalalignment="left", |
| 243 | + ) |
| 244 | + |
| 245 | + ax.text(0, 0.13, r"$|\Delta E_{Hull-MP}| > $MAE", horizontalalignment="center") |
| 246 | + |
| 247 | + ax.set(xlabel=r"$\Delta E_{Hull-MP}$ / eV per atom", ylabel="MAE / eV per atom") |
| 248 | + |
| 249 | + ax.set(xlim=x_lim, ylim=(0.0, 0.14)) |
| 250 | + |
| 251 | + return ax |
0 commit comments