Skip to content

Commit fbe847b

Browse files
committed
run model combos (m3gnet|chgnet)+megnet through test_megnet.py
instead of join_(m3gnet|chgnet)_results.py update all plots with CHGNet results
1 parent 01658ad commit fbe847b

21 files changed

+471
-475
lines changed

matbench_discovery/metrics.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,23 @@ def stable_metrics(
8686
is_nan = np.isnan(each_true) | np.isnan(each_pred)
8787
each_true, each_pred = np.array(each_true)[~is_nan], np.array(each_pred)[~is_nan]
8888

89+
TPR = recall
90+
FPR = n_false_pos / n_total_neg
91+
TNR = n_true_neg / n_total_neg
92+
FNR = n_false_neg / n_total_pos
93+
# sanity check: false positives + true negatives = all negatives
94+
assert FPR + TNR == 1
95+
# sanity check: true positives + false negatives = all positives
96+
assert TPR + FNR == 1
97+
8998
return dict(
9099
F1=2 * (precision * recall) / (precision + recall),
91100
R2=r2_score(each_true, each_pred),
92101
DAF=precision / prevalence,
93102
Precision=precision,
94103
Recall=recall,
104+
**dict(TPR=TPR, FPR=FPR, TNR=TNR, FNR=FNR),
95105
Accuracy=(n_true_pos + n_true_neg) / len(each_true),
96-
TPR=n_true_pos / n_total_pos,
97-
FPR=n_false_pos / n_total_neg,
98-
TNR=n_true_neg / n_total_neg,
99-
FNR=n_false_neg / n_total_pos,
100106
MAE=np.abs(each_true - each_pred).mean(),
101107
RMSE=((each_true - each_pred) ** 2).mean() ** 0.5,
102108
)

matbench_discovery/plots.py

+2
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,8 @@ def rolling_mae_vs_hull_dist(
486486
yanchor="bottom",
487487
title_font=dict(size=15),
488488
)
489+
# change tooltip precision to 2 decimal places
490+
ax.update_traces(hovertemplate="x = %{x:.2f} eV/atom<br>y = %{y:.2f} eV/atom")
489491
ax.layout.xaxis.title.text = "E<sub>above MP hull</sub> (eV/atom)"
490492
ax.layout.yaxis.title.text = "rolling MAE (eV/atom)"
491493
ax.update_xaxes(range=x_lim)

matbench_discovery/preds.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ class PredFiles(Files):
3232
# bowsr optimizer coupled with original megnet
3333
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
3434
# default CHGNet model from publication with 400,438 params
35-
chgnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
36-
chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
35+
chgnet = "chgnet/2023-03-06-chgnet-wbm-IS2RE.csv"
36+
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
3737
# CGCnn 10-member ensemble
3838
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
3939
# cgcnn 10-member ensemble with 5-fold training set perturbations
4040
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5.csv"
4141
# original m3gnet straight from publication, not re-trained
4242
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
4343
# m3gnet-relaxed structures fed into megnet for formation energy prediction
44-
m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
44+
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
4545
# original megnet straight from publication, not re-trained
4646
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
4747
# magpie composition+voronoi tessellation structure features + sklearn random forest

models/bowsr/test_bowsr.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,9 @@
7272
print(f"{data_path = }")
7373
print(f"{out_path = }")
7474

75-
76-
# %%
77-
df_wbm = pd.read_json(data_path).set_index("material_id")
78-
79-
df_in: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
80-
slurm_array_task_id - 1
81-
]
75+
df_in: pd.DataFrame = np.array_split(
76+
pd.read_json(data_path).set_index("material_id"), slurm_array_task_count
77+
)[slurm_array_task_id - 1]
8278

8379

8480
# %%

models/chgnet/join_chgnet_results.py

+2-43
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from glob import glob
1212

1313
import pandas as pd
14-
from megnet.utils.models import load_model
15-
from pymatgen.core import Structure
1614
from pymatviz import density_scatter
1715
from tqdm import tqdm
1816

@@ -62,53 +60,14 @@
6260

6361

