Skip to content

Commit db34c09

Browse files
committed
add scripts/difficult_structures.py
1 parent a42472c commit db34c09

File tree

6 files changed

+144
-33
lines changed

6 files changed

+144
-33
lines changed

models/cgcnn/metadata.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Ensemble Size: 10
2626

2727
notes:
28-
description: Published in 2017, CGCNN was the first crystal graph convolutional neural network to directly learn 8 different DFT-computed material properties from a graph representing the atoms and bonds in a crystal.
28+
description: Published in 2017, CGCNN was the first crystal graph convolutional neural network to directly learn 8 different DFT-computed material properties from a graph representing the atoms and bonds in a crystal. ![Illustration of the crystal graph convolutional neural networks](https://researchgate.net/profile/Tian-Xie-11/publication/320726915/figure/fig1/AS:635258345119746@1528468800829/Illustration-of-the-crystal-graph-convolutional-neural-networks-a-Construction-of-the.png)
2929
long: It showed that just like in other areas of ML, given large training sets, embeddings that outperform human-engineered features could be learned directly from the data.
3030

3131
- model_name: CGCNN+P
@@ -60,5 +60,5 @@
6060
Perturbations: 5
6161

6262
notes:
63-
description: This work proposes simple, physically motivated structure perturbations to augment CGCNN's training data of relaxed structures with structures resembling unrelaxed ones but mapped to the same DFT final energy.
64-
long: From this the model should learn to map structures to their nearest energy basin which is supported by a lowering of the energy error on unrelaxed structures.
63+
description: This work proposes simple structure perturbations to augment CGCNN's training data of relaxed structures with randomly perturbed ones resembling unrelaxed structures that are mapped to the same DFT final energy during training. ![Step function PES](https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41524-022-00891-8/MediaObjects/41524_2022_891_Fig1_HTML.png?as=webp)
64+
long: The model is essentially taught the potential energy surface (PES) is a step-function that maps each valley to its local minimum. The expectation is that during testing on unrelaxed structures, the model will predict the energy of the nearest basin in the PES. The authors confirm this by demonstrating a lowering of the energy error on unrelaxed structures.

scripts/compile_metrics.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -121,37 +121,36 @@
121121

122122

123123
# %%
124+
higher_is_better = ["DAF", "R²", "Precision", "Recall", "F1", "Accuracy", "TPR", "TNR"]
125+
lower_is_better = ["MAE", "RMSE", "FNR", "FPR"]
124126
styler = (
125127
df_metrics.T.rename(columns={"R2": "R²"})
128+
# append arrow up/down to table headers to indicate higher/lower metric is better
129+
# .rename(columns=lambda x: x + " ↑" if x in higher_is_better else x + " ↓")
126130
.style.format(precision=2)
127-
.background_gradient(
128-
cmap="viridis_r", # lower is better so reverse color map
129-
subset=["MAE", "RMSE", "FNR", "FPR"],
130-
)
131+
# reverse color map if lower=better
132+
.background_gradient(cmap="viridis_r", subset=lower_is_better)
131133
# .background_gradient(
132134
# cmap="viridis_r",
133135
# subset=[time_col],
134136
# gmap=np.log10(df_stats[time_col].to_numpy()), # for log scaled color map
135137
# )
136-
.background_gradient(
137-
cmap="viridis", # higher is better
138-
subset=["DAF", "R²", "Precision", "Recall", "F1", "Accuracy", "TPR", "TNR"],
139-
)
138+
.background_gradient(cmap="viridis", subset=higher_is_better)
140139
)
141-
142140
styles = {
143141
"": "font-family: sans-serif; border-collapse: collapse;",
144-
"td, th": "border: 1px solid #ddd; text-align: left; padding: 8px; white-space: nowrap;",
142+
"td, th": "border: none; padding: 4px 6px; white-space: nowrap;",
143+
"th": "border: 1px solid; border-width: 1px 0; text-align: left;",
145144
}
146145
styler.set_table_styles([dict(selector=sel, props=styles[sel]) for sel in styles])
147146
styler.set_uuid("")
148147

149148

150149
# %% export model metrics as styled HTML table
151150
# insert svelte {...props} forwarding to the table element
152-
html = styler.to_html().replace("<table", "<table {...$$props}")
151+
html_table = styler.to_html().replace("<table", "<table {...$$props}")
153152
with open(f"{FIGS}/metrics-table.svelte", "w") as file:
154-
file.write(html)
153+
file.write(html_table)
155154

156155

157156
# %%

scripts/difficult_structures.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# %%
2+
import matplotlib.pyplot as plt
3+
import pandas as pd
4+
from pymatgen.core import Structure
5+
from pymatviz import plot_structure_2d, ptable_heatmap_plotly
6+
7+
from matbench_discovery import ROOT
8+
from matbench_discovery.metrics import classify_stable
9+
from matbench_discovery.preds import df_each_err, df_each_pred, df_wbm, each_true_col
10+
11+
__author__ = "Janosh Riebesell"
12+
__date__ = "2023-02-15"
13+
14+
df_each_err[each_true_col] = df_wbm[each_true_col]
15+
mean_ae_col = "All models mean absolute error (eV/atom)"
16+
df_each_err[mean_ae_col] = df_wbm[mean_ae_col] = df_each_err.abs().mean(axis=1)
17+
18+
19+
# %%
20+
cse_path = f"{ROOT}/data/wbm/2022-10-19-wbm-computed-structure-entries.json.bz2"
21+
df_cse = pd.read_json(cse_path).set_index("material_id")
22+
23+
24+
# %%
25+
n_rows, n_cols = 5, 4
26+
for which in ("best", "worst"):
27+
fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_rows, 4 * n_cols))
28+
n_axs = len(axs.flat)
29+
30+
errs = (
31+
df_each_err.mean_ae.nsmallest(n_axs)
32+
if which == "best"
33+
else df_each_err.mean_ae.nlargest(n_axs)
34+
)
35+
title = f"{which} {len(errs)} structures (across {len(list(df_each_pred))} models)"
36+
fig.suptitle(title, fontsize=16, fontweight="bold", y=0.95)
37+
38+
for idx, (ax, (id, err)) in enumerate(zip(axs.flat, errs.items()), 1):
39+
struct = Structure.from_dict(
40+
df_cse.computed_structure_entry.loc[id]["structure"]
41+
)
42+
plot_structure_2d(struct, ax=ax)
43+
_, spg_num = struct.get_space_group_info()
44+
formula = struct.composition.reduced_formula
45+
ax.set_title(
46+
f"{idx}. {formula} (spg={spg_num})\n{id} {err=:.2f}", fontweight="bold"
47+
)
48+
49+
fig.savefig(f"{ROOT}/tmp/figures/{which}-{len(errs)}-structures.webp", dpi=300)
50+
51+
52+
# %% plotly scatter plot of largest model errors with points sized by mean error and
53+
# colored by true stability
54+
fig = df_wbm.nlargest(200, mean_ae_col).plot.scatter(
55+
x=each_true_col,
56+
y=mean_ae_col,
57+
color=each_true_col,
58+
size=mean_ae_col,
59+
backend="plotly",
60+
)
61+
fig.layout.coloraxis.colorbar.update(
62+
title="DFT distance to convex hull (eV/atom)",
63+
title_side="top",
64+
yanchor="bottom",
65+
y=1,
66+
xanchor="center",
67+
x=0.5,
68+
orientation="h",
69+
thickness=12,
70+
)
71+
fig.show()
72+
73+
74+
# %% find materials that were misclassified by all models
75+
for model in df_each_pred:
76+
true_pos, false_neg, false_pos, true_neg = classify_stable(
77+
df_each_pred[model], df_wbm[each_true_col]
78+
)
79+
df_wbm[f"{model}_true_pos"] = true_pos
80+
df_wbm[f"{model}_false_neg"] = false_neg
81+
df_wbm[f"{model}_false_pos"] = false_pos
82+
df_wbm[f"{model}_true_neg"] = true_neg
83+
84+
85+
df_wbm["all_true_pos"] = df_wbm.filter(like="_true_pos").all(axis=1)
86+
df_wbm["all_false_neg"] = df_wbm.filter(like="_false_neg").all(axis=1)
87+
df_wbm["all_false_pos"] = df_wbm.filter(like="_false_pos").all(axis=1)
88+
df_wbm["all_true_neg"] = df_wbm.filter(like="_true_neg").all(axis=1)
89+
90+
df_wbm.filter(like="all_").sum()
91+
92+
93+
# %%
94+
ptable_heatmap_plotly(df_wbm[df_wbm.all_false_pos].formula, colorscale="Viridis")
95+
ptable_heatmap_plotly(df_wbm[df_wbm.all_false_neg].formula, colorscale="Viridis")
96+
97+
98+
# %%
99+
df_each_err.abs().mean().sort_values()
100+
df_each_err.abs().mean(axis=1).nlargest(25)
101+
102+
103+
# %% get mean distance to convex hull for each classification
104+
df_wbm.query("all_true_pos").describe()
105+
df_wbm.query("all_false_pos").describe()

