Skip to content

Commit e7d57ee

Browse files
committed
add color bar labels to MP/WBM/MPtrj ptable element occurrence heatmaps
tweak plot scripts update site deps, esp. elementari to fix black text on ptable element tiles missing data
1 parent 12e8477 commit e7d57ee

36 files changed

+139
-86
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
rev: v0.1.9
10+
rev: v0.1.13
1111
hooks:
1212
- id: ruff
1313
args: [--fix]
@@ -56,7 +56,7 @@ repos:
5656
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yaml|json)|changelog.md)$
5757

5858
- repo: https://github.com/pre-commit/mirrors-eslint
59-
rev: v8.56.0
59+
rev: v9.0.0-alpha.0
6060
hooks:
6161
- id: eslint
6262
types: [file]

data/wbm/eda_wbm.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,21 @@
1414
spacegroup_sunburst,
1515
)
1616
from pymatviz.io import save_fig
17-
from pymatviz.utils import si_fmt
17+
from pymatviz.utils import si_fmt, si_fmt_int
1818

1919
from matbench_discovery import (
2020
PDF_FIGS,
2121
ROOT,
2222
SITE_FIGS,
2323
STABILITY_THRESHOLD,
24+
e_form_raw_col,
2425
formula_col,
2526
id_col,
2627
)
2728
from matbench_discovery import plots as plots
2829
from matbench_discovery.data import DATA_FILES, df_wbm
2930
from matbench_discovery.energy import mp_elem_reference_entries
30-
from matbench_discovery.preds import df_each_err, each_true_col
31+
from matbench_discovery.preds import df_each_err, e_form_col, each_true_col
3132

3233
__author__ = "Janosh Riebesell"
3334
__date__ = "2023-03-30"
@@ -141,8 +142,8 @@
141142

142143
# %% histogram of energy distance to MP convex hull for WBM
143144
e_col = each_true_col # or e_form_col
144-
e_col = "e_form_per_atom_uncorrected"
145-
e_col = "e_form_per_atom_mp2020_corrected"
145+
# e_col = e_form_raw_col
146+
# e_col = e_form_col
146147
mean, std = df_wbm[e_col].mean(), df_wbm[e_col].std()
147148

148149
range_x = (mean - 2 * std, mean + 2 * std)
@@ -170,9 +171,28 @@
170171
dummy_mae = (df_wbm[e_col] - df_wbm[e_col].mean()).abs().mean()
171172

172173
title = (
173-
f"{len(df_wbm.dropna()):,} structures with {n_stable:,} stable + {n_unstable:,}"
174+
f"{si_fmt_int(len(df_wbm.dropna()))} structures with {si_fmt_int(n_stable)} "
175+
f"stable + {si_fmt_int(n_unstable)} unstable (stable rate="
176+
f"{n_stable / len(df_wbm):.1%})"
174177
)
175-
fig.layout.title = dict(text=title, x=0.5)
178+
fig.layout.title = dict(text=title, x=0.5, font_size=16, y=0.95)
179+
180+
# add red/blue annotations to left and right of mean saying stable/unstable
181+
for idx, (label, x_pos) in enumerate(
182+
(("stable", mean - std), ("unstable", mean + std))
183+
):
184+
fig.add_annotation(
185+
x=x_pos,
186+
y=0.5,
187+
text=label,
188+
showarrow=False,
189+
font_size=18,
190+
font_color=px.colors.qualitative.Plotly[idx],
191+
yref="paper",
192+
xanchor="right",
193+
xshift=-40,
194+
)
195+
176196

177197
fig.layout.margin = dict(l=0, r=0, b=0, t=40)
178198
fig.update_layout(showlegend=False)
@@ -183,14 +203,19 @@
183203
(mean + std, f"{mean + std = :.2f}"),
184204
):
185205
anno = dict(text=label, yshift=-10, xshift=-5, xanchor="right")
186-
line_width = 1 if x_pos == mean else 0.5
206+
line_width = 3 if x_pos == mean else 2
187207
fig.add_vline(x=x_pos, line=dict(width=line_width, dash="dash"), annotation=anno)
188208

189209
fig.show()
190-
191-
save_fig(fig, f"{SITE_FIGS}/hist-wbm-hull-dist.svelte")
192-
# save_fig(fig, "./figs/hist-wbm-hull-dist.svg", width=1000, height=500)
193-
save_fig(fig, f"{PDF_FIGS}/hist-wbm-hull-dist.pdf")
210+
suffix = {
211+
each_true_col: "hull-dist",
212+
e_form_col: "e-form",
213+
e_form_raw_col: "e-form-uncorrected",
214+
}[e_col]
215+
img_name = f"hist-wbm-{suffix}"
216+
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
217+
# save_fig(fig, f"./figs/{img_name}.svg", width=800, height=500)
218+
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=600, height=300)
194219