6462
# %%
65-
ax = density_scatter(x=df_wbm[e_form_col], y=df_wbm[e_form_chgnet_col])
66-
67-
68-
# %% load 2019 MEGNet formation energy model
69-
megnet_mp_e_form = load_model("Eform_MP_2019")
70-
megnet_e_form_preds: dict[str, float] = {}
71-
72-
73-
# %% predict formation energies on chgnet relaxed structure with MEGNet
74-
for material_id, struct in tqdm(
75-
df_chgnet.chgnet_structure.items(), total=len(df_chgnet)
76-
):
77-
if material_id in megnet_e_form_preds:
78-
continue
79-
try:
80-
if isinstance(struct, dict):
81-
struct = Structure.from_dict(struct)
82-
[e_form_per_atom] = megnet_mp_e_form.predict_structure(struct)
83-
megnet_e_form_preds[material_id] = e_form_per_atom
84-
except Exception as exc:
85-
print(f"Failed to predict {material_id=}: {exc}")
86-
87-
e_form_megnet_col = "e_form_per_atom_chgnet_megnet"
88-
# remove legacy MP corrections that MEGNet was trained on and apply newer MP2020
89-
# corrections instead
90-
df_chgnet[e_form_megnet_col] = (
91-
pd.Series(megnet_e_form_preds)
92-
- df_wbm.e_correction_per_atom_mp_legacy
93-
+ df_wbm.e_correction_per_atom_mp2020
94-
)
95-
96-
assert (
97-
n_isna := df_chgnet.e_form_per_atom_chgnet_megnet.isna().sum()
98-
) < 10, f"too many missing MEGNet preds: {n_isna}"
99-
100-
101-
# %%
102-
ax = density_scatter(df=df_chgnet, x=e_form_chgnet_col, y=e_form_megnet_col)
103-
ax = density_scatter(df=df_chgnet, x=e_form_col, y=e_form_megnet_col)
63+
ax = density_scatter(df=df_wbm, x=e_form_col, y=e_form_chgnet_col)
10464

10565

10666
# %%
10767
out_path = f"{module_dir}/{today}-chgnet-wbm-{task_type}.json.gz"
10868
df_chgnet = df_chgnet.round(4)
109-
df_chgnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
110-
11169
df_chgnet.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))
70+
df_chgnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
11271

11372
# in_path = f"{module_dir}/2023-03-04-chgnet-wbm-IS2RE.json.gz"
11473
# df_chgnet = pd.read_csv(in_path.replace(".json.gz", ".csv")).set_index("material_id")

models/chgnet/metadata.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model_name: [CHGNet, CHGNet + MEGNet]
1+
model_name: CHGNet
22
model_version: 0.0.1
33
matbench_discovery_version: 1.0
44
date_added: "2023-03-03"

models/chgnet/test_chgnet.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,12 @@
6161
}[task_type]
6262
print(f"\nJob started running {timestamp}")
6363
print(f"{data_path=}")
64-
df_in = pd.read_json(data_path).set_index("material_id")
6564
e_pred_col = "chgnet_energy"
6665
max_steps = 2000
6766

68-
df_in: pd.DataFrame = np.array_split(df_in, slurm_array_task_count)[
69-
slurm_array_task_id - 1
70-
]
67+
df_in: pd.DataFrame = np.array_split(
68+
pd.read_json(data_path).set_index("material_id"), slurm_array_task_count
69+
)[slurm_array_task_id - 1]
7170

7271
run_params = dict(
7372
data_path=data_path,
@@ -124,7 +123,7 @@
124123
].reset_index()
125124
)
126125

127-
title = f"CHGNet {task_type} ({len(df_wbm):,})"
126+
title = f"CHGNet {task_type} ({len(df_out):,})"
128127
wandb_scatter(table, fields=dict(x="uncorrected_energy", y=e_pred_col), title=title)
129128

130129
wandb.log_artifact(out_path, type=f"chgnet-wbm-{task_type}")

models/m3gnet/join_m3gnet_results.py

+1-43
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
from glob import glob
1212

1313
import pandas as pd
14-
from megnet.utils.models import load_model
1514
from pymatgen.core import Structure
1615
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
1716
from pymatgen.entries.computed_entries import ComputedStructureEntry
1817
from pymatviz import density_scatter
1918
from tqdm import tqdm
2019

2120
from matbench_discovery import today
22-
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
21+
from matbench_discovery.data import DATA_FILES, as_dict_handler
2322
from matbench_discovery.energy import get_e_form_per_atom
2423