scripts/prc_roc_curves_models.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,14 @@
6666
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles
6767

6868
fig.layout.coloraxis.colorbar.update(
69-
x=1, y=1, xanchor="right", yanchor="top", thickness=14, len=0.2, title_side="right"
69+
x=1,
70+
y=1,
71+
xanchor="right",
72+
yanchor="top",
73+
thickness=14,
74+
lenmode="pixels",
75+
len=210,
76+
title_side="right",
7077
)
7178
fig.add_shape(type="line", x0=0, y0=0, x1=1, y1=1, line=line, row="all", col="all")
7279
fig.add_annotation(text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10)

site/src/figs/metrics-table.svelte

+9-6
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,20 @@
44
border-collapse: collapse;
55
}
66
#T_ td {
7-
border: 1px solid #ddd;
8-
text-align: left;
9-
padding: 8px;
7+
border: none;
8+
padding: 4px 6px;
109
white-space: nowrap;
1110
}
1211
#T_ th {
13-
border: 1px solid #ddd;
14-
text-align: left;
15-
padding: 8px;
12+
border: none;
13+
padding: 4px 6px;
1614
white-space: nowrap;
1715
}
16+
#T_ th {
17+
border: 1px solid;
18+
border-width: 1px 0;
19+
text-align: left;
20+
}
1821
#T__row0_col0, #T__row0_col2, #T__row0_col3, #T__row0_col5, #T__row0_col10, #T__row1_col7, #T__row1_col8, #T__row3_col4, #T__row3_col6, #T__row3_col9, #T__row5_col1, #T__row5_col11 {
1922
background-color: #440154;
2023
color: #f1f1f1;

