Skip to content

Commit 4013bb1

Browse files
committed
add descriptions to all model metadata.yml
use Turbo as better initial color map on ptable heatmaps (heatmap now changeable, required sveriodic-table update) add CGCNN+P metrics to model-stats.json update model-metrics.svelte table compile_metrics.py import df_metrics, df_wbm from matbench_discovery.preds remove dates from figure file names
1 parent f9d4d04 commit 4013bb1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+868
-277
lines changed

data/wbm/analysis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112
fig.update_layout(title=dict(text=title, x=0.5, y=0.95))
113113

114114
fig.update_layout(showlegend=False, paper_bgcolor="rgba(0,0,0,0)")
115-
fig.update_xaxes(title_text="WBM energy above MP convex hull (eV/atom)")
115+
fig.update_xaxes(title="WBM energy above MP convex hull (eV/atom)")
116116

117117
for x_pos, label in zip(
118118
[mean, mean + std, mean - std],

data/wbm/fetch_process_wbm_dataset.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pymatviz.utils import save_fig
2222
from tqdm import tqdm
2323

24-
from matbench_discovery import ROOT, today
24+
from matbench_discovery import FIGS, ROOT, today
2525
from matbench_discovery.energy import get_e_form_per_atom
2626
from matbench_discovery.plots import pio
2727

@@ -436,6 +436,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
436436
fig = df_summary.hist(
437437
x="e_form_per_atom_wbm", backend="plotly", log_y=True, range_x=[-5.5, 5.5]
438438
)
439+
fig_compressed = False
439440
fig.add_vline(x=e_form_cutoff, line=dict(dash="dash"))
440441
fig.add_vline(x=-e_form_cutoff, line=dict(dash="dash"))
441442
fig.add_annotation(
@@ -458,13 +459,13 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
458459
# %%
459460
# no need to store all 250k x values in plot, leads to 1.7 MB file, subsample every 10th
460461
# point is enough to see the distribution
461-
if not fig.data[0].compressed:
462-
fig.data[0].compressed = True
462+
if not fig_compressed:
463+
fig_compressed = True
463464
# keep only every 10th data point, round to 3 decimal places to reduce file size
464465
fig.data[0].x = [round(x, 3) for x in fig.data[0].x[::10]]
465466

466467
# recommended to upload SVG to vecta.io/nano afterwards for compression
467-
img_path = f"{module_dir}/2022-12-07-hist-wbm-e-form-per-atom"
468+
img_path = f"{FIGS}/hist-wbm-e-form-per-atom"
468469
# save_fig(fig, f"{img_path}.svg", width=800, height=300)
469470
save_fig(fig, f"{img_path}.svelte")
470471

data/wbm/readme.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ The full set of processing steps used to curate the WBM test set from the raw da
2525

2626
<caption>WBM Formation energy distribution. 524 materials outside green dashed lines were discarded.<br />(zoom out on this plot to see discarded samples)</caption>
2727
<slot name="hist-e-form-per-atom">
28-
<img src="./figs/2022-12-07-hist-wbm-e-form-per-atom.svg" alt="WBM formation energy histogram indicating outlier cutoffs">
28+
<img src="./figs/wbm-e-form-per-atom.svg" alt="WBM formation energy histogram indicating outlier cutoffs">
2929
</slot>
3030

3131
- apply the [`MaterialsProject2020Compatibility`](https://pymatgen.org/pymatgen.entries.compatibility.html#pymatgen.entries.compatibility.MaterialsProject2020Compatibility) energy correction scheme to the formation energies
@@ -99,5 +99,5 @@ The number of stable materials (according to the MP convex hull which is spanned
9999
> Note: [According to the authors](https://www.nature.com/articles/s41524-020-00481-6#Sec2), the stability rate w.r.t. to the more complete hull constructed from the combined train and test set (MP + WBM) for the first 3 rounds of elemental substitution is 18,479 out of 189,981 crystals ($\approx$ 9.7%).
100100
101101
<slot name="wbm-each-hist">
102-
<img src="./figs/2023-01-26-wbm-each-hist.svg" alt="WBM energy above MP convex hull distribution">
102+
<img src="./figs/wbm-each-hist.svg" alt="WBM energy above MP convex hull distribution">
103103
</slot>

matbench_discovery/data.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,7 @@ def glob_to_df(
184184

185185

186186
def load_df_wbm_preds(
187-
models: Sequence[str],
188-
pbar: bool = True,
189-
id_col: str = "material_id",
190-
return_model_dfs: bool = False,
191-
**kwargs: Any,
187+
models: Sequence[str], pbar: bool = True, id_col: str = "material_id", **kwargs: Any
192188
) -> pd.DataFrame:
193189
"""Load WBM summary dataframe with model predictions from disk.
194190
@@ -197,8 +193,6 @@ def load_df_wbm_preds(
197193
matbench_discovery.data.PRED_FILENAMES.
198194
pbar (bool, optional): Whether to show progress bar. Defaults to True.
199195
id_col (str, optional): Column to set as df.index. Defaults to "material_id".
200-
return_model_dfs (bool, optional): Whether to return dict of dataframes for each
201-
model dfs. Defaults to False.
202196
**kwargs: Keyword arguments passed to glob_to_df().
203197
204198
Raises:
@@ -218,9 +212,6 @@ def load_df_wbm_preds(
218212
df = glob_to_df(pattern, pbar=False, **kwargs).set_index(id_col)
219213
dfs[model_name] = df
220214

221-
if return_model_dfs:
222-
return dfs
223-
224215
df_out = df_wbm.copy()
225216
for model_name, df in dfs.items():
226217
model_key = model_name.lower().replace(" + ", "_").replace(" ", "_")

matbench_discovery/metrics.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,17 @@ def classify_stable(
5151

5252

5353
def stable_metrics(
54-
true: Sequence[float], pred: Sequence[float], stability_threshold: float = 0
54+
each_true: Sequence[float],
55+
each_pred: Sequence[float],
56+
stability_threshold: float = 0,
5557
) -> dict[str, float]:
5658
"""
5759
Get a dictionary of stability prediction metrics. Mostly binary classification
5860
metrics, but also MAE, RMSE and R2.
5961
6062
Args:
61-
true (list[float]): true energy values
62-
pred (list[float]): predicted energy values
63+
each_true (list[float]): true energy above convex hull
64+
each_pred (list[float]): predicted energy above convex hull
6365
stability_threshold (float): Where to place stability threshold relative to
6466
convex hull in eV/atom, usually 0 or 0.1 eV. Defaults to 0.
6567
@@ -71,34 +73,31 @@ def stable_metrics(
7173
dict[str, float]: dictionary of classification metrics with keys DAF, Precision,
7274
Recall, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2.
7375
"""
74-
true_pos, false_neg, false_pos, true_neg = classify_stable(
75-
true, pred, stability_threshold
76-
)
77-
78-
n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
79-
sum, (true_pos, false_pos, true_neg, false_neg)
76+
n_true_pos, n_false_neg, n_false_pos, n_true_neg = map(
77+
sum, classify_stable(each_true, each_pred, stability_threshold)
8078
)
8179

8280
n_total_pos = n_true_pos + n_false_neg
8381
n_total_neg = n_true_neg + n_false_pos
84-
prevalence = n_total_pos / len(true) # null rate
85-
precision = n_true_pos / (n_true_pos + n_false_pos)
82+
# prevalence: dummy discovery rate of selecting randomly from all materials
83+
prevalence = n_total_pos / len(each_true)
84+
precision = n_true_pos / (n_true_pos + n_false_pos) # model's discovery rate
8685
recall = n_true_pos / n_total_pos
8786

88-
is_nan = np.isnan(true) | np.isnan(pred)
89-
true, pred = np.array(true)[~is_nan], np.array(pred)[~is_nan]
87+
is_nan = np.isnan(each_true) | np.isnan(each_pred)
88+
each_true, each_pred = np.array(each_true)[~is_nan], np.array(each_pred)[~is_nan]
9089

9190
return dict(
9291
DAF=precision / prevalence,
9392
Precision=precision,
9493
Recall=recall,
95-
Accuracy=(n_true_pos + n_true_neg) / len(true),
94+
Accuracy=(n_true_pos + n_true_neg) / len(each_true),
9695
F1=2 * (precision * recall) / (precision + recall),
9796
TPR=n_true_pos / n_total_pos,
9897
FPR=n_false_pos / n_total_neg,
9998
TNR=n_true_neg / n_total_neg,
10099
FNR=n_false_neg / n_total_pos,
101-
MAE=np.abs(true - pred).mean(),
102-
RMSE=((true - pred) ** 2).mean() ** 0.5,
103-
R2=r2_score(true, pred),
100+
MAE=np.abs(each_true - each_pred).mean(),
101+
RMSE=((each_true - each_pred) ** 2).mean() ** 0.5,
102+
R2=r2_score(each_true, each_pred),
104103
)

matbench_discovery/plots.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
__author__ = "Janosh Riebesell"
2121
__date__ = "2022-08-05"
2222

23-
WhichEnergy = Literal["true", "pred"]
24-
AxLine = Literal["x", "y", "xy", ""]
2523
Backend = Literal["matplotlib", "plotly"]
2624

2725
# --- start global plot settings
@@ -58,22 +56,20 @@
5856
)
5957
px.defaults.labels = quantity_labels | model_labels
6058

61-
# https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout
59+
# color list https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout
6260
colorway = (
6361
"lightseagreen",
6462
"orange",
6563
"lightsalmon",
6664
"dodgerblue",
67-
"aquamarine",
68-
"purple",
69-
"firebrick",
7065
)
7166
clf_labels = ("True Positive", "False Negative", "False Positive", "True Negative")
72-
clf_color_map = dict(zip(clf_labels, colorway))
67+
clf_colors = ("lightseagreen", "orange", "lightsalmon", "dodgerblue")
68+
clf_color_map = dict(zip(clf_labels, clf_colors))
7369

7470
global_layout = dict(
7571
# colorway=px.colors.qualitative.Pastel,
76-
colorway=colorway,
72+
# colorway=colorway,
7773
margin=dict(l=30, r=20, t=60, b=20),
7874
paper_bgcolor="rgba(0,0,0,0)",
7975
# plot_bgcolor="rgba(0,0,0,0)",
@@ -101,7 +97,7 @@ def hist_classified_stable_vs_hull_dist(
10197
each_true_col: str,
10298
each_pred_col: str,
10399
ax: plt.Axes = None,
104-
which_energy: WhichEnergy = "true",
100+
which_energy: Literal["true", "pred"] = "true",
105101
stability_threshold: float | None = 0,
106102
x_lim: tuple[float | None, float | None] = (-0.7, 0.7),
107103
rolling_acc: float | None = 0.02,
@@ -133,7 +129,7 @@ def hist_classified_stable_vs_hull_dist(
133129
(in eV / atom). Same as true energy to convex hull plus predicted minus true
134130
formation energy.
135131
ax (plt.Axes, optional): matplotlib axes to plot on.
136-
which_energy (WhichEnergy, optional): Whether to use the true (DFT) hull
132+
which_energy ('true' | 'pred', optional): Whether to use the true (DFT) hull
137133
distance or the model's predicted hull distance for the histogram.
138134
stability_threshold (float, optional): set stability threshold as distance to
139135
convex hull in eV/atom, usually 0 or 0.1 eV.
@@ -376,7 +372,7 @@ def rolling_mae_vs_hull_dist(
376372

377373
window_bar_anno = f"rolling window={2 * window * 1000:.0f} meV"
378374
dummy_mae = (e_above_hull_true - e_above_hull_true.mean()).abs().mean()
379-
legend_title = f"dummy MAE = {dummy_mae:.2f} eV/atom"
375+
dummy_mae_text = f"dummy MAE = {dummy_mae:.2f} eV/atom"
380376

381377
if backend == "matplotlib":
382378
# assert df_rolling_err.isna().sum().sum() == 0, "NaNs in df_rolling_err"
@@ -430,6 +426,9 @@ def rolling_mae_vs_hull_dist(
430426
horizontalalignment="right",
431427
)
432428

429+
ax.axhline(dummy_mae, color="tab:blue", linestyle="--", linewidth=0.5)
430+
ax.text(dummy_mae, 0.1, dummy_mae_text)
431+
433432
ax.text(
434433
0, 0.13, r"MAE > $|E_\mathrm{above\ hull}|$", horizontalalignment="center"
435434
)
@@ -456,7 +455,7 @@ def rolling_mae_vs_hull_dist(
456455
)
457456

458457
ax.layout.legend.update(
459-
title=legend_title,
458+
title="",
460459
x=1,
461460
y=0,
462461
xanchor="right",
@@ -484,6 +483,11 @@ def rolling_mae_vs_hull_dist(
484483
showarrow=False,
485484
yref="paper",
486485
)
486+
ax.add_hline(
487+
y=dummy_mae,
488+
line=dict(dash="dash", width=0.5),
489+
annotation_text=dummy_mae_text,
490+
)
487491
if show_dft_acc:
488492
ax.add_scatter(
489493
x=(-dft_acc, dft_acc, 0, -dft_acc),
@@ -536,7 +540,7 @@ def cumulative_precision_recall(
536540
metrics: Sequence[str] = ("Precision", "Recall"),
537541
stability_threshold: float = 0, # set stability threshold as distance to convex
538542
# hull in eV / atom, usually 0 or 0.1 eV
539-
project_end_point: AxLine = "xy",
543+
project_end_point: Literal["x", "y", "xy", ""] = "xy",
540544
optimal_recall: str | None = "Optimal Recall",
541545
show_n_stable: bool = True,
542546
backend: Backend = "plotly",
@@ -692,7 +696,7 @@ def cumulative_precision_recall(
692696
**kwargs,
693697
)
694698

695-
line_kwds = dict(color="white", dash="dash", width=0.5)
699+
line_kwds = dict(dash="dash", width=0.5)
696700
for idx, anno in enumerate(fig.layout.annotations):
697701
anno.text = anno.text.split("=")[1]
698702
anno.font.size = 16

matbench_discovery/preds.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pandas as pd
44

5-
from matbench_discovery.data import load_df_wbm_preds
5+
from matbench_discovery.data import PRED_FILENAMES, load_df_wbm_preds
66
from matbench_discovery.metrics import stable_metrics
77

88
"""Centralize data-loading and computing metrics for plotting scripts"""
@@ -18,20 +18,18 @@
1818
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
1919
each_pred_col = "e_above_hull_pred"
2020

21-
df_wbm = load_df_wbm_preds(models).round(3)
22-
23-
for col in [e_form_col, each_true_col]:
24-
assert col in df_wbm, f"{col=} not in {list(df_wbm)=}"
21+
df_wbm = load_df_wbm_preds(list(PRED_FILENAMES)).round(3)
22+
drop_cols = {*PRED_FILENAMES} - {*models}
2523

2624

2725
df_metrics = pd.DataFrame()
28-
for model in models:
26+
for model in list(PRED_FILENAMES):
2927
df_metrics[model] = stable_metrics(
3028
df_wbm[each_true_col],
3129
df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col],
3230
)
3331

34-
assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range"
35-
assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range"
36-
assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range"
37-
assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics"
32+
33+
df_each_pred = pd.DataFrame()
34+
for model in df_metrics.T.MAE.sort_values().index:
35+
df_each_pred[model] = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]

models/bowsr/metadata.yml

+2
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ hyperparams:
3333
n_iter: 100
3434

3535
notes:
36+
description: BOWSR is a Bayesian optimizer with symmetry constraints using a graph deep learning energy model to perform "DFT-free" relaxations of crystal structures.
37+
long: The authors show that this iterative approach improves the accuracy of ML-predicted formation energies over single-shot predictions.
3638
training: Uses same version of MEGNet as standalone MEGNet.

models/cgcnn/metadata.yml

+8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
hyperparams:
2525
Ensemble Size: 10
2626

27+
notes:
28+
description: Published in 2017, CGCNN was the first crystal graph convolutional neural network to directly learn 8 different DFT-computed material properties from a graph representing the atoms and bonds in a crystal.
29+
long: It showed that just like in other areas of ML, given large training sets, embeddings that outperform human-engineered features could be learned directly from the data.
30+
2731
- model_name: CGCNN+P
2832
model_version: 0.1.0 # the aviary version
2933
matbench_discovery_version: 1.0
@@ -54,3 +58,7 @@
5458
hyperparams:
5559
Ensemble Size: 10
5660
Perturbations: 5
61+
62+
notes:
63+
description: This work proposes simple, physically motivated structure perturbations to augment CGCNN's training data of relaxed structures with structures resembling unrelaxed ones but mapped to the same DFT final energy.
64+
long: From this the model should learn to map structures to their nearest energy basin which is supported by a lowering of the energy error on unrelaxed structures.

models/m3gnet/metadata.yml

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
pandas: 1.5.1
2323
trained_on_benchmark: false
2424
notes:
25+
description: M3GNet is a GNN-based universal (as in full periodic table) interatomic potential for materials trained on up to 3-body interactions in the initial, middle and final frame of MP DFT relaxations.
26+
long: It thereby learns to emulate structure relaxation, MD simulations and property prediction of materials across diverse chemical spaces.
2527
training: Using pre-trained model released with paper. Was only trained on a subset of 62,783 MP relaxation trajectories in the 2018 database release (see [related issue](https://github.com/materialsvirtuallab/m3gnet/issues/20#issuecomment-1207087219)).
2628

2729
- model_name: M3GNet + MEGNet
@@ -58,4 +60,5 @@
5860
pandas: 1.5.1
5961
trained_on_benchmark: false
6062
notes:
63+
description: This combination of models uses M3GNet to relax initial structures and then passes it to MEGNet to predict the formation energy.
6164
training: Using pre-trained model released with paper. Was only trained on a subset of 62,783 MP relaxation trajectories in the 2018 database release (see [related issue](https://github.com/materialsvirtuallab/m3gnet/issues/20#issuecomment-1207087219)).

models/megnet/metadata.yml

+2
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,7 @@ requirements:
2929
numpy: 1.24.0
3030
pandas: 1.5.1
3131
trained_on_benchmark: false
32+
3233
notes:
34+
description: MatErials Graph Network is another GNN for material properties of relaxed structure which showed that learned element embeddings encode periodic chemical trends and can be transfer-learned from large data sets (formation energies) to predictions on small data properties (band gaps, elastic moduli).
3335
training: Using pre-trained model released with paper. Was only trained on `MP-crystals-2018.6.1` dataset [available on Figshare](https://figshare.com/articles/Graphs_of_materials_project/7451351).

models/voronoi/metadata.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@ authors:
1414
orcid: https://orcid.org/0000-0003-2248-474X
1515
repo: https://github.com/janosh/matbench-discovery
1616
doi: https://doi.org/10.1103/PhysRevB.96.024104
17-
preprint: https://arxiv.org/abs/2106.11132
1817
requirements:
1918
matminer: 0.8.0
2019
scikit-learn: 1.1.2
2120
pymatgen: 2022.10.22
2221
numpy: 1.24.0
2322
pandas: 1.5.1
2423
trained_on_benchmark: true
24+
25+
notes:
26+
description: A random forest trained to map the combo of composition-based Magpie features and structure-based relaxation-invariant Voronoi tessellation features (bond angles, coordination numbers, ...) to DFT formation energies.
27+
long: This is an old model that predates most deep learning for materials but significantly improved over Coulomb matrix and partial radial distribution function methods. It therefore serves as a good baseline model to see what modern ML buys us.

models/wrenformer/metadata.yml

+4
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ trained_on_benchmark: true
2828

2929
hyperparams:
3030
Ensemble Size: 10
31+
32+
notes:
33+
description: Wrenformer is a standard PyTorch Transformer Encoder trained to learn material embeddings from composition, space group, Wyckoff positions in a structure.
34+
long: It builds on [Roost](https://doi.org/10.1038/s41467-020-19964-7) and [Wren](https://doi.org/10.1126/sciadv.abn4117), by being a fast structure-free model that is still able to distinguish polymorphs through symmetry.

0 commit comments

Comments
 (0)