2524
__author__ = "Janosh Riebesell"
@@ -93,47 +92,6 @@
9392
)
9493

9594

96-
# %% load 2019 MEGNet formation energy model
97-
megnet_mp_e_form = load_model("Eform_MP_2019")
98-
megnet_e_form_preds: dict[str, float] = {}
99-
100-
101-
# %% predict formation energies on M3GNet relaxed structure with MEGNet
102-
for material_id, struct in tqdm(
103-
df_m3gnet.m3gnet_structure.items(), total=len(df_m3gnet)
104-
):
105-
if material_id in megnet_e_form_preds:
106-
continue
107-
try:
108-
if isinstance(struct, dict):
109-
struct = struct = Structure.from_dict(struct)
110-
df_m3gnet.loc[material_id, struct_col] = struct
111-
112-
[e_form_per_atom] = megnet_mp_e_form.predict_structure(struct)
113-
megnet_e_form_preds[material_id] = e_form_per_atom
114-
except Exception as exc:
115-
print(f"Failed to predict {material_id=}: {exc}")
116-
117-
pred_col_megnet = "e_form_per_atom_m3gnet_megnet"
118-
# remove legacy MP corrections that MEGNet was trained on and apply newer MP2020
119-
# corrections instead
120-
df_m3gnet[pred_col_megnet] = (
121-
pd.Series(megnet_e_form_preds)
122-
- df_wbm.e_correction_per_atom_mp_legacy
123-
+ df_wbm.e_correction_per_atom_mp2020
124-
)
125-
126-
assert (
127-
n_isna := df_m3gnet.e_form_per_atom_m3gnet_megnet.isna().sum()
128-
) < 10, f"too many missing MEGNet preds: {n_isna}"
129-
130-
131-
# %%
132-
ax = density_scatter(
133-
df=df_m3gnet, x="e_form_per_atom_m3gnet", y="e_form_per_atom_m3gnet_megnet"
134-
)
135-
136-
13795
# %%
13896
out_path = f"{module_dir}/{today}-m3gnet-wbm-{task_type}.json.gz"
13997
df_m3gnet = df_m3gnet.round(4)

models/m3gnet/test_m3gnet.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,11 @@
6666
}[task_type]
6767
print(f"\nJob started running {timestamp}")
6868
print(f"{data_path=}")
69-
df_wbm = pd.read_json(data_path).set_index("material_id")
7069
e_pred_col = "m3gnet_energy"
7170

72-
df_in: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
73-
slurm_array_task_id - 1
74-
]
71+
df_in: pd.DataFrame = np.array_split(
72+
pd.read_json(data_path).set_index("material_id"), slurm_array_task_count
73+
)[slurm_array_task_id - 1]
7574

7675
run_params = dict(
7776
data_path=data_path,

models/megnet/test_megnet.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
from importlib.metadata import version
1313

14+
import numpy as np
1415
import pandas as pd
1516
import wandb
1617
from megnet.utils.models import load_model
@@ -21,15 +22,17 @@
2122
from matbench_discovery import DEBUG, timestamp, today
2223
from matbench_discovery.data import DATA_FILES, df_wbm
2324
from matbench_discovery.plots import wandb_scatter
25+
from matbench_discovery.preds import PRED_FILES
2426
from matbench_discovery.slurm import slurm_submit
2527

2628
__author__ = "Janosh Riebesell"
2729
__date__ = "2022-11-14"
2830

29-
task_type = "IS2RE"
31+
task_type = "chgnet_structure"
3032
module_dir = os.path.dirname(__file__)
3133
job_name = f"megnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
3234
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
35+
slurm_array_task_count = 20
3336

3437
slurm_vars = slurm_submit(
3538
job_name=job_name,
@@ -38,27 +41,33 @@
3841
account="LEE-SL3-CPU",
3942
time="12:0:0",
4043
slurm_flags=("--mem", "30G"),
44+
array=f"1-{slurm_array_task_count}",
4145
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
4246
# https://stackoverflow.com/a/40982782
4347
pre_cmd="TF_CPP_MIN_LOG_LEVEL=2",
4448
)
4549

4650

4751
# %%
52+
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
4853
out_path = f"{out_dir}/megnet-e-form-preds.csv"
4954
if os.path.isfile(out_path):
5055
raise SystemExit(f"{out_path = } already exists, exciting early")
5156

5257
data_path = {
5358
"IS2RE": DATA_FILES.wbm_initial_structures,
5459
"RS2RE": DATA_FILES.wbm_computed_structure_entries,
60+
"chgnet_structure": PRED_FILES.__dict__["CHGNet"].replace(".csv", ".json.gz"),
61+
"m3gnet_structure": PRED_FILES.__dict__["M3GNet"].replace(".csv", ".json.gz"),
5562
}[task_type]
5663
print(f"\nJob started running {timestamp}")
5764
print(f"{data_path=}")
5865
e_form_col = "e_form_per_atom_mp2020_corrected"
5966
assert e_form_col in df_wbm, f"{e_form_col=} not in {list(df_wbm)=}"
6067

61-
df_in = pd.read_json(data_path).set_index("material_id")
68+
df_in: pd.DataFrame = np.array_split(
69+
pd.read_json(data_path).set_index("material_id"), slurm_array_task_count
70+
)[slurm_array_task_id - 1]
6271
megnet_mp_e_form = load_model(model_name := "Eform_MP_2019")
6372

6473

@@ -77,15 +86,17 @@
7786

7887

7988
# %%
80-
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]
89+
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}.get(
90+
task_type, task_type # input_col=task_type for CHGNet and M3GNet
91+
)
8192