195220

196221
# %%

data/wbm/fetch_process_wbm_dataset.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919
from pymatviz.io import save_fig
2020
from tqdm import tqdm
2121

22-
from matbench_discovery import PDF_FIGS, SITE_FIGS, formula_col, id_col, today
22+
from matbench_discovery import (
23+
PDF_FIGS,
24+
SITE_FIGS,
25+
e_form_raw_col,
26+
formula_col,
27+
id_col,
28+
today,
29+
)
2330
from matbench_discovery.data import DATA_FILES
2431
from matbench_discovery.energy import get_e_form_per_atom
2532

@@ -39,7 +46,7 @@
3946

4047

4148
module_dir = os.path.dirname(__file__)
42-
e_form_col = "e_form_per_atom_wbm"
49+
e_form_wbm_col = "e_form_per_atom_wbm"
4350

4451

4552
# %% links to google drive files received via email from 1st author Hai-Chen Wang
@@ -296,7 +303,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
296303
"nsites": "n_sites",
297304
"vol": "volume",
298305
"e": "uncorrected_energy",
299-
"e_form": e_form_col,
306+
"e_form": e_form_wbm_col,
300307
"e_hull": "e_above_hull_wbm",
301308
"gap": "bandgap_pbe",
302309
"id": id_col,
@@ -440,15 +447,15 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
440447

441448
# %% remove suspicious formation energy outliers
442449
e_form_cutoff = 5
443-
n_too_stable = sum(df_summary[e_form_col] < -e_form_cutoff)
450+
n_too_stable = sum(df_summary[e_form_wbm_col] < -e_form_cutoff)
444451
print(f"{n_too_stable = }") # n_too_stable = 502
445-
n_too_unstable = sum(df_summary[e_form_col] > e_form_cutoff)
452+
n_too_unstable = sum(df_summary[e_form_wbm_col] > e_form_cutoff)
446453
print(f"{n_too_unstable = }") # n_too_unstable = 22
447454

