Skip to content

Commit 7946b5e

Browse files
committed
add scripts/compute_struct_fingerprints.py to generate matminer SiteStats fingerprints for all MP+WBM structures
make EACH scatter plots interactive by binning points and only showing 1 representative per bin. added benefit over only plotting every 50th point: we now show all outliers refactor hist_classified_stable_vs_hull_dist for smaller file size via new kwarg n_bins: int = 200 replace site/static/figs/hist-(pred|pred)-energy-vs-hull-dist-models.webp with interactive site/src/figs/hist-clf-(true|pred)-hull-dist-models.svelte show Accuracy+TNR+TPR on /models page, hide FPR+FNR
1 parent c9fed5a commit 7946b5e

16 files changed

+353
-206
lines changed

matbench_discovery/plots.py

+119-102
Large diffs are not rendered by default.

matbench_discovery/preds.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,29 @@ class PredFiles(Files):
2929
_root = f"{ROOT}/models/"
3030
_key_map = model_labels # remap model keys below to pretty plot labels (see Files)
3131

32-
# bowsr optimizer coupled with original megnet
32+
# BOWSR optimizer coupled with original megnet
3333
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
3434
# default CHGNet model from publication with 400,438 params
3535
chgnet = "chgnet/2023-03-06-chgnet-wbm-IS2RE.csv"
36-
chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
36+
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
37+
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
38+
3739
# CGCnn 10-member ensemble
3840
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
39-
# cgcnn 10-member ensemble with 5-fold training set perturbations
41+
# CGCnn 10-member ensemble with 5-fold training set perturbations
4042
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5.csv"
41-
# original m3gnet straight from publication, not re-trained
43+
44+
# original M3GNet straight from publication, not re-trained
4245
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
43-
# m3gnet-relaxed structures fed into megnet for formation energy prediction
44-
m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
45-
# original megnet straight from publication, not re-trained
46+
# M3GNet-relaxed structures fed into MEGNet for formation energy prediction
47+
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
48+
49+
# original MEGNet straight from publication, not re-trained
4650
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
47-
# magpie composition+voronoi tessellation structure features + sklearn random forest
51+
52+
# Magpie composition+Voronoi tessellation structure features + sklearn random forest
4853
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"
54+
4955
# wrenformer 10-member ensemble
5056
wrenformer = "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
5157