8293
if task_type == "RS2RE":
8394
df_in[input_col] = [x["structure"] for x in df_in.computed_structure_entry]
8495

8596
structures = df_in[input_col].map(Structure.from_dict).to_dict()
8697

8798
megnet_e_form_preds = {}
88-
for material_id in tqdm(structures, disable=None):
99+
for material_id in tqdm(structures):
89100
if material_id in megnet_e_form_preds:
90101
continue
91102
try:

models/voronoi/voronoi_featurize_dataset.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,9 @@
5959
raise SystemExit(f"{out_path = } already exists, exciting early")
6060

6161
print(f"{data_path=}")
62-
df = pd.read_json(data_path).set_index("material_id")
63-
df_in: pd.DataFrame = np.array_split(df, slurm_array_task_count)[
64-
slurm_array_task_id - 1
65-
]
62+
df_in: pd.DataFrame = np.array_split(
63+
pd.read_json(data_path).set_index("material_id"), slurm_array_task_count
64+
)[slurm_array_task_id - 1]
6665

6766
if data_name == "mp": # extract structure dicts from ComputedStructureEntry
6867
struct_dicts = [x["structure"] for x in df_in.entry]

scripts/compile_metrics.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
),
3636
),
3737
"CHGNet": dict(
38-
n_runs=100,
38+
n_runs=102,
3939
filters=dict(
4040
display_name={"$regex": "chgnet-wbm-IS2RE-"},
41-
created_at={"$lt": "2023-03-03"},
41+
created_at={"$gt": "2023-03-05", "$lt": "2023-03-07"},
4242
),
4343
),
4444
"CGCNN": dict(
@@ -155,6 +155,8 @@
155155
}
156156
styler.set_table_styles([dict(selector=sel, props=styles[sel]) for sel in styles])
157157
styler.set_uuid("")
158+
# hide redundant metrics (TPR = Recall, FPR = 1 - TNR, FNR = 1 - TPR)
159+
styler.hide(["Recall", "FPR", "FNR"], axis=1)
158160

159161

160162
# %% export model metrics as styled HTML table
@@ -183,8 +185,7 @@
183185

184186
df_stats.attrs["Total Run Time"] = df_stats[time_col].sum()
185187

186-
stats_out = f"{MODELS}/model-stats.json"
187-
df_stats.round(2).to_json(stats_out, orient="index")
188+
df_stats.round(2).to_json(f"{MODELS}/model-stats.json", orient="index")
188189

189190

190191
# %% plot model run times as pie chart

scripts/cumulative_clf_metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
)
4949
fig.update_traces(line=dict(width=3))
5050
for trace in fig.data:
51-
if trace.name in df_metrics.T.sort_values("F1").index[6:]:
51+
if trace.name in df_metrics.T.sort_values("F1").index[:-6]:
5252
trace.visible = "legendonly" # show only top models by default
5353
last_idx = pd.Series(trace.y).last_valid_index()
5454
last_x = trace.x[last_idx]

0 commit comments

Comments
 (0)