site/src/routes/about-the-test-set/tmi/+page.svelte

+8-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
<script lang="ts">
22
import { ElemCountInset } from '$lib'
3-
import {
4-
ColorScaleSelect,
5-
PeriodicTable,
6-
TableInset,
7-
type ChemicalElement,
8-
} from 'elementari'
3+
import type { ChemicalElement } from 'elementari'
4+
import { ColorScaleSelect, PeriodicTable, TableInset } from 'elementari'
95
import { RadioButtons, Toggle } from 'svelte-zoo'
106
import type { Snapshot } from './$types'
117
@@ -27,7 +23,7 @@
2723
$: color_scale = selected[0]
2824
$: active_counts = elem_counts[filter]
2925
30-
const style = `display: flex; gap: 5pt; place-items: center; place-content: center;`
26+
const style = `display: flex; place-items: center; place-content: center;`
3127
3228
export const snapshot: Snapshot = {
3329
capture: () => ({ filter, log }),
@@ -42,8 +38,9 @@ Stuff that didn't make the cut into the main page describing the WBM test set.
4238

4339
<h2>WBM Element Counts for <code>{filter}</code></h2>
4440

45-
Filter WBM element counts by composition arity (how many elements in the formula) or batch
46-
index (which iteration of elemental substitution the structure was generated in).
41+
Filter WBM element counts by composition<strong>arity</strong> (how many elements in the
42+
formula) or <strong>batch index</strong> (which iteration of elemental substitution the
43+
structure was generated in).
4744

4845
<ColorScaleSelect bind:selected />
4946
<ul>
@@ -77,11 +74,11 @@ index (which iteration of elemental substitution the structure was generated in)
7774
display: flex;
7875
gap: 1ex;
7976
}
80-
strong {
77+
ul > li strong {
8178
background-color: rgba(255, 255, 255, 0.1);
8279
padding: 3pt 4pt;
8380
}
84-
strong.active {
81+
ul > li strong.active {
8582
background-color: teal;
8683
}
8784
</style>

0 commit comments

Comments
 (0)