Skip to content

Commit 13b1173

Browse files
committed
fix per-model KDE in scatter_e_above_hull_models.py, add color bar since color value is now meaningful
1 parent 5f59a90 commit 13b1173

File tree

7 files changed

+100
-27
lines changed

7 files changed

+100
-27
lines changed

citation.cff

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ authors:
2626
- given-names: Philipp
2727
family-names: Benner
2828
affiliation: German Federal Institute of Materials Research and Testing (BAM)
29+
orcid: 0000-0002-0912-8137
2930
affil_key: 3
3031
github: https://github.com/pbenner
3132
- given-names: Kristin

models/mace/analyze_mace.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Investigate MACE energy underpredictions."""
2+
3+
4+
# %%
5+
import os
6+
7+
import pandas as pd
8+
from pymatviz import density_scatter, ptable_heatmap_plotly, spacegroup_sunburst
9+
10+
from matbench_discovery import plots as plots
11+
from matbench_discovery.data import df_wbm
12+
from matbench_discovery.preds import PRED_FILES
13+
14+
__author__ = "Janosh Riebesell"
15+
__date__ = "2023-07-23"
16+
17+
module_dir = os.path.dirname(__file__)
18+
id_col = "material_id"
19+
target_col = "e_form_per_atom_mp2020_corrected"
20+
pred_col = "e_form_per_atom_mace"
21+
22+
23+
# %%
24+
df_mace = pd.read_csv(PRED_FILES.MACE).set_index(id_col)
25+
df_mace[list(df_wbm)] = df_wbm
26+
27+
wyckoff_col, spg_col = "wyckoff_spglib", "spacegroup"
28+
df_mace[spg_col] = df_wbm[wyckoff_col].str.split("_").str[2].astype(int)
29+
30+
31+
# %%
32+
density_scatter(df=df_mace, x=target_col, y=pred_col)
33+
34+
35+
# %%
36+
df_bad = df_mace.query(f"{target_col} - {pred_col} > 2")
37+
38+
ax = density_scatter(df=df_bad, x=target_col, y=pred_col)
39+
ax.set(title=f"{len(df_bad):,} MACE severe energy underpredictions")
40+
41+
42+
# %%
43+
fig = ptable_heatmap_plotly(df_bad.formula)
44+
title = f"Elements in {len(df_bad):,} MACE severe energy underpredictions"
45+
fig.layout.title.update(text=title, x=0.4, y=0.95)
46+
fig.show()
47+
48+
49+
# %%
50+
fig = spacegroup_sunburst(df_bad[spg_col], title="MACE spacegroups")
51+
title = f"Spacegroup sunburst of {len(df_bad):,} MACE severe energy underpredictions"
52+
fig.layout.title.update(text=title, x=0.5)
53+
fig.show()
54+
55+
56+
"""
57+
Space groups of MACE underpredictions look unremarkable but unusually heavy in Silicon,
58+
Lanthanide and heavy alkali metals.
59+
"""

scripts/model_figs/scatter_e_above_hull_models.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import plotly.express as px
1212
import scipy.stats
1313
from pymatviz.utils import add_identity_line, bin_df_cols, save_fig
14+
from tqdm import tqdm
1415

1516
from matbench_discovery import FIGS, PDF_FIGS
1617
from matbench_discovery.metrics import classify_stable
@@ -119,35 +120,52 @@
119120
# save_fig(fig, f"{img_path}.svelte")
120121

121122

122-
# %% plot all models in separate subplots
123-
n_cols = 2
124-
n_rows = math.ceil(len(models) / n_cols)
123+
# %%
124+
clr_col, cnt_col = "density", "counts"
125+
# compute KDE for each model's predictions separately
126+
for model in (pbar := tqdm(models)):
127+
pbar.set_description(f"KDE for {model=}")
128+
129+
xy = df_preds[[each_true_col, model]].dropna().T
130+
model_kde = scipy.stats.gaussian_kde(xy)
131+
132+
model_rows = df_bin[df_bin[facet_col] == model]
133+
xy_binned = model_rows[[each_true_col, each_pred_col]].T
134+
density = model_kde(xy_binned)
135+
n_preds = len(df_preds[model].dropna())
136+
df_bin.loc[model_rows.index, cnt_col] = density / density.sum() * n_preds
125137

138+
df_bin[clr_col] = np.log1p(df_bin[cnt_col]).round(2)
126139

127-
def get_density(xs: np.ndarray, ys: np.ndarray) -> np.ndarray:
128-
"""Get kernel density estimate for each (x, y) point."""
129-
return scipy.stats.gaussian_kde([xs, ys])([xs, ys])
130140

141+
# %% scatter plot of DFT vs predicted hull distance with each model in separate subplot
142+
n_cols = 2
143+
n_rows = math.ceil(len(models) / n_cols)
131144

132-
# scatter plot of DFT vs predicted hull distance
133145
fig = px.scatter(
134146
df_bin,
135147
x=each_true_col,
136148
y=each_pred_col,
137149
facet_col=facet_col,
138150
facet_col_wrap=n_cols,
139-
color=get_density(df_bin[each_true_col], df_bin[each_pred_col]),
151+
color=clr_col,
140152
facet_col_spacing=0.02,
141153
facet_row_spacing=0.04,
142154
hover_data=hover_cols,
143155
hover_name=df_preds.index.name,
144156
# color=clf_col,
145-
color_discrete_map=clf_color_map,
157+
# color_discrete_map=clf_color_map,
146158
# opacity=0.4,
147159
range_x=(domain := (-4, 7)),
148160
range_y=domain,
149161
category_orders={facet_col: legend_order},
150-
color_continuous_scale="turbo",
162+
color_continuous_scale="turbo", # "thermal"
163+
)
164+
165+
# manually set colorbar ticks and labels (needed after log1p transform)
166+
tick_vals = [1, 10, 100, 1000, 10_000]
167+
fig.layout.coloraxis.colorbar.update(
168+
tickvals=np.log1p(tick_vals), ticktext=list(map(str, tick_vals))
151169
)
152170

153171
x_title = fig.layout.xaxis.title.text # used in annotations below
@@ -229,7 +247,7 @@ def get_density(xs: np.ndarray, ys: np.ndarray) -> np.ndarray:
229247
**axis_titles,
230248
)
231249
fig.layout.height = 200 * n_rows
232-
fig.layout.coloraxis.showscale = False
250+
fig.layout.coloraxis.colorbar.update(orientation="h", thickness=9, len=0.5, y=1.05)
233251
# fig.layout.width = 1100
234252
fig.layout.margin.update(l=40, r=10, t=30, b=60)
235253
fig.update_xaxes(matches=None)

site/package.json

+10-10
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@
1919
"devDependencies": {
2020
"@iconify/svelte": "^3.1.4",
2121
"@rollup/plugin-yaml": "^4.1.1",
22-
"@sveltejs/adapter-static": "^2.0.2",
23-
"@sveltejs/kit": "^1.22.3",
22+
"@sveltejs/adapter-static": "^2.0.3",
23+
"@sveltejs/kit": "^1.22.4",
2424
"@sveltejs/vite-plugin-svelte": "^2.4.3",
25-
"@typescript-eslint/eslint-plugin": "^6.2.0",
26-
"@typescript-eslint/parser": "^6.2.0",
25+
"@typescript-eslint/eslint-plugin": "^6.2.1",
26+
"@typescript-eslint/parser": "^6.2.1",
2727
"d3-scale-chromatic": "^3.0.0",
2828
"elementari": "^0.2.2",
29-
"eslint": "^8.45.0",
29+
"eslint": "^8.46.0",
3030
"eslint-plugin-svelte": "^2.32.4",
31-
"hastscript": "^7.2.0",
31+
"hastscript": "^8.0.0",
3232
"highlight.js": "^11.8.0",
3333
"js-yaml": "^4.1.0",
3434
"katex": "^0.16.8",
3535
"mdsvex": "^0.11.0",
36-
"prettier": "^3.0.0",
37-
"prettier-plugin-svelte": "^3.0.1",
36+
"prettier": "^3.0.1",
37+
"prettier-plugin-svelte": "^3.0.3",
3838
"rehype-autolink-headings": "^6.1.1",
3939
"rehype-katex-svelte": "^1.2.0",
4040
"rehype-slug": "^5.1.0",
4141
"remark-math": "3.0.0",
42-
"svelte": "^4.1.1",
42+
"svelte": "^4.1.2",
4343
"svelte-check": "^3.4.6",
4444
"svelte-multiselect": "^10.1.0",
4545
"svelte-preprocess": "^5.0.4",
@@ -48,7 +48,7 @@
4848
"svelte2tsx": "^0.6.19",
4949
"tslib": "^2.6.1",
5050
"typescript": "5.1.6",
51-
"vite": "^4.4.7"
51+
"vite": "^4.4.8"
5252
},
5353
"prettier": {
5454
"semi": false,

site/src/app.css

-4
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,6 @@ strong a[aria-hidden='true'] {
174174
left: 0;
175175
}
176176

177-
/* for /api/[slug] */
178-
kbd {
179-
padding: 0 1ex 0 0;
180-
}
181177
aside.toc.desktop {
182178
position: fixed;
183179
top: 3em;

site/src/figs/each-scatter-models-5x2.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/routes/api/+page.svelte

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
/* select all but first module h1s */
99
:global(h1[id^='module-']:not(:nth-of-type(2))) {
1010
margin: 2em 0 0;
11-
font-size: 2em !important;
1211
}
1312
:global(h1 > kbd) {
1413
font-size: 25pt;

0 commit comments

Comments
 (0)