Skip to content

Commit a5b3211

Browse files
committed
change SI plot of largest model errors: Predicted vs. DFT hull distance colored by model disagreement
change y-axis from average model error to all-model mean
1 parent 6d30216 commit a5b3211

9 files changed

+69
-48
lines changed

matbench_discovery/preds.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
e_form_col = "e_form_per_atom_mp2020_corrected"
2020
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
2121
each_pred_col = "e_above_hull_pred"
22+
model_mean_each_col = "Mean prediction all models"
2223
model_mean_err_col = "Mean error all models"
2324
model_std_col = "Std. dev. over models"
2425

25-
26-
quantity_labels[model_mean_err_col] = f"{model_mean_err_col} {ev_per_atom}"
27-
quantity_labels[model_std_col] = f"{model_std_col} {ev_per_atom}"
26+
for col in (model_mean_each_col, model_mean_err_col, model_std_col):
27+
quantity_labels[col] = f"{col} {ev_per_atom}"
2828

2929

3030
class PredFiles(Files):
@@ -157,8 +157,18 @@ def load_df_wbm_with_preds(
157157
df_preds[each_true_col] + df_preds[model] - df_preds[e_form_col]
158158
)
159159

160+
# important: do df_each_pred.std(axis=1) before inserting
161+
# df_each_pred[model_mean_each_col]
162+
df_preds[model_std_col] = df_each_pred.std(axis=1)
163+
df_each_pred[model_mean_each_col] = df_preds[model_mean_each_col] = df_each_pred.mean(
164+
axis=1
165+
)
160166

161167
# dataframe of all models' errors in their EACH predictions (eV/atom)
162168
df_each_err = pd.DataFrame()
163169
for model in df_metrics.T.MAE.sort_values().index:
164170
df_each_err[model] = df_preds[model] - df_preds[e_form_col]
171+
172+
df_each_err[model_mean_err_col] = df_preds[model_mean_err_col] = df_each_err.abs().mean(
173+
axis=1
174+
)

pyproject.toml

+10-9
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,23 @@ authors = [{ name = "Janosh Riebesell", email = "[email protected]" }]
1010
readme = "readme.md"
1111
license = { file = "license" }
1212
keywords = [
13-
"materials discovery",
14-
"inorganic crystal stability",
15-
"machine learning",
16-
"interatomic potential",
1713
"Bayesian optimization",
18-
"high-throughput search",
1914
"convex hull",
15+
"high-throughput search",
16+
"inorganic crystal stability",
17+
"interatomic potential",
18+
"machine learning",
19+
"materials discovery",
2020
]
2121
classifiers = [
2222
"Intended Audience :: Science/Research",
2323
"License :: OSI Approved :: MIT License",
2424
"Operating System :: OS Independent",
25-
"Programming Language :: Python :: 3.11",
2625
"Programming Language :: Python :: 3.10",
26+
"Programming Language :: Python :: 3.11",
2727
"Programming Language :: Python :: 3.9",
28-
"Topic :: Scientific/Engineering :: Chemistry",
2928
"Topic :: Scientific/Engineering :: Artificial Intelligence",
29+
"Topic :: Scientific/Engineering :: Chemistry",
3030
"Topic :: Scientific/Engineering :: Physics",
3131
]
3232