448455
e_form_hist, e_form_bins = np.histogram(
449-
df_summary[e_form_col], bins=300, range=(-5.5, 5.5)
456+
df_summary[e_form_wbm_col], bins=300, range=(-5.5, 5.5)
450457
)
451-
x_label = {e_form_col: "WBM uncorrected formation energy (eV/atom)"}[e_form_col]
458+
x_label = {e_form_wbm_col: "WBM uncorrected formation energy (eV/atom)"}[e_form_wbm_col]
452459
fig = px.bar(
453460
x=e_form_bins[:-1], # [:-1] to drop last bin edge which is not needed
454461
y=e_form_hist,
@@ -485,7 +492,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
485492
# %%
486493
assert len(df_summary) == len(df_wbm) == 257_487
487494

488-
query_str = f"{-e_form_cutoff} < {e_form_col} < {e_form_cutoff}"
495+
query_str = f"{-e_form_cutoff} < {e_form_wbm_col} < {e_form_cutoff}"
489496
dropped_ids = sorted(set(df_summary.index) - set(df_summary.query(query_str).index))
490497
assert len(dropped_ids) == 502 + 22
491498
assert dropped_ids[:3] == "wbm-1-12142 wbm-1-12143 wbm-1-12144".split()
@@ -569,8 +576,6 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
569576
# first make sure source and target dfs have matching indices
570577
assert sum(df_wbm.index != df_summary.index) == 0
571578

572-
e_form_col = "e_form_per_atom_uncorrected"
573-
574579
for row in tqdm(df_wbm.itertuples(), total=len(df_wbm)):
575580
mat_id, cse, formula = row.Index, row.cse, row.formula_from_cse
576581
assert mat_id == cse.entry_id, f"{mat_id=} != {cse.entry_id=}"
@@ -585,11 +590,11 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
585590
assert (
586591
abs(e_form - e_form_ppd) < 1e-4
587592
), f"{mat_id}: {e_form=:.3} != {e_form_ppd=:.3} (diff={e_form - e_form_ppd:.3}))"
588-
df_summary.loc[cse.entry_id, e_form_col] = e_form
593+
df_summary.loc[cse.entry_id, e_form_raw_col] = e_form
589594

590595

591-
df_summary[e_form_col.replace("uncorrected", "mp2020_corrected")] = (
592-
df_summary[e_form_col] + df_summary["e_correction_per_atom_mp2020"]
596+
df_summary[e_form_raw_col.replace("uncorrected", "mp2020_corrected")] = (
597+
df_summary[e_form_raw_col] + df_summary["e_correction_per_atom_mp2020"]
593598
)
594599

595600

data/wbm/figs/hist-wbm-hull-dist.svg

+1-1
Loading

data/wbm/readme.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,18 @@ The number of stable materials (according to the MP convex hull which is spanned
9090

9191
The WBM test set and even more so the MP training set are heavily oxide dominated. The WBM test set is about 75% larger than the MP training set and also more chemically diverse, containing a higher fraction of transition metals, post-transition metals and metalloids. Our goal in picking such a large diverse test set is future-proofing. Ideally, this data will provide a challenging materials discovery test bed even for large foundational ML models in the future.
9292

93-
Below: Element counts for WBM test set consisting of 256,963 WBM `ComputedStructureEntries`
94-
9593
<slot name="wbm-elements-heatmap">
9694
<img src="./figs/wbm-elements.svg" alt="Periodic table log heatmap of WBM elements">
9795
</slot>
9896

99-
Below: Element counts for MP training set consisting of 154,719 `ComputedStructureEntries`
97+
The WBM test set consists of 256,963 WBM `ComputedStructureEntries`
10098

10199
<slot name="mp-elements-heatmap">
102100
<img src="./figs/mp-elements.svg" alt="Periodic table log heatmap of MP elements">
103101
</slot>
104102

103+
The MP training set consists of 154,719 `ComputedStructureEntries`
104+
105105
<slot name="mp-trj-elements-heatmap" />
106106

107107
## 📊 &thinsp; Symmetry Statistics

matbench_discovery/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
init_struct_col = "initial_structure"
4545
struct_col = "structure"
4646
e_form_col = "formation_energy_per_atom"
47+
e_form_raw_col = "e_form_per_atom_uncorrected"
4748
formula_col = "formula"
4849
stress_col = "stress"
4950
stress_trace_col = "stress_trace"

matbench_discovery/plots.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,8 @@ def rolling_mae_vs_hull_dist(
483483
y=(1, dft_acc, dft_acc, 1) if show_dft_acc else (1, 0, 1),
484484
name=triangle_anno,
485485
fillcolor="red",
486+
# remove triangle border
487+
line=dict(color="rgba(0,0,0,0)"),
486488
**scatter_kwds,
487489
)
488490
fig.add_annotation(
@@ -535,14 +537,10 @@ def rolling_mae_vs_hull_dist(
535537

536538
from matbench_discovery.preds import model_styles
537539

538-
for idx, trace in enumerate(fig.data):
540+
for trace in fig.data:
539541
if style := model_styles.get(trace.name):
540542
ls, _marker, color = style
541543
trace.line = dict(color=color, dash=ls, width=2)
542-
else:
543-
trace.line = dict(
544-
color=plotly_colors[idx], dash=plotly_line_styles[idx], width=3
545-
)
546544
# marker_spacing = 2
547545
# fig.add_scatter(
548546
# x=trace.x[::marker_spacing],

matbench_discovery/preds.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
each_pred_col = "e_above_hull_pred"
3131
model_mean_each_col = "Mean prediction all models"
3232
model_mean_err_col = "Mean error all models"
33-
model_std_col = "Std. dev. over models"
33+
model_std_each_col = "Std. dev. over models"
3434

35-
for col in (model_mean_each_col, model_mean_err_col, model_std_col):
35+
for col in (model_mean_each_col, model_mean_err_col, model_std_each_col):
3636
quantity_labels[col] = f"{col} {ev_per_atom}"
3737

3838

@@ -211,7 +211,7 @@ def load_df_wbm_with_preds(
211211
)
212212

213213
# important: do df_each_pred.std(axis=1) before inserting model_mean_each_col into df
214-
df_preds[model_std_col] = df_each_pred.std(axis=1)
214+
df_preds[model_std_each_col] = df_each_pred.std(axis=1)
215215
df_each_pred[model_mean_each_col] = df_preds[model_mean_each_col] = df_each_pred.mean(
216216
axis=1
217217
)

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66
name = "matbench-discovery"
77
version = "1.0.0"
88
description = "A benchmark for machine learning energy models on inorganic crystal stability prediction from unrelaxed structures"
9-
authors = [{ name = "Janosh Riebesell", email = "janosh@lbl.gov" }]
9+
authors = [{ name = "Janosh Riebesell", email = "janosh[email protected]" }]
1010
readme = "readme.md"
1111
license = { file = "license" }
1212
keywords = [
@@ -60,7 +60,7 @@ running-models = [
6060

6161
"aviary@git+https://github.com/CompRhys/aviary",
6262
"m3gnet",
63-
"mace@git+https://github.com/ACEsuit/mace",
63+
"mace-torch",
6464
"maml",
6565
"megnet",
6666
]

readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
<h4 align="center" class="toc-exclude">
77

8-
[![arXiv](https://img.shields.io/badge/arXiv-2308.14920-blue)](https://arxiv.org/abs/2308.14920)
8+
[![arXiv](https://img.shields.io/badge/arXiv-2308.14920-blue?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2308.14920)
99
[![Tests](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml)
1010
[![GitHub Pages](https://github.com/janosh/matbench-discovery/actions/workflows/gh-pages.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/gh-pages.yml)
1111
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)

scripts/model_figs/analyze_model_disagreement.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
each_true_col,
1919
model_mean_each_col,
2020
model_mean_err_col,
21-
model_std_col,
21+
model_std_each_col,
2222
)
2323

2424
__author__ = "Janosh Riebesell"
@@ -56,8 +56,7 @@
5656
fig = df_plot.plot.scatter(
5757
x=each_true_col,
5858
y=model_mean_each_col,
59-
color=model_std_col,
60-
size="n_sites",
59+
color=model_std_each_col,
6160
backend="plotly",
6261
hover_name=id_col,
6362
hover_data=[formula_col],
@@ -71,11 +70,14 @@
7170
fig.layout.coloraxis.colorbar.update(title_side="right", thickness=14)
7271
fig.layout.margin.update(l=0, r=30, b=0, t=60)
7372
add_identity_line(fig)
73+
label = {"all": "structures"}.get(material_cls, material_cls)
7474
fig.layout.title.update(
75-
text=f"{n_structs} largest {material_cls} model errors: Predicted vs.<br>"
76-
"DFT hull distance colored by model disagreement",
75+
text=f"{n_structs} {material_cls} with largest hull distance errors<br>"
76+
"colored by model disagreement, sized by number of sites",
7777
x=0.5,
7878
)
79+
# size markers by structure
80+
fig.data[0].marker.size = df_plot["n_sites"] ** 0.5 * 3
7981
# tried setting error_y=model_std_col but looks bad
8082
# fig.update_traces(
8183
# error_y=dict(color="rgba(255,255,255,0.2)", width=3, thickness=2)

scripts/model_figs/parity_energy_models.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
df_melt,
6161
bin_by_cols=[e_true_col, e_pred_col],
6262
group_by_cols=[facet_col],
63-
n_bins=200,
63+
n_bins=300,
6464
bin_counts_col=(bin_cnt_col := "bin counts"),
6565
)
6666
df_bin = df_bin.reset_index()
@@ -162,7 +162,8 @@
162162
# pick from https://plotly.com/python/builtin-colorscales
163163
color_continuous_scale="agsunset",
164164
)
165-
165+
# decrease marker size
166+
fig.update_traces(marker=dict(size=2))
166167
# manually set colorbar ticks and labels (needed after log1p transform)
167168
tick_vals = [1, 10, 100, 1000, 10_000]
168169
fig.layout.coloraxis.colorbar.update(
@@ -181,9 +182,8 @@
181182
assert model in df_preds, f"Unexpected {model=} not in {list(df_preds)=}"
182183
# add MAE and R2 to subplot titles
183184
MAE, R2 = df_metrics[model][["MAE", "R2"]]
184-
fig.layout.annotations[
185-
idx - 1
186-
].text = f"{model} · {MAE=:.2f} · R<sup>2</sup>={R2:.2f}"
185+
sub_title = f"{model} · {MAE=:.2f} · R<sup>2</sup>={R2:.2f}"
186+
fig.layout.annotations[idx - 1].text = sub_title
187187

188188
# remove subplot x and y axis titles
189189
fig.layout[f"xaxis{idx}"].title.text = ""
@@ -222,7 +222,7 @@
222222
yshift=-15 * sign_y,
223223
text=label,
224224
showarrow=False,
225-
font=dict(size=16, color=color),
225+
font=dict(size=14, color=color),
226226
row="all",
227227
col="all",
228228
)
@@ -245,9 +245,10 @@
245245
# fig.update_layout(yaxis=dict(scaleanchor="x", scaleratio=1))
246246

247247
axis_titles = dict(xref="paper", yref="paper", showarrow=False)
248+
portrait = n_rows > n_cols
248249
fig.add_annotation( # x-axis title
249250
x=0.5,
250-
y=-0.06,
251+
y=-0.06 if portrait else -0.18,
251252
text=x_title,
252253
**axis_titles,
253254
)
@@ -259,10 +260,10 @@
259260
**axis_titles,
260261
)
261262

262-
fig.layout.height = 230 * n_rows
263+
fig.layout.update(height=230 * n_rows, width=180 * n_cols)
263264
fig.layout.coloraxis.colorbar.update(orientation="h", thickness=9, len=0.5, y=1.05)
264265
# fig.layout.width = 1100
265-
fig.layout.margin.update(l=40, r=10, t=30, b=60)
266+
fig.layout.margin.update(l=40, r=10, t=30 if portrait else 10, b=60 if portrait else 10)
266267
fig.update_xaxes(matches=None)
267268
fig.update_yaxes(matches=None)
268269
fig.show()

0 commit comments

Comments
 (0)