Skip to content

Commit a6bfa74

Browse files
committed
plot easy vs hard structures (for all models) norm of SiteStats fingerprint difference before/after relaxation
1 parent b8a18d8 commit a6bfa74

File tree

10 files changed

+204
-59
lines changed

10 files changed

+204
-59
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/charliermarsh/ruff-pre-commit
10-
rev: v0.0.252
10+
rev: v0.0.255
1111
hooks:
1212
- id: ruff
1313
args: [--fix]

matbench_discovery/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class DataFiles(Files):
6363
"wbm/2022-10-19-wbm-computed-structure-entries.json.bz2"
6464
)
6565
wbm_initial_structures = "wbm/2022-10-19-wbm-init-structs.json.bz2"
66-
wbm_computed_structure_entries_plus_init_structs = (
66+
wbm_cses_plus_init_structs = (
6767
"wbm/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
6868
)
6969
wbm_summary = "wbm/2022-10-19-wbm-summary.csv"

matbench_discovery/plots.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -754,8 +754,8 @@ def cumulative_precision_recall(
754754
align="left",
755755
)
756756
fig.layout.legend.title = ""
757-
fig.update_xaxes(showticklabels=True, title="")
758-
fig.update_yaxes(showticklabels=True, title="")
757+
fig.update_xaxes(showticklabels=True, title="", matches=None)
758+
fig.update_yaxes(showticklabels=True, title="", matches=None)
759759

760760
return fig, df_cum
761761

models/m3gnet/pre_vs_post_m3gnet_relaxation.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323

2424

2525
# %%
26-
df_wbm = pd.read_json(
27-
DATA_FILES.wbm_computed_structure_entries_plus_init_structs
28-
).set_index("material_id")
26+
df_wbm = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index("material_id")
2927

3028
df_summary = pd.read_csv(DATA_FILES.wbm_summary).set_index("material_id")
3129

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ running-models = ["aviary", "m3gnet", "maml", "megnet"]
5454
3d-structures = ["crystaltoolkit"]
5555

5656
[tool.setuptools.packages]
57-
find = { include = ["matbench_discovery"] }
57+
find = { include = ["matbench_discovery*"], exclude = ["tests*"] }
5858

5959
[tool.setuptools.package-data]
6060
matbench_discovery = ["data/mp/*.json"]
@@ -66,7 +66,7 @@ universal = true
6666
target-version = "py39"
6767
select = [
6868
"B", # flake8-bugbear
69-
"C4", # flake8-comprehensions
69+
"C40", # flake8-comprehensions
7070
"D", # pydocstyle
7171
"E", # pycodestyle
7272
"F", # pyflakes

scripts/difficult_structures.py

+180-24
Original file line numberDiff line numberDiff line change
@@ -5,56 +5,128 @@
55

66

77
# %%
8+
import itertools
9+
810
import matplotlib.pyplot as plt
11+
import numpy as np
912
import pandas as pd
10-
from pymatgen.core import Structure
11-
from pymatviz import plot_structure_2d, ptable_heatmap_plotly
13+
from matminer.featurizers.site import CrystalNNFingerprint
14+
from matminer.featurizers.structure import SiteStatsFingerprint
15+
from pymatgen.core import Composition, Element, Structure
16+
from pymatviz import count_elements, plot_structure_2d, ptable_heatmap_plotly
17+
from tqdm import tqdm
1218

13-
from matbench_discovery import ROOT
19+
from matbench_discovery import MODELS, ROOT
1420
from matbench_discovery.data import DATA_FILES
21+
from matbench_discovery.data import df_wbm as df_summary
1522
from matbench_discovery.metrics import classify_stable
16-
from matbench_discovery.preds import df_each_err, df_each_pred, df_preds, each_true_col
23+
from matbench_discovery.preds import (
24+
df_each_err,
25+
df_each_pred,
26+
df_metrics,
27+
df_preds,
28+
each_true_col,
29+
)
1730

1831
__author__ = "Janosh Riebesell"
1932
__date__ = "2023-02-15"
2033

2134
df_each_err[each_true_col] = df_preds[each_true_col]
22-
mean_ae_col = "All models mean absolute error (eV/atom)"
35+
mean_ae_col = "All models MAE (eV/atom)"
2336
df_each_err[mean_ae_col] = df_preds[mean_ae_col] = df_each_err.abs().mean(axis=1)
2437

2538

2639
# %%
27-
df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(
28-
"material_id"
29-
)
40+
df_wbm = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index("material_id")
3041

3142

3243
# %%
3344
n_rows, n_cols = 5, 4
34-
for which in ("best", "worst"):
45+
for good_bad, init_final in itertools.product(("best", "worst"), ("initial", "final")):
3546
fig, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
36-
n_axs = len(axs.flat)
47+
n_structs = len(axs.flat)
48+
struct_col = {
49+
"initial": "initial_structure",
50+
"final": "computed_structure_entry",
51+
}[init_final]
3752

38-
errs = (
39-
df_each_err.mean_ae.nsmallest(n_axs)
40-
if which == "best"
41-
else df_each_err.mean_ae.nlargest(n_axs)
53+
errs = {
54+
"best": df_each_err[mean_ae_col].nsmallest(n_structs),
55+
"worst": df_each_err[mean_ae_col].nlargest(n_structs),
56+
}[good_bad]
57+
title = (
58+
f"{good_bad.title()} {len(errs)} {init_final} structures (across "
59+
f"{len(list(df_each_pred))} models)\nErrors in (ev/atom)"
4260
)
43-
title = f"{which} {len(errs)} structures (across {len(list(df_each_pred))} models)"
44-
fig.suptitle(title, fontsize=16, fontweight="bold", y=0.95)
61+
fig.suptitle(title, fontsize=20, fontweight="bold", y=1.05)
4562

46-
for idx, (ax, (id, err)) in enumerate(zip(axs.flat, errs.items()), 1):
47-
struct = Structure.from_dict(
48-
df_cse.computed_structure_entry.loc[id]["structure"]
49-
)
63+
for idx, (ax, (id, error)) in enumerate(zip(axs.flat, errs.items()), 1):
64+
struct = df_wbm[struct_col].loc[id]
65+
if init_final == "relaxed":
66+
struct = struct["structure"]
67+
struct = Structure.from_dict(struct)
5068
plot_structure_2d(struct, ax=ax)
5169
_, spg_num = struct.get_space_group_info()
5270
formula = struct.composition.reduced_formula
5371
ax.set_title(
54-
f"{idx}. {formula} (spg={spg_num})\n{id} {err=:.2f}", fontweight="bold"
72+
f"{idx}. {formula} (spg={spg_num})\n{id} {error=:.2f}", fontweight="bold"
5573
)
74+
out_path = f"{ROOT}/tmp/figures/{good_bad}-{len(errs)}-structures-{init_final}.webp"
75+
fig.savefig(out_path, dpi=300)
76+
77+
78+
# %%
79+
n_structs = 100
80+
worst_ids = df_each_err[mean_ae_col].nlargest(n_structs).index.tolist()
81+
best_ids = df_each_err[mean_ae_col].nsmallest(n_structs).index.tolist()
82+
83+
best_init_structs = df_wbm.initial_structure.loc[best_ids].map(Structure.from_dict)
84+
worst_init_structs = df_wbm.initial_structure.loc[worst_ids].map(Structure.from_dict)
85+
best_final_structs = df_wbm.computed_structure_entry.loc[best_ids].map(
86+
lambda cse: Structure.from_dict(cse["structure"])
87+
)
88+
worst_final_structs = df_wbm.computed_structure_entry.loc[worst_ids].map(
89+
lambda cse: Structure.from_dict(cse["structure"])
90+
)
91+
92+
93+
# %%
94+
cnn_fp = CrystalNNFingerprint.from_preset("ops")
95+
site_stats_fp = SiteStatsFingerprint(
96+
cnn_fp, stats=("mean", "std_dev", "minimum", "maximum")
97+
)
98+
99+
worst_fp_diff_norms = (
100+
worst_final_structs.map(site_stats_fp.featurize).map(np.array)
101+
- worst_init_structs.map(site_stats_fp.featurize).map(np.array)
102+
).map(np.linalg.norm)
56103

57-
fig.savefig(f"{ROOT}/tmp/figures/{which}-{len(errs)}-structures.webp", dpi=300)
104+
best_fp_diff_norms = (
105+
best_final_structs.map(site_stats_fp.featurize).map(np.array)
106+
- best_init_structs.map(site_stats_fp.featurize).map(np.array)
107+
).map(np.linalg.norm)
108+
109+
df_fp = pd.DataFrame(
110+
[worst_fp_diff_norms.values, best_fp_diff_norms.values],
111+
index=["highest-error structures", "lowest-error structures"],
112+
).T
113+
114+
115+
# %%
116+
fig = df_fp.plot.hist(backend="plotly", nbins=50, barmode="overlay", opacity=0.8)
117+
title = (
118+
f"SiteStatsFingerprint norm-diff between initial/final {n_structs}<br>"
119+
f"highest/lowest-error structures (mean over {len(list(df_each_pred))} models)"
120+
)
121+
fig.layout.title.update(text=title, font_size=20, xanchor="center", x=0.5)
122+
fig.layout.legend.update(
123+
title="", yanchor="top", y=0.98, xanchor="right", x=0.98, font_size=16
124+
)
125+
fig.layout.xaxis.title = "|SSFP<sub>initial</sub> - SSFP<sub>final</sub>|"
126+
fig.show()
127+
fig.write_image(
128+
f"{ROOT}/tmp/figures/init-final-fp-diff-norms.webp", width=1000, scale=2
129+
)
58130

59131

60132
# %% plotly scatter plot of largest model errors with points sized by mean error and
@@ -99,8 +171,92 @@
99171

100172

101173
# %%
102-
ptable_heatmap_plotly(df_preds[df_preds.all_false_pos].formula, colorscale="Viridis")
103-
ptable_heatmap_plotly(df_preds[df_preds.all_false_neg].formula, colorscale="Viridis")
174+
elem_counts: dict[str, pd.Series] = {}
175+
for col in ("all_false_neg", "all_false_pos"):
176+
elem_counts[col] = elem_counts.get(col, count_elements(df_preds.query(col).formula))
177+
fig = ptable_heatmap_plotly(elem_counts[col], font_size=10)
178+
fig.layout.title = col
179+
fig.show()
180+
181+
182+
# %% scatter plot error by element against prevalence in training set
183+
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id")
184+
# compute number of samples per element in training set
185+
# counting element occurrences not weighted by composition, assuming model don't learn
186+
# much more about iron and oxygen from Fe2O3 than from FeO
187+
188+
count_col = "MP Occurrences"
189+
df_elem_err = count_elements(df_mp.formula_pretty, count_mode="occurrence").to_frame(
190+
name=count_col
191+
)
192+
193+
title = "Number of MP structures containing each element"
194+
fig = df_elem_err[count_col].plot.bar(backend="plotly", title=title)
195+
fig.update_layout(showlegend=False)
196+
fig.show()
197+
198+
fig = ptable_heatmap_plotly(df_elem_err[count_col], font_size=10)
199+
fig.layout.title.update(text=title, x=0.35, y=0.9, font_size=20)
200+
fig.show()
201+
202+
203+
# %% map average model error onto elements
204+
df_summary["fractional_composition"] = [
205+
Composition(comp).fractional_composition for comp in tqdm(df_summary.formula)
206+
]
207+
208+
df_frac_comp = pd.json_normalize(
209+
[comp.as_dict() for comp in df_summary["fractional_composition"]]
210+
).set_index(df_summary.index)
211+
assert all(
212+
df_frac_comp.sum(axis=1).round(6) == 1
213+
), "composition fractions don't sum to 1"
214+
215+
(len(df_frac_comp) - df_frac_comp.isna().sum()).sort_values().plot.bar(backend="plotly")
216+
217+
# df_frac_comp = df_frac_comp.dropna(axis=1, thresh=100) # remove Xe with only 1 entry
218+
219+
220+
# %%
221+
for model in (*df_metrics, mean_ae_col):
222+
df_elem_err[model] = (
223+
df_frac_comp * df_each_err[model].abs().values[:, None]
224+
).mean()
225+
fig = ptable_heatmap_plotly(
226+
df_elem_err[model],
227+
precision=".2f",
228+
fill_value=None,
229+
cbar_max=0.2,
230+
colorscale="Turbo",
231+
)
232+
fig.layout.title.update(text=model, x=0.35, y=0.9, font_size=20)
233+
fig.show()
234+
235+
236+
# %%
237+
df_elem_err.to_json(f"{MODELS}/per-element/per-element-model-each-errors.json")
238+
239+
240+
# %%
241+
df_elem_err["elem_name"] = [Element(el).long_name for el in df_elem_err.index]
242+
fig = df_elem_err.plot.scatter(
243+
x=count_col,
244+
y=mean_ae_col,
245+
backend="plotly",
246+
hover_name="elem_name",
247+
text=df_elem_err.index.where(
248+
(df_elem_err[mean_ae_col] > 0.04) | (df_elem_err[count_col] > 10_000)
249+
),
250+
title="Correlation between element-error and element-occurrence in<br>training "
251+
f"set: {df_elem_err[mean_ae_col].corr(df_elem_err[count_col]):.2f}",
252+
hover_data={mean_ae_col: ":.2f", count_col: ":,.0f"},
253+
)
254+
255+
fig.update_traces(textposition="top center")
256+
fig.show()
257+
258+
# save_fig(fig, f"{ROOT}/tmp/figures/element-occu-vs-err.webp", scale=2)
259+
# save_fig(fig, f"{ROOT}/tmp/figures/element-occu-vs-err.pdf")
104260

105261

106262
# %%

site/package.json

+9-9
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
"@sveltejs/adapter-static": "^2.0.1",
2323
"@sveltejs/kit": "^1.11.0",
2424
"@sveltejs/vite-plugin-svelte": "^2.0.3",
25-
"@typescript-eslint/eslint-plugin": "^5.54.1",
26-
"@typescript-eslint/parser": "^5.54.1",
25+
"@typescript-eslint/eslint-plugin": "^5.55.0",
26+
"@typescript-eslint/parser": "^5.55.0",
2727
"elementari": "^0.1.0",
28-
"eslint": "^8.35.0",
28+
"eslint": "^8.36.0",
2929
"eslint-plugin-svelte3": "^4.0.0",
3030
"hastscript": "^7.2.0",
3131
"js-yaml": "^4.1.0",
@@ -38,14 +38,14 @@
3838
"rehype-slug": "^5.1.0",
3939
"remark-math": "3.0.0",
4040
"svelte": "^3.56.0",
41-
"svelte-check": "^3.1.0",
41+
"svelte-check": "^3.1.4",
4242
"svelte-multiselect": "^8.5.0",
43-
"svelte-preprocess": "^5.0.1",
44-
"svelte-toc": "^0.5.2",
45-
"svelte-zoo": "^0.3.4",
46-
"svelte2tsx": "^0.6.3",
43+
"svelte-preprocess": "^5.0.2",
44+
"svelte-toc": "^0.5.3",
45+
"svelte-zoo": "^0.4.3",
46+
"svelte2tsx": "^0.6.9",
4747
"tslib": "^2.5.0",
48-
"typescript": "^4.9.5",
48+
"typescript": "5.0.1-rc",
4949
"vite": "^4.1.4"
5050
},
5151
"prettier": {

site/src/app.css

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
--sms-focus-border: 0.1px solid white;
2525
--sms-active-color: cornflowerblue;
2626
}
27+
html {
28+
scroll-behavior: smooth;
29+
}
2730
body {
2831
background: var(--night);
2932
font-family: -apple-system, BlinkMacSystemFont, Roboto, sans-serif;

site/src/routes/+layout.svelte

+5-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import { repository } from '$site/package.json'
66
import { CmdPalette } from 'svelte-multiselect'
77
import Toc from 'svelte-toc'
8-
import { GitHubCorner } from 'svelte-zoo'
8+
import { GitHubCorner, PrevNext } from 'svelte-zoo'
99
import '../app.css'
1010
1111
const routes = Object.keys(import.meta.glob(`./*/+page.{svx,svelte,md}`)).map(
@@ -16,10 +16,6 @@
1616
$page.url.pathname === `/api` ? `h1, ` : ``
1717
}h2, h3, h4):not(.toc-exclude)`
1818
19-
$: current_route_idx = routes.findIndex((route) => route === $page.url.pathname)
20-
// get prev/next route with wrap-around
21-
$: next_route = routes[(current_route_idx + 1) % routes.length]
22-
$: prev_route = routes[(current_route_idx - 1 + routes.length) % routes.length]
2319
$: description = {
2420
'/': `Benchmarking machine learning energy models for materials discovery.`,
2521
'/about-the-data': `Details about provenance, chemistry and energies in the benchmark's train and test set.`,
@@ -64,10 +60,10 @@
6460

6561
<slot />
6662

67-
<section>
68-
<a href={prev_route} class="link">&laquo; {prev_route}</a>
69-
<a href={next_route} class="link">{next_route} &raquo;</a>
70-
</section>
63+
<PrevNext items={routes} current={$page.url.pathname} let:item={href}>
64+
<a {href} class="link" slot="next">{href} &raquo;</a>
65+
<a {href} class="link" slot="prev">&laquo; {href}</a>
66+
</PrevNext>
7167
</main>
7268

7369
<Footer />

site/static/prism-vsc-dark-plus.css

-8
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,6 @@ pre[class*='language-'] {
2727
hyphens: none;
2828
}
2929

30-
pre[class*='language-']::-moz-selection,
31-
pre[class*='language-'] ::-moz-selection,
32-
code[class*='language-']::-moz-selection,
33-
code[class*='language-'] ::-moz-selection {
34-
text-shadow: none;
35-
background: rgba(29, 59, 83, 0.99);
36-
}
37-
3830
pre[class*='language-']::selection,
3931
pre[class*='language-'] ::selection,
4032
code[class*='language-']::selection,

0 commit comments

Comments
 (0)