Skip to content

Commit 457abf0

Browse files
committed
add scatter plot of largest errors averaged over models vs DFT hull distance
improve rolling-mae-vs-hull-dist-wbm-batches-models caption update deps move metric table rows showing MEGNet combos with M3GNet and CHNGet to SI add prop hide: string[] to metrics table to hide rows with matching headers
1 parent 8a0c3bb commit 457abf0

25 files changed

+336
-281
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/charliermarsh/ruff-pre-commit
10-
rev: v0.0.260
10+
rev: v0.0.261
1111
hooks:
1212
- id: ruff
1313
args: [--fix]
@@ -57,7 +57,7 @@ repos:
5757
- prettier
5858
- prettier-plugin-svelte
5959
- svelte
60-
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yml|yaml|json))$
60+
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yaml|json))$
6161

6262
- repo: https://github.com/pre-commit/mirrors-eslint
6363
rev: v8.37.0

data/wbm/eda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@
181181
# weight in each WBM composition. TLDR: no obvious structure in the data
182182
# was hoping to find certain clusters to have higher or lower errors after seeing
183183
# many models struggle on the halogens in per-element error periodic table heatmaps
184-
# https://matbench-discovery.janosh.dev/models
184+
# https://janosh.github.io/matbench-discovery/models
185185
df_2d_tsne = pd.read_csv(f"{module_dir}/tsne/one-hot-112-composition-2d.csv.gz")
186186
df_2d_tsne = df_2d_tsne.set_index("material_id")
187187

matbench_discovery/preds.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from matbench_discovery import ROOT
1010
from matbench_discovery.data import Files, glob_to_df
1111
from matbench_discovery.metrics import stable_metrics
12-
from matbench_discovery.plots import model_labels
12+
from matbench_discovery.plots import eVpa, model_labels, quantity_labels
1313

1414
"""Centralize data-loading and computing metrics for plotting scripts"""
1515

@@ -19,7 +19,12 @@
1919
e_form_col = "e_form_per_atom_mp2020_corrected"
2020
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
2121
each_pred_col = "e_above_hull_pred"
22-
model_mean_err_col = "Mean over models"
22+
model_mean_err_col = "Mean error all models"
23+
model_std_col = "Std. dev. over models"
24+
25+
26+
quantity_labels[model_mean_err_col] = f"{model_mean_err_col} {eVpa}"
27+
quantity_labels[model_std_col] = f"{model_std_col} {eVpa}"
2328

2429

2530
class PredFiles(Files):
@@ -34,8 +39,6 @@ class PredFiles(Files):
3439
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
3540
# default CHGNet model from publication with 400,438 params
3641
chgnet = "chgnet/2023-03-06-chgnet-wbm-IS2RE.csv"
37-
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
38-
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
3942

4043
# CGCnn 10-member ensemble
4144
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
@@ -44,11 +47,13 @@ class PredFiles(Files):
4447

4548
# original M3GNet straight from publication, not re-trained
4649
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
47-
# M3GNet-relaxed structures fed into MEGNet for formation energy prediction
48-
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
4950

5051
# original MEGNet straight from publication, not re-trained
5152
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
53+
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
54+
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
55+
# M3GNet-relaxed structures fed into MEGNet for formation energy prediction
56+
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
5257

5358
# Magpie composition+Voronoi tessellation structure features + sklearn random forest
5459
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"