@@ -113,13 +119,14 @@ def load_df_wbm_with_preds(
113119

114120
# load WBM summary dataframe with all models' formation energy predictions (eV/atom)
115121
df_preds = load_df_wbm_with_preds().round(3)
116-
for combo in [["CHGNet", "M3GNet"]]:
117-
df_preds[" + ".join(combo)] = df_preds[combo].mean(axis=1)
122+
# for combo in [["CHGNet", "M3GNet"]]:
123+
# df_preds[" + ".join(combo)] = df_preds[combo].mean(axis=1)
124+
# PRED_FILES[" + ".join(combo)] = "combo"
118125

119126

120127
df_metrics = pd.DataFrame()
121128
df_metrics.index.name = "model"
122-
for model in [*PRED_FILES, "CHGNet + M3GNet"]:
129+
for model in PRED_FILES:
123130
df_metrics[model] = stable_metrics(
124131
df_preds[each_true_col],
125132
df_preds[each_true_col] + df_preds[model] - df_preds[e_form_col],

models/megnet/test_megnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@
5858
data_path = {
5959
"IS2RE": DATA_FILES.wbm_initial_structures,
6060
"RS2RE": DATA_FILES.wbm_computed_structure_entries,
61-
"chgnet_structure": PRED_FILES.CHGNet.replace(".csv", ".json.gz"),
62-
"m3gnet_structure": PRED_FILES.M3GNet.replace(".csv", ".json.gz"),
61+
"chgnet_structure": PRED_FILES.__dict__["CHGNet"].replace(".csv", ".json.gz"),
62+
"m3gnet_structure": PRED_FILES.__dict__["M3GNet"].replace(".csv", ".json.gz"),
6363
}[task_type]
6464
print(f"\nJob started running {timestamp}")
6565
print(f"{data_path=}")

models/voronoi/voronoi_featurize_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
__date__ = "2022-10-31"
2525

2626

27-
data_name = "mp" # "mp"
27+
data_name = "mp"
2828
data_path = {
2929
"wbm": DATA_FILES.wbm_initial_structures,
3030
"mp": DATA_FILES.mp_computed_structure_entries,
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Analyze structures and composition with largest mean error across all models.
2+
Maybe there's some chemistry/region of materials space that all models struggle with?
3+
Might point to deficiencies in the data or models architecture.
4+
"""
5+
6+
7+
# %%
8+
import os
9+
import warnings
10+
11+
import numpy as np
12+
import pandas as pd
13+
from matminer.featurizers.site import CrystalNNFingerprint
14+
from matminer.featurizers.structure import SiteStatsFingerprint
15+
from pymatgen.core import Structure
16+
from tqdm import tqdm
17+
18+
from matbench_discovery import ROOT, timestamp
19+
from matbench_discovery.data import DATA_FILES
20+
from matbench_discovery.slurm import slurm_submit
21+
22+
__author__ = "Janosh Riebesell"
23+
__date__ = "2023-03-26"
24+
25+
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
26+
27+
28+
# %% compute all initial and final MP/WBM structure fingerprints
29+
data_name = "wbm"
30+
data_path = {
31+
"wbm": DATA_FILES.wbm_cses_plus_init_structs,
32+
"mp": DATA_FILES.mp_computed_structure_entries,
33+
}[data_name]
34+
35+
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 0))
36+
slurm_array_task_count = 100
37+
38+
job_name = f"make-{data_name}-struct-fingerprints"
39+
out_dir = f"{ROOT}/data/{data_name}/structure-fingerprints"
40+
os.makedirs(out_dir, exist_ok=True)
41+
42+
slurm_vars = slurm_submit(
43+
job_name=job_name,
44+
out_dir=out_dir,
45+
partition="icelake-himem",
46+
account="LEE-SL3-CPU",
47+
time="6:0:0",
48+
array=f"1-{slurm_array_task_count}",
49+
)
50+
51+
52+
# %%
53+
out_path = f"{out_dir}/site-stats-{slurm_array_task_id}.json.gz"
54+
if os.path.isfile(out_path):
55+
raise SystemExit(f"{out_path = } already exists, exciting early")
56+
57+
print(f"\nJob started running {timestamp}")
58+
print(f"{out_path=}")
59+
60+
61+
# %%
62+
df_in: pd.DataFrame = np.array_split(
63+
pd.read_json(data_path).set_index("material_id"), slurm_array_task_count
64+
)[slurm_array_task_id - 1]
65+
66+
cnn_fp = CrystalNNFingerprint.from_preset("ops")
67+
# including "minimum" and "maximum" increases the fingerprint length from 61 to 122
68+
site_stats_fp = SiteStatsFingerprint(
69+
cnn_fp, stats=("mean", "std_dev", "minimum", "maximum")
70+
)
71+
72+
73+
# %%
74+
init_struct_col = "initial_structure"
75+
final_struct_col = "computed_structure_entry"
76+
init_fp_col = "initial_site_stats_fingerprint"
77+
final_fp_col = "final_site_stats_fingerprint"
78+
for struct_col, fp_col in (
79+
(init_struct_col, init_fp_col),
80+
(final_struct_col, final_fp_col),
81+
("entry", final_fp_col),
82+
):
83+
if struct_col not in df_in:
84+
continue
85+
df_in[fp_col] = None
86+
87+
for row in tqdm(df_in.itertuples(), total=len(df_in)):
88+
struct = getattr(row, struct_col)
89+
if "structure" in struct: # is a ComputedStructureEntry as dict
90+
struct = struct["structure"]
91+
struct = Structure.from_dict(struct)
92+
try:
93+
ss_fp = site_stats_fp.featurize(struct)
94+
df_in.at[row.Index, fp_col] = ss_fp
95+
except Exception as exc:
96+
print(f"{fp_col} for {row.Index} failed: {exc}")
97+
98+
df_in.filter(like="site_stats_fingerprint").to_json(out_path)

scripts/hist_classified_stable_vs_hull_dist.py

+14-31
Original file line numberDiff line numberDiff line change
@@ -12,56 +12,39 @@
1212
from pymatviz.utils import save_fig
1313

1414
from matbench_discovery import FIGS
15-
from matbench_discovery.metrics import stable_metrics
15+
from matbench_discovery.data import df_wbm
1616
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist
17-
from matbench_discovery.preds import df_preds, e_form_col, each_pred_col, each_true_col
17+
from matbench_discovery.preds import df_each_pred, each_true_col
1818

1919
__author__ = "Rhys Goodall, Janosh Riebesell"
2020
__date__ = "2022-06-18"
2121

2222

2323
# %%
2424
model_name = "Wrenformer"
25+
model_name = "CHGNet"
26+
# model_name = "M3GNet"
27+
# model_name = "Voronoi RF"
2528
which_energy: Final = "true"
26-
# std_factor=0,+/-1,+/-2,... changes the criterion for material stability to
27-
# energy+std_factor*std. energy+std means predicted energy plus the model's uncertainty
28-
# in the prediction have to be on or below the convex hull to be considered stable. This
29-
# reduces the false positive rate, but increases the false negative rate. Vice versa for
30-
# energy-std. energy+std should be used for cautious exploration, energy-std for
31-
# exhaustive exploration.
32-
std_factor = 0
33-
34-
# TODO column names to compute standard deviation from are currently hardcoded
35-
# needs to be updated when adding non-aviary models with uncertainty estimation
36-
var_aleatoric = (df_preds.filter(like="_ale_") ** 2).mean(axis=1)
37-
var_epistemic = df_preds.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
38-
std_total = (var_epistemic + var_aleatoric) ** 0.5
39-
std_total = df_preds[f"{model_name}_std"]
40-
df_preds[each_pred_col] = df_preds[each_true_col] + (
41-
(df_preds[model_name] + std_factor * std_total) - df_preds[e_form_col]
42-
)
29+
df_each_pred[each_true_col] = df_wbm[each_true_col]
30+
backend: Final = "plotly"
4331

4432
fig = hist_classified_stable_vs_hull_dist(
45-
df_preds,
33+
df_each_pred,
4634
each_true_col=each_true_col,
47-
each_pred_col=each_pred_col,
35+
each_pred_col=model_name,
4836
which_energy=which_energy,
4937
# stability_threshold=-0.05,
50-
# rolling_acc=0,
51-
backend="plotly",
38+
# rolling_acc=None,
39+
backend=backend,
5240
)
5341

54-
metrics = stable_metrics(df_preds[each_true_col], df_preds[each_pred_col])
55-
legend_title = f"DAF = {metrics['DAF']:.3}"
56-
57-
if hasattr(fig, "legend"): # matplotlib
58-
fig.legend(loc="upper left", frameon=False, title=legend_title)
59-
else: # plotly
60-
fig.layout.legend.title.text = legend_title
42+
if backend == "plotly":
43+
fig.layout.title = model_name
6144
fig.show()
6245

6346

6447
# %%
65-
img_path = f"{FIGS}/wren-wbm-hull-dist-hist-{which_energy=}"
48+
img_path = f"{FIGS}/hist-clf-{which_energy}-hull-dist-{model_name}"
6649
# save_fig(fig, f"{img_path}.svelte")
6750
save_fig(fig, f"{img_path}.webp")

scripts/hist_classified_stable_vs_hull_dist_models.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010
from pymatviz.utils import save_fig
1111

12-
from matbench_discovery import ROOT, STATIC, today
13-
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist, plt
12+
from matbench_discovery import FIGS, ROOT, today
13+
from matbench_discovery.plots import (
14+
hist_classified_stable_vs_hull_dist,
15+
plt,
16+
)
1417
from matbench_discovery.preds import df_metrics, df_preds, e_form_col, each_true_col
1518

1619
__author__ = "Janosh Riebesell"
@@ -58,6 +61,18 @@
5861
**kwds, # type: ignore[arg-type]
5962
)
6063

64+
# true_pos, false_neg, false_pos, true_neg = classify_stable(
65+
# df_melt[each_true_col], df_melt[each_pred_col], stability_threshold=0
66+
# )
67+
# import numpy as np
68+
69+
# df_melt[(clf_col := "classified")] = np.array(clf_labels)[
70+
# true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3
71+
# ]
72+
# import pandas as pd
73+
74+
# pd.cut(df_melt[each_pred_col], bins=10).value_counts()
75+
6176

6277
# TODO add line showing the true hull distance histogram on each subplot
6378
show_metrics = False
@@ -91,11 +106,12 @@
91106
)
92107
anno.text = f"{model_name} · {F1=:.2f} · {FPR=:.2f} · {FNR=:.2f} · {DAF=:.2f}"
93108

109+
fig.layout.height = 1000
94110
fig.layout.margin.update(t=50, b=30, l=40, r=0)
95111
fig.layout.legend.update(
96-
y=1.15, xanchor="center", x=0.5, bgcolor="rgba(0,0,0,0)", orientation="h"
112+
y=1.1, xanchor="center", x=0.5, bgcolor="rgba(0,0,0,0)", orientation="h"
97113
)
98-
fig.update_yaxes(range=[0, 3_000], title_text=None)
114+
fig.update_yaxes(range=[0, 11_000], title_text=None)
99115

100116
# for trace in fig.data:
101117
# # no need to store all 250k x values in plot, leads to 1.7 MB file,
@@ -107,8 +123,8 @@
107123

108124

109125
# %%
110-
img_name = f"hist-{which_energy}-energy-vs-hull-dist-models"
111-
# save_fig(fig, f"{FIGS}/{img_name}.svelte")
126+
img_name = f"hist-clf-{which_energy}-hull-dist-models"
127+
save_fig(fig, f"{FIGS}/{img_name}.svelte")
112128
n_models = len(fig.layout.annotations)
113-
save_fig(fig, f"{STATIC}/{img_name}.webp", scale=3, height=100 * n_models)
129+
# save_fig(fig, f"{STATIC}/{img_name}.webp", scale=3, height=100 * n_models)
114130
save_fig(fig, f"{ROOT}/tmp/figures/{img_name}.pdf", height=550, width=600)

0 commit comments

Comments
 (0)