@@ -58,6 +58,7 @@ running-models = [
5858
"chgnet",
5959
# torch needs to install before aviary
6060
"torch",
61+
6162
"aviary@git+https://github.com/CompRhys/aviary",
6263
"m3gnet",
6364
"maml",
@@ -82,7 +83,7 @@ select = [
8283
"B", # flake8-bugbear
8384
"C40", # flake8-comprehensions
8485
"D", # pydocstyle
85-
"E", # pycodestyle
86+
"E", # pycodestyle error
8687
"F", # pyflakes
8788
"I", # isort
8889
"N", # pep8-naming
@@ -97,7 +98,7 @@ select = [
9798
"SIM", # flake8-simplify
9899
"TID", # tidy imports
99100
"UP", # pyupgrade
100-
"W", # pycodestyle
101+
"W", # pycodestyle warning
101102
"YTT", # flake8-2020
102103
]
103104
ignore = [

scripts/analyze_model_failure_cases.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
df_metrics,
2626
df_preds,
2727
each_true_col,
28+
model_mean_each_col,
2829
model_mean_err_col,
2930
model_std_col,
3031
)
@@ -33,10 +34,6 @@
3334
__date__ = "2023-02-15"
3435

3536
models = list(df_each_pred)
36-
df_preds[model_std_col] = df_preds[models].std(axis=1)
37-
df_each_err[model_mean_err_col] = df_preds[model_mean_err_col] = df_each_err.abs().mean(
38-
axis=1
39-
)
4037
fp_diff_col = "site_stats_fingerprint_init_final_norm_diff"
4138

4239

@@ -181,14 +178,15 @@
181178
# on MP which is highly low-energy enriched.
182179
# also possible models failed to learn whatever physics makes these materials highly
183180
# unstable
181+
n_structs = 200
184182
fig = (
185-
df_preds.nlargest(200, model_mean_err_col)
183+
df_preds.nlargest(n_structs, model_mean_err_col)
186184
.round(2)
187185
.plot.scatter(
188186
x=each_true_col,
189-
y=model_mean_err_col,
187+
y=model_mean_each_col,
190188
color=model_std_col,
191-
size=n_examp_for_rarest_elem_col,
189+
size="n_sites",
192190
backend="plotly",
193191
hover_name="material_id",
194192
hover_data=["formula"],
@@ -197,17 +195,19 @@
197195
)
198196
# yanchor="bottom", y=1, xanchor="center", x=0.5, orientation="h", thickness=12
199197
fig.layout.coloraxis.colorbar.update(title_side="right", thickness=14)
198+
fig.layout.margin.update(l=0, r=30, b=0, t=30)
200199
add_identity_line(fig)
201-
fig.layout.title = (
202-
"Largest model errors vs. DFT hull distance colored by model disagreement"
200+
fig.layout.title.update(
201+
text=f"{n_structs} largest model errors: Predicted vs. DFT hull distance<br>"
202+
"colored by model disagreement",
203+
x=0.5,
203204
)
204205
# tried setting error_y=model_std_col but looks bad
205206
# fig.update_traces(error_y=dict(color="rgba(255,255,255,0.2)", width=3, thickness=2))
206207
fig.show()
207-
# save_fig(fig, f"{FIGS}/scatter-largest-errors-models-mean-vs-each-true.svelte")
208-
# save_fig(
209-
# fig, f"{ROOT}/tmp/figs/scatter-largest-errors-models-mean-vs-each-true.pdf"
210-
# )
208+
img_name = "scatter-largest-errors-models-mean-vs-true-hull-dist"
209+
save_fig(fig, f"{FIGS}/{img_name}.svelte")
210+
# save_fig(fig, f"{ROOT}/tmp/figs/{img_name}.pdf")
211211

212212

213213
# %% find materials that were misclassified by all models

scripts/compile_metrics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@
223223
"margin-bottom": "0",
224224
"margin-left": "0",
225225
# fit page size to content
226-
"page-width": f"{(len(styler.columns) + 1) * 10}",
227-
"page-height": f"{(len(styler.index) + 1) * 6}",
226+
"page-width": f"{(len(styler.columns) + 1) * 8.3}",
227+
"page-height": f"{(len(styler.index) + 1) * 5.5}",
228228
},
229229
)
230230

scripts/hist_classified_stable_vs_hull_dist.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from pymatviz.utils import save_fig
1313

14-
from matbench_discovery import FIGS
14+
from matbench_discovery import ROOT
1515
from matbench_discovery.data import df_wbm
1616
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist
1717
from matbench_discovery.preds import df_each_pred, each_true_col
@@ -21,11 +21,10 @@
2121

2222

2323
# %%
24-
model_name = "Wrenformer"
25-
model_name = "CHGNet"
26-
# model_name = "M3GNet"
27-
# model_name = "Voronoi RF"
28-
which_energy: Final = "true"
24+
# model_name = "Wrenformer"
25+
model_name = "CGCNN"
26+
# model_name = "CGCNN+P"
27+
which_energy: Final = "pred"
2928
df_each_pred[each_true_col] = df_wbm[each_true_col]
3029
backend: Final = "plotly"
3130

@@ -35,16 +34,18 @@
3534
each_pred_col=model_name,
3635
which_energy=which_energy,
3736
# stability_threshold=-0.05,
38-
# rolling_acc=None,
37+
rolling_acc=None,
3938
backend=backend,
4039
)
4140

4241
if backend == "plotly":
43-
fig.layout.title = model_name
42+
# fig.layout.title.update(text=model_name, x=0.5)
43+
fig.layout.margin.update(l=0, r=0, b=0, t=30)
44+
# fig.update_yaxes(range=[0, 12000])
4445
fig.show()
4546

4647

4748
# %%
48-
img_path = f"{FIGS}/hist-clf-{which_energy}-hull-dist-{model_name}"
49-
# save_fig(fig, f"{img_path}.svelte")
50-
save_fig(fig, f"{img_path}.webp")
49+
img_name = f"hist-clf-{which_energy}-hull-dist-{model_name}"
50+
# save_fig(fig, f"{FIGS}/{img_name}.svelte")
51+
save_fig(fig, f"{ROOT}/tmp/figs/{img_name}.pdf")

site/src/figs/scatter-largest-errors-models-mean-vs-true-hull-dist.svelte

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/routes/preprint/+page.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import CumulativeClfMetrics from '$figs/cumulative-clf-metrics.svelte'
66
import RollingMaeVsHullDistModels from '$figs/rolling-mae-vs-hull-dist-models.svelte'
77
import ElementErrorsPtableHeatmap from '$models/element-errors-ptable-heatmap.svelte'
8-
import HistClfTrueHullDistModels from '$figs/hist-clf-true-hull-dist-models.svelte'
8+
import HistClfTrueHullDistModels from '$figs/hist-clf-true-hull-dist-models-4x2.svelte'
99
import { onMount } from 'svelte'
1010

1111
let mounted = false

site/src/routes/si/+page.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<script lang="ts">
2-
import MetricsTableMegnetCombos from '$figs/metrics-table-megnet-combos.svelte'
2+
import MetricsTableMegnetUipCombos from '$figs/metrics-table-megnet-uip-combos.svelte'
33
import MetricsTableFirst10k from '$figs/metrics-table-first-10k.svelte'
44
import RunTimeBars from '$figs/model-run-times-bar.svelte'
55
import RocModels from '$figs/roc-models.svelte'
@@ -14,7 +14,7 @@
1414
import HistClfPredHullDistModels from '$figs/hist-clf-pred-hull-dist-models-4x2.svelte'
1515
import SpacegroupSunburstWbm from '$figs/spacegroup-sunburst-wbm.svelte'
1616
import SpacegroupSunburstWrenformerFailures from '$figs/spacegroup-sunburst-wrenformer-failures.svelte'
17-
import ScatterLargestErrorsModelsMeanVsEachTrue from '$figs/scatter-largest-errors-models-mean-vs-each-true.svelte'
17+
import ScatterLargestErrorsModelsMeanVsTrueHullDist from '$figs/scatter-largest-errors-models-mean-vs-true-hull-dist.svelte'
1818
import EAboveHullScatterWrenformerFailures from '$figs/e-above-hull-scatter-wrenformer-failures.svelte'
1919
import ProtoCountsWrenformerFailures from '$figs/proto-counts-wrenformer-failures.svelte'
2020
import ElementPrevalenceVsError from '$figs/element-prevalence-vs-error.svelte'
@@ -99,19 +99,19 @@ Given its strong performance on batch 1, it is possible that given sufficiently
9999
## Largest Errors vs DFT Hull Distance
100100

101101
{#if mounted}
102-
<ScatterLargestErrorsModelsMeanVsEachTrue />
102+
<ScatterLargestErrorsModelsMeanVsTrueHullDist />
103103
{/if}
104104

105-
> @label:fig:scatter-largest-errors-models-mean-vs-each-true The 200 structures with largest error averaged over all models vs their DFT hull distance colored by model disagreement (as measured by standard deviation in hull distance predictions from different models) and sized by number of training structures containing the least prevalent element (e.g. if a scatter point had composition FeO, MP has 6.6k structures containing Fe and 82k containing O so its size would be set to 6.6k). Thus smaller points have less training support. This plot suggests all models are biased to predict low energy and perhaps fail to capture certain physics resulting in highly unstable structures. This is unsurprising considering MP training data mainly consists of low energy structures.<br>
106-
> It is also possible that some of the blue points with large error yet good agreement among models are in fact accurate ML predictions for a DFT relaxation gone wrong.
105+
> @label:fig:scatter-largest-errors-models-mean-vs-true-hull-dist DFT vs predicted hull distance (average over all models) for the 200 largest error structures colored by model disagreement (as measured by standard deviation in hull distance predictions from different models) and sized by number of atoms in the structures. This plot shows that high-error predictions are biased towards predicting too small hull distance. This is unsurprising considering MP training data mainly consists of low-energy structures.<br>
106+
> However, note the clear color separation between the mostly blue low-energy-bias predictions and the yellow/red high error prediction. Blue means models are in good agreement, i.e. all models are "wrong" together. Red/yellow are large-error predictions with little model agreement, i.e. all models are wrong in different ways. It is possible that some of the blue points with large error yet good agreement among models are in fact accurate ML predictions for a DFT relaxation gone wrong. Zooming in on the blue points reveals that many of them are large. Larger markers correspond to larger structures where DFT failures are less surprising. This suggests ML model committees could be used to cheaply screen large databases for DFT errors in a high-throughput manner.
107107
108108
## MEGNet formation energies from UIP-relaxed structures
109109

110110
{#if mounted}
111-
<MetricsTableMegnetCombos select={[`model`, `MEGNet`, `CHGNet`, `M3GNet`, `CHGNet + MEGNet`, `M3GNet + MEGNet`]} />
111+
<MetricsTableMegnetUipCombos select={[`model`, `MEGNet`, `CHGNet`, `M3GNet`, `CHGNet + MEGNet`, `M3GNet + MEGNet`]} />
112112
{/if}
113113

114-
> @label:fig:metrics-table-megnet-combos This table shows metrics obtained by combining MEGNet with both UIPs. The metrics in rows labeled M3GNet + MEGNet and CHGNet + MEGNet are the result of passing M3GNet/CHGNet-relaxed structures into MEGNet for formation energy prediction. Both combos perform worse than using the respective UIPs on their own with a more pronounced performance drop from CHGNet to CHGNet + MEGNet than M3GNet to M3GNet + MEGnet. This suggests MEGNet has learned no additional knowledge of the PES that is not already present in the UIPs. However, both combos perform better than MEGNet on its own, demonstrating that UIP relaxation provides real utility at very low cost for any downstream structure-dependent analysis.
114+
> @label:fig:metrics-table-megnet-uip-combos This table shows metrics obtained by combining MEGNet with both UIPs. The metrics in rows labeled M3GNet + MEGNet and CHGNet + MEGNet are the result of passing M3GNet/CHGNet-relaxed structures into MEGNet for formation energy prediction. Both combos perform worse than using the respective UIPs on their own with a more pronounced performance drop from CHGNet to CHGNet + MEGNet than M3GNet to M3GNet + MEGnet. This suggests MEGNet has learned no additional knowledge of the PES that is not already present in the UIPs. However, both combos perform better than MEGNet on its own, demonstrating that UIP relaxation provides real utility at very low cost for any downstream structure-dependent analysis.
115115
116116
The UIPs M3GNet and CHGNet are both trained to predict DFT energies (including/excluding MP2020 energy corrections for CHGNet/M3GNet) while MEGNet is trained to predict formation energies.
117117

tests/test_preds.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
e_form_col,
1212
each_true_col,
1313
load_df_wbm_with_preds,
14+
model_mean_each_col,
15+
model_mean_err_col,
1416
)
1517

1618

@@ -29,13 +31,19 @@ def test_df_metrics() -> None:
2931

3032
def test_df_each_pred() -> None:
3133
assert len(df_each_pred) == len(df_wbm)
32-
assert {*df_each_pred} == {*df_metrics}, "df_each_pred has wrong columns"
34+
assert {*df_each_pred} == {
35+
*df_metrics,
36+
model_mean_each_col,
37+
}, "df_each_pred has wrong columns"
3338
assert all(df_each_pred.isna().mean() < 0.05), "too many NaNs in df_each_pred"
3439

3540

3641
def test_df_each_err() -> None:
3742
assert len(df_each_err) == len(df_wbm)
38-
assert {*df_each_err} == {*df_metrics}, "df_each_err has wrong columns"
43+
assert {*df_each_err} == {
44+
*df_metrics,
45+
model_mean_err_col,
46+
}, "df_each_err has wrong columns"
3947
assert all(df_each_err.isna().mean() < 0.05), "too many NaNs in df_each_err"
4048

4149

0 commit comments

Comments
 (0)