readme.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ Matbench Discovery
1717
1818
Matbench Discovery is an [interactive leaderboard](https://janosh.github.io/matbench-discovery) and associated [PyPI package](https://pypi.org/project/matbench-discovery) which together make it easy to benchmark ML energy models on a task designed to closely simulate a high-throughput discovery campaign for new stable inorganic crystals.
1919

20-
In version 1 of this benchmark, we explore 8 models covering multiple methodologies ranging from random forests to graph neural networks, from one-shot predictors to iterative Bayesian optimizers and interatomic potential-based relaxers. We find [M3GNet](https://github.com/materialsvirtuallab/m3gnet) ([paper](https://doi.org/10.1038/s43588-022-00349-3)) to achieve the highest F1 score of 0.58 and $R^2$ of 0.59 while [MEGNet](https://github.com/materialsvirtuallab/megnet) ([paper](https://doi.org/10.1021/acs.chemmater.9b01294)) wins on discovery acceleration factor (DAF) with 2.94. See the [**full results**](https://matbench-discovery.janosh.dev/paper#results) in our interactive dashboard which provides valuable insights for maintainers of large-scale materials databases. We show these models have become powerful enough to warrant deploying them as triaging steps to more effectively allocate compute in high-throughput DFT relaxations.
20+
In version 1 of this benchmark, we explore 8 models covering multiple methodologies ranging from random forests to graph neural networks, from one-shot predictors to iterative Bayesian optimizers and interatomic potential-based relaxers. We find [CHGNet](https://github.com/CederGroupHub/chgnet) ([paper](https://doi.org/10.48550/arXiv.2302.14231)) to achieve the highest F1 score of 0.59, $R^2$ of 0.61 and a discovery acceleration factor (DAF) of 3.06 (meaning a 3x higher rate of stable structures compared to dummy selection in our already enriched search space). See the [**full results**](https://janosh.github.io/matbench-discovery/paper#results) in our interactive dashboard which provides valuable insights for maintainers of large-scale materials databases. We show these models have become powerful enough to warrant deploying them as triaging steps to more effectively allocate compute in high-throughput DFT relaxations.
2121

2222
<slot name="metrics-table" />
2323

2424
We welcome contributions that add new models to the leaderboard through [GitHub PRs](https://github.com/janosh/matbench-discovery/pulls). See the [usage and contributing guide](https://janosh.github.io/matbench-discovery/contribute) for details.
2525

2626
For a version 2 release of this benchmark, we plan to merge the current training and test sets into the new training set and acquire a much larger test set (potentially at meta-GGA level of theory) compared to the v1 test set of 257k structures. Anyone interested in joining this effort please [open a GitHub discussion](https://github.com/janosh/matbench-discovery/discussions) or [reach out privately](mailto:[email protected]?subject=Matbench%20Discovery).
2727

28-
For detailed results and analysis, check out the [paper](https://matbench-discovery.janosh.dev/paper) and [supplementary material](https://matbench-discovery.janosh.dev/si).
28+
For detailed results and analysis, check out the [paper](https://janosh.github.io/matbench-discovery/paper) and [supplementary material](https://janosh.github.io/matbench-discovery/si).

scripts/analyze_element_errors.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,8 @@
181181
# %% plot EACH errors against least prevalent element in structure (by occurrence in
182182
# MP training set). this seems to correlate more with model error
183183
n_examp_for_rarest_elem_col = "Examples for rarest element in structure"
184-
df_wbm["composition"] = df_wbm.get("composition", df_wbm.formula.map(Composition))
185-
df_elem_err.loc[list(map(str, df_wbm.composition[0]))][train_count_col].min()
186184
df_wbm[n_examp_for_rarest_elem_col] = [
187-
df_elem_err.loc[list(map(str, Composition(formula)))][train_count_col].min()
185+
df_elem_err[train_count_col].loc[list(map(str, Composition(formula)))].min()
188186
for formula in tqdm(df_wbm.formula)
189187
]
190188

scripts/analyze_model_failure_cases.py

+59-24
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626
df_preds,
2727
each_true_col,
2828
model_mean_err_col,
29+
model_std_col,
2930
)
3031

3132
__author__ = "Janosh Riebesell"
3233
__date__ = "2023-02-15"
3334

35+
models = list(df_each_pred)
36+
df_preds[model_std_col] = df_preds[models].std(axis=1)
3437
df_each_err[model_mean_err_col] = df_preds[model_mean_err_col] = df_each_err.abs().mean(
3538
axis=1
3639
)
@@ -158,29 +161,53 @@
158161
# save_fig(fig, f"{FIGS}/scatter-largest-each-errors-fp-diff-models.svelte")
159162

160163

161-
# %% plotly scatter plot of largest model errors with points sized by mean error and
162-
# colored by true stability.
163-
# while some points lie on a horizontal line of constant error, more follow the identity
164-
# line suggesting the models failed to learn the true physics in these materials
165-
fig = df_preds.nlargest(200, model_mean_err_col).plot.scatter(
166-
x=each_true_col,
167-
y=model_mean_err_col,
168-
color=each_true_col,
169-
size=model_mean_err_col,
170-
backend="plotly",
164+
# %%
165+
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id")
166+
train_count_col = "MP Occurrences"
167+
df_elem_counts = count_elements(df_mp.formula_pretty, count_mode="occurrence").to_frame(
168+
name=train_count_col
171169
)
172-
fig.layout.coloraxis.colorbar.update(
173-
title="DFT distance to convex hull (eV/atom)",
174-
title_side="top",
175-
yanchor="bottom",
176-
y=1,
177-
xanchor="center",
178-
x=0.5,
179-
orientation="h",
180-
thickness=12,
170+
n_examp_for_rarest_elem_col = "Examples for rarest element in structure"
171+
df_wbm[n_examp_for_rarest_elem_col] = [
172+
df_elem_counts[train_count_col].loc[list(map(str, Composition(formula)))].min()
173+
for formula in tqdm(df_wbm.formula)
174+
]
175+
df_preds[n_examp_for_rarest_elem_col] = df_wbm[n_examp_for_rarest_elem_col]
176+
177+
178+
# %% scatter plot of largest model errors vs. DFT hull distance
179+
# while some points lie on a horizontal line of constant error, more follow the identity
180+
# line showing models are biased to predict low energies likely as a result of training
181+
# on MP which is highly low-energy enriched.
182+
# also possible models failed to learn whatever physics makes these materials highly
183+
# unstable
184+
fig = (
185+
df_preds.nlargest(200, model_mean_err_col)
186+
.round(2)
187+
.plot.scatter(
188+
x=each_true_col,
189+
y=model_mean_err_col,
190+
color=model_std_col,
191+
size=n_examp_for_rarest_elem_col,
192+
backend="plotly",
193+
hover_name="material_id",
194+
hover_data=["formula"],
195+
color_continuous_scale="Turbo",
196+
)
181197
)
198+
# yanchor="bottom", y=1, xanchor="center", x=0.5, orientation="h", thickness=12
199+
fig.layout.coloraxis.colorbar.update(title_side="right", thickness=14)
182200
add_identity_line(fig)
201+
fig.layout.title = (
202+
"Largest model errors vs. DFT hull distance colored by model disagreement"
203+
)
204+
# tried setting error_y=model_std_col but looks bad
205+
# fig.update_traces(error_y=dict(color="rgba(255,255,255,0.2)", width=3, thickness=2))
183206
fig.show()
207+
# save_fig(fig, f"{FIGS}/scatter-largest-errors-models-mean-vs-each-true.svelte")
208+
# save_fig(
209+
# fig, f"{ROOT}/tmp/figures/scatter-largest-errors-models-mean-vs-each-true.pdf"
210+
# )
184211

185212

186213
# %% find materials that were misclassified by all models
@@ -203,16 +230,24 @@
203230

204231

205232
# %%
233+
normalized = True
206234
elem_counts: dict[str, pd.Series] = {}
207235
for col in ("All models false neg", "All models false pos"):
208236
elem_counts[col] = elem_counts.get(
209237
col, count_elements(df_preds[df_preds[col]].formula)
210238
)
211-
fig = ptable_heatmap_plotly(elem_counts[col], font_size=10)
212-
fig.layout.title = col
213-
fig.layout.margin.update(l=0, r=0, t=50, b=0)
239+
fig = ptable_heatmap_plotly(
240+
elem_counts[col] / df_elem_counts[train_count_col]
241+
if normalized
242+
else elem_counts[col],
243+
color_bar=dict(title=col),
244+
precision=".3f",
245+
cscale_range=[0, 0.1],
246+
)
214247
fig.show()
215248

249+
# TODO plot these for each model individually
250+
216251

217252
# %% map abs EACH model errors onto elements in structure weighted by composition
218253
# fraction and average over all test set structures
@@ -234,8 +269,8 @@
234269
# df_frac_comp = df_frac_comp.dropna(axis=1, thresh=100) # remove Xe with only 1 entry
235270

236271

237-
# %% TODO investigate if structures with largest mean over models error can be
238-
# attributed to DFT gone wrong. would be cool if models can be run across large
272+
# %% TODO investigate if structures with largest mean error across all models error can
273+
# be attributed to DFT gone wrong. would be cool if models can be run across large
239274
# databases as correctness checkers
240275
df_each_err.abs().mean().sort_values()
241276
df_each_err.abs().mean(axis=1).nlargest(25)

scripts/compile_metrics.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,15 @@
111111

112112
# %%
113113
time_cols = list(df_stats.filter(like=time_col))
114-
for col in time_cols: # uncomment to include run times in metrics table
115-
df_metrics.loc[col] = df_stats[col]
114+
# for col in time_cols: # uncomment to include run times in metrics table
115+
# df_metrics.loc[col] = df_stats[col]
116116
higher_is_better = {"DAF", "R²", "Precision", "F1", "Accuracy", "TPR", "TNR"}
117-
lower_is_better = {"MAE", "RMSE", "FNR", "FPR", *time_cols}
117+
lower_is_better = {"MAE", "RMSE", "FNR", "FPR"}
118+
df_metrics = df_metrics.rename(index={"R2": "R²"})
118119
idx_set = set(df_metrics.index)
120+
119121
styler = (
120-
df_metrics.T.rename(columns={"R2": "R²"})
122+
df_metrics.T
121123
# append arrow up/down to table headers to indicate higher/lower metric is better
122124
# .rename(columns=lambda x: x + " ↑" if x in higher_is_better else x + " ↓")
123125
.style.format(precision=2)
@@ -141,10 +143,12 @@
141143
styler.hide(["Recall", "FPR", "FNR"], axis=1)
142144

143145

144-
# %% export model metrics as styled HTML table
146+
# %% export model metrics as styled HTML table and Svelte component
147+
styler.to_html(f"{ROOT}/tmp/figures/model-metrics.html")
148+
145149
# insert svelte {...props} forwarding to the table element
146150
insert = """
147-
<script>
151+
<script lang="ts">
148152
import { sortable } from 'svelte-zoo/actions'
149153
</script>
150154

scripts/cumulative_clf_metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# fig.suptitle(title)
3636
fig.text(0.5, -0.08, x_label, ha="center", fontdict={"size": 16})
3737
if backend == "plotly":
38-
fig.layout.legend.update(x=0.01, y=0.01, bgcolor="rgba(0,0,0,0)")
38+
fig.layout.legend.update(x=0, y=0, bgcolor="rgba(0,0,0,0)")
3939
fig.layout.margin.update(l=0, r=5, t=30, b=50)
4040
fig.add_annotation(
4141
x=0.5,

scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# %%
77
from pymatviz.utils import save_fig
88

9-
from matbench_discovery import FIGS, today
9+
from matbench_discovery import FIGS, ROOT, today
1010
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist
1111
from matbench_discovery.preds import df_each_pred, df_preds, e_form_col, each_true_col
1212

@@ -16,10 +16,10 @@
1616
batch_col = "batch_idx"
1717
df_each_pred[batch_col] = "Batch " + df_each_pred.index.str.split("-").str[1]
1818
df_err, df_std = None, None # variables to cache rolling MAE and std
19+
model = "MEGNet"
1920

2021

2122
# %% matplotlib
22-
model = "Wrenformer"
2323
fig, ax = plt.subplots(1, figsize=(10, 9))
2424
markers = ("o", "v", "^", "H", "D")
2525
assert len(markers) == 5 # number of iterations of element substitution in WBM data set
@@ -54,7 +54,6 @@
5454

5555

5656
# %% plotly
57-
model = "CHGNet"
5857
df_pivot = df_each_pred.pivot(columns=batch_col, values=model)
5958

6059
# unstack two-level column index into new model column
@@ -81,4 +80,4 @@
8180
file_model = model.lower().replace(" + ", "-").replace(" ", "-")
8281
img_path = f"{file_model}-rolling-mae-vs-hull-dist-wbm-batches"
8382
save_fig(fig, f"{FIGS}/{img_path}.svelte")
84-
# save_fig(f"{ROOT}/tmp/figures/{img_path}.pdf")
83+
save_fig(fig, f"{ROOT}/tmp/figures/{img_path}.pdf")

site/package.json

+8-8
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
"@iconify/svelte": "^3.1.0",
2020
"@rollup/plugin-yaml": "^4.0.1",
2121
"@sveltejs/adapter-static": "^2.0.1",
22-
"@sveltejs/kit": "^1.14.0",
23-
"@sveltejs/vite-plugin-svelte": "^2.0.3",
24-
"@typescript-eslint/eslint-plugin": "^5.57.0",
25-
"@typescript-eslint/parser": "^5.57.0",
22+
"@sveltejs/kit": "^1.15.1",
23+
"@sveltejs/vite-plugin-svelte": "^2.0.4",
24+
"@typescript-eslint/eslint-plugin": "^5.57.1",
25+
"@typescript-eslint/parser": "^5.57.1",
2626
"elementari": "^0.1.5",
2727
"eslint": "^8.37.0",
2828
"eslint-plugin-svelte3": "^4.0.0",
@@ -36,15 +36,15 @@
3636
"rehype-katex-svelte": "^1.1.2",
3737
"rehype-slug": "^5.1.0",
3838
"remark-math": "3.0.0",
39-
"svelte": "^3.57.0",
40-
"svelte-check": "^3.1.4",
39+
"svelte": "^3.58.0",
40+
"svelte-check": "^3.2.0",
4141
"svelte-multiselect": "^8.6.0",
4242
"svelte-preprocess": "^5.0.3",
4343
"svelte-toc": "^0.5.4",
4444
"svelte-zoo": "^0.4.3",
45-
"svelte2tsx": "^0.6.10",
45+
"svelte2tsx": "^0.6.11",
4646
"tslib": "^2.5.0",
47-
"typescript": "5.0.2",
47+
"typescript": "5.0.3",
4848
"vite": "^4.2.1"
4949
},
5050
"prettier": {

site/src/app.html

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
></script>
3232
<link rel="stylesheet" href="/prism-vsc-dark-plus.css" />
3333
<!-- interactive plots -->
34-
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
34+
<script src="https://cdn.plot.ly/plotly-2.20.0.min.js"></script>
3535
<!-- math display -->
3636
<link
3737
rel="stylesheet"

site/src/figs/each-error-vs-least-prevalent-element-in-struct.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)