Skip to content

Commit 0fad3bd

Browse files
committed
split analyze_model_failure_cases.py into two scripts, new one is analyze_elements.py
update bunch of figures add bar-element-counts-mp+wbm to with/without normalization to /about-the-data/tmi
1 parent 0f2410d commit 0fad3bd

21 files changed

+439
-302
lines changed

matbench_discovery/preds.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
e_form_col = "e_form_per_atom_mp2020_corrected"
2020
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
2121
each_pred_col = "e_above_hull_pred"
22+
model_mean_err_col = "Mean over models"
2223

2324

2425
class PredFiles(Files):

readme.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ In version 1 of this benchmark, we explore 8 models covering multiple methodolog
2323

2424
We welcome contributions that add new models to the leaderboard through [GitHub PRs](https://github.com/janosh/matbench-discovery/pulls). See the [usage and contributing guide](https://janosh.github.io/matbench-discovery/contribute) for details.
2525

26-
For a version 2 release of this benchmark, we plan to merge the current training and test sets into the new training set and acquire a much larger test set compared to the v1 test set of 257k structures.
26+
For a version 2 release of this benchmark, we plan to merge the current training and test sets into the new training set and acquire a much larger test set (potentially at meta-GGA level of theory) compared to the v1 test set of 257k structures. Anyone interested in joining this effort please [open a GitHub discussion](https://github.com/janosh/matbench-discovery/discussions) or [reach out privately](mailto:[email protected]?subject=Matbench%20Discovery).
27+
28+
For detailed results and analysis, check out the [paper](https://matbench-discovery.janosh.dev/paper) and [supplementary material](https://matbench-discovery.janosh.dev/si).

scripts/analyze_element_errors.py

+246
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
"""Analyze structures and composition with largest mean error across all models.
2+
Maybe there's some chemistry/region of materials space that all models struggle with?
3+
Might point to deficiencies in the data or models architecture.
4+
"""
5+
6+
7+
# %%
8+
import pandas as pd
9+
import plotly.express as px
10+
from pymatgen.core import Composition, Element
11+
from pymatviz import count_elements, ptable_heatmap_plotly
12+
from pymatviz.utils import bin_df_cols, save_fig
13+
from sklearn.metrics import r2_score
14+
from tqdm import tqdm
15+
16+
from matbench_discovery import FIGS, MODELS, ROOT
17+
from matbench_discovery.data import DATA_FILES, df_wbm
18+
from matbench_discovery.preds import (
19+
df_each_err,
20+
df_metrics,
21+
df_preds,
22+
each_pred_col,
23+
each_true_col,
24+
model_mean_err_col,
25+
)
26+
27+
__author__ = "Janosh Riebesell"
28+
__date__ = "2023-02-15"
29+
30+
df_each_err[model_mean_err_col] = df_preds[model_mean_err_col] = df_each_err.abs().mean(
31+
axis=1
32+
)
33+
34+
35+
# %%
36+
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id")
37+
# compute number of samples per element in training set
38+
# counting element occurrences not weighted by composition, assuming model don't learn
39+
# much more about iron and oxygen from Fe2O3 than from FeO
40+
41+
train_count_col = "MP Occurrences"
42+
df_elem_err = count_elements(df_mp.formula_pretty, count_mode="occurrence").to_frame(
43+
name=train_count_col
44+
)
45+
46+
47+
# %%
48+
fig = ptable_heatmap_plotly(df_elem_err[train_count_col], font_size=10)
49+
title = "Number of MP structures containing each element"
50+
fig.layout.title.update(text=title, x=0.4, y=0.9)
51+
fig.show()
52+
53+
54+
# %% map average model error onto elements
55+
frac_comp_col = "fractional composition"
56+
df_wbm[frac_comp_col] = [
57+
Composition(comp).fractional_composition for comp in tqdm(df_wbm.formula)
58+
]
59+
60+
df_frac_comp = pd.DataFrame(comp.as_dict() for comp in df_wbm[frac_comp_col]).set_index(
61+
df_wbm.index
62+
)
63+
assert all(
64+
df_frac_comp.sum(axis=1).round(6) == 1
65+
), "composition fractions don't sum to 1"
66+
67+
# df_frac_comp = df_frac_comp.dropna(axis=1, thresh=100) # remove Xe with only 1 entry
68+
69+
70+
# %%
71+
for label, srs in (
72+
("MP", df_elem_err[train_count_col]),
73+
("WBM", df_frac_comp.where(pd.isna, 1).sum()),
74+
):
75+
title = f"Number of {label} structures containing each element"
76+
srs = srs.sort_values().copy()
77+
srs.index = [f"{len(srs) - idx} {el}" for idx, el in enumerate(srs.index)]
78+
fig = srs.plot.bar(backend="plotly", title=title)
79+
fig.layout.update(showlegend=False)
80+
fig.show()
81+
82+
83+
# %% plot structure counts for each element in MP and WBM in a grouped bar chart
84+
df_struct_counts = pd.DataFrame(index=df_elem_err.index)
85+
df_struct_counts["MP"] = df_elem_err[train_count_col]
86+
df_struct_counts["WBM"] = df_frac_comp.where(pd.isna, 1).sum()
87+
min_count = 10 # only show elements with at least 10 structures
88+
df_struct_counts = df_struct_counts[df_struct_counts.sum(axis=1) > min_count]
89+
normalized = False
90+
if normalized:
91+
df_struct_counts["MP"] /= len(df_mp) / 100
92+
df_struct_counts["WBM"] /= len(df_wbm) / 100
93+
y_col = "percent" if normalized else "count"
94+
fig = (
95+
df_struct_counts.reset_index()
96+
.melt(var_name="dataset", value_name=y_col, id_vars="symbol")
97+
.sort_values([y_col, "symbol"])
98+
.plot.bar(
99+
x="symbol",
100+
y=y_col,
101+
backend="plotly",
102+
title="Number of structures containing each element",
103+
color="dataset",
104+
barmode="group",
105+
)
106+
)
107+
108+
fig.layout.update(bargap=0.1)
109+
fig.layout.legend.update(x=0.02, y=0.98, font_size=16)
110+
fig.show()
111+
save_fig(fig, f"{FIGS}/bar-element-counts-mp+wbm-{normalized=}.svelte")
112+
113+
114+
# %%
115+
test_set_std_col = "Test set standard deviation (eV/atom)"
116+
df_elem_err[test_set_std_col] = (
117+
df_frac_comp.where(pd.isna, 1) * df_wbm[each_true_col].values[:, None]
118+
).std()
119+
120+
121+
# %%
122+
fig = ptable_heatmap_plotly(
123+
df_elem_err[test_set_std_col], precision=".2f", colorscale="Inferno"
124+
)
125+
fig.show()
126+
127+
128+
# %%
129+
normalized = True
130+
cs_range = (0, 0.5) # same range for all plots
131+
# cs_range = (None, None) # different range for each plot
132+
for model in (*df_metrics, model_mean_err_col):
133+
df_elem_err[model] = (
134+
df_frac_comp * df_each_err[model].abs().values[:, None]
135+
).mean()
136+
# don't change series values in place, would change the df
137+
per_elem_err = df_elem_err[model].copy(deep=True)
138+
per_elem_err.name = f"{model} (eV/atom)"
139+
if normalized:
140+
per_elem_err /= df_elem_err[test_set_std_col]
141+
per_elem_err.name = f"{model} (normalized by test set std)"
142+
fig = ptable_heatmap_plotly(
143+
per_elem_err, precision=".2f", colorscale="Inferno", cscale_range=cs_range
144+
)
145+
fig.show()
146+
147+
148+
# %%
149+
assert (df_elem_err.isna().sum() < 35).all()
150+
df_elem_err.round(4).to_json(f"{MODELS}/per-element-model-each-errors.json")
151+
152+
153+
# %% scatter plot error by element against prevalence in training set
154+
# for checking correlation and R2 of elemental prevalence in MP training data vs.
155+
# model error
156+
df_elem_err["elem_name"] = [Element(el).long_name for el in df_elem_err.index]
157+
R2 = r2_score(*df_elem_err[[train_count_col, model_mean_err_col]].dropna().values.T)
158+
r_P = df_elem_err[model_mean_err_col].corr(df_elem_err[train_count_col])
159+
160+
fig = df_elem_err.plot.scatter(
161+
x=train_count_col,
162+
y=model_mean_err_col,
163+
backend="plotly",
164+
hover_name="elem_name",
165+
text=df_elem_err.index.where(
166+
(df_elem_err[model_mean_err_col] > 0.04)
167+
| (df_elem_err[train_count_col] > 6_000)
168+
),
169+
title="Per-element error vs element-occurrence in MP training "
170+
f"set: r<sub>Pearson</sub>={r_P:.2f}, R<sup>2</sup>={R2:.2f}",
171+
hover_data={model_mean_err_col: ":.2f", train_count_col: ":,.0f"},
172+
)
173+
fig.update_traces(textposition="top center") # place text above scatter points
174+
fig.layout.title.update(xanchor="center", x=0.5)
175+
fig.show()
176+
177+
# save_fig(fig, f"{FIGS}/element-prevalence-vs-error.svelte")
178+
save_fig(fig, f"{ROOT}/tmp/figures/element-prevalence-vs-error.pdf")
179+
180+
181+
# %% plot EACH errors against least prevalent element in structure (by occurrence in
182+
# MP training set). this seems to correlate more with model error
183+
n_examp_for_rarest_elem_col = "Examples for rarest element in structure"
184+
df_wbm["composition"] = df_wbm.get("composition", df_wbm.formula.map(Composition))
185+
df_elem_err.loc[list(map(str, df_wbm.composition[0]))][train_count_col].min()
186+
df_wbm[n_examp_for_rarest_elem_col] = [
187+
df_elem_err.loc[list(map(str, Composition(formula)))][train_count_col].min()
188+
for formula in tqdm(df_wbm.formula)
189+
]
190+
191+
192+
# %%
193+
df_melt = (
194+
df_each_err.abs()
195+
.reset_index()
196+
.melt(var_name="Model", value_name=each_pred_col, id_vars="material_id")
197+
.set_index("material_id")
198+
)
199+
df_melt[n_examp_for_rarest_elem_col] = df_wbm[n_examp_for_rarest_elem_col]
200+
201+
df_bin = bin_df_cols(df_melt, [n_examp_for_rarest_elem_col, each_pred_col], ["Model"])
202+
df_bin = df_bin.reset_index().set_index("material_id")
203+
df_bin["formula"] = df_wbm.formula
204+
205+
206+
# %%
207+
fig = px.scatter(
208+
df_bin.reset_index(),
209+
x=n_examp_for_rarest_elem_col,
210+
y=each_pred_col,
211+
color="Model",
212+
facet_col="Model",
213+
facet_col_wrap=3,
214+
hover_data=dict(material_id=True, formula=True, Model=False),
215+
title="Absolute errors in model-predicted E<sub>above hull</sub> vs. occurrence "
216+
"count in MP training set<br>of least prevalent element in structure",
217+
)
218+
fig.layout.update(showlegend=False)
219+
fig.layout.title.update(x=0.5, xanchor="center", y=0.95)
220+
fig.layout.margin.update(t=100)
221+
# remove axis labels
222+
fig.update_xaxes(title="")
223+
fig.update_yaxes(title="")
224+
for anno in fig.layout.annotations:
225+
anno.text = anno.text.split("=")[1]
226+
227+
fig.add_annotation(
228+
text="MP occurrence count of least prevalent element in structure",
229+
x=0.5,
230+
y=-0.18,
231+
xref="paper",
232+
yref="paper",
233+
showarrow=False,
234+
)
235+
fig.add_annotation(
236+
text="Absolute error in E<sub>above hull</sub>",
237+
x=-0.07,
238+
y=0.5,
239+
xref="paper",
240+
yref="paper",
241+
showarrow=False,
242+
textangle=-90,
243+
)
244+
245+
fig.show()
246+
save_fig(fig, f"{FIGS}/each-error-vs-least-prevalent-element-in-struct.svelte")

0 commit comments

Comments
 (0)