Skip to content

Commit 01658ad

Browse files
committed
fix join_chgnet_results.py by removing code to apply MP2020 corrections
reason: unlike M3GNet which predicts raw DFT energies, CHGNet targets include MP2020 corrections. Hence we don't need to correct afterwards increase CHGnet max relax steps: 500 -> 2000 small train+test script refactor
1 parent c8eaebd commit 01658ad

22 files changed

+252065
-252138
lines changed

matbench_discovery/metrics.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def stable_metrics(
6464
stability_threshold (float): Where to place stability threshold relative to
6565
convex hull in eV/atom, usually 0 or 0.1 eV. Defaults to 0.
6666
67-
Note: Could be replaced by sklearn.metrics.classification_report() which takes
68-
binary labels. I.e. classification_report(true > 0, pred > 0, output_dict=True)
69-
should give equivalent results.
67+
Note: Should give equivalent classification metrics to sklearn.metrics.
68+
classification_report(each_true > 0, each_pred > 0, output_dict=True) which
69+
takes binary labels.
7070
7171
Returns:
7272
dict[str, float]: dictionary of classification metrics with keys DAF, Precision,

models/bowsr/join_bowsr_results.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,9 @@
3131
for file_path in tqdm(file_paths):
3232
if file_path in dfs:
3333
continue
34-
df = pd.read_json(file_path).set_index("material_id")
34+
dfs[file_path] = pd.read_json(file_path).set_index("material_id")
3535

36-
dfs[file_path] = df
3736

38-
39-
# %%
4037
df_bowsr = pd.concat(dfs.values()).round(4)
4138

4239

models/bowsr/test_bowsr.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
task_type = "IS2RE" # "RS2RE"
3131
module_dir = os.path.dirname(__file__)
32-
# set large job array size for fast testing/debugging
32+
# set large job array size for smaller data splits and faster testing/debugging
3333
slurm_array_task_count = 500
3434
# see https://stackoverflow.com/a/55431306 for how to change array throttling
3535
# post submission
@@ -95,9 +95,7 @@
9595
data_path=data_path,
9696
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
9797
energy_model=energy_model,
98-
maml_version=version("maml"),
99-
energy_model_version=version(energy_model),
100-
numpy_version=version("numpy"),
98+
**{f"{dep}_version": version(dep) for dep in ("maml", "numpy", energy_model)},
10199
optimize_kwargs=optimize_kwargs,
102100
task_type=task_type,
103101
slurm_vars=slurm_vars,

models/cgcnn/test_cgcnn.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@
8484
run_params = dict(
8585
data_path=data_path,
8686
df=dict(shape=str(df.shape), columns=", ".join(df)),
87-
aviary_version=version("aviary"),
88-
numpy_version=version("numpy"),
89-
torch_version=version("torch"),
87+
**{f"{dep}_version": version(dep) for dep in ("aviary", "numpy", "torch")},
9088
ensemble_size=len(runs),
9189
task_type=task_type,
9290
target_col=e_form_col,

models/cgcnn/train_cgcnn.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@
102102
run_params = dict(
103103
data_path=data_path,
104104
batch_size=batch_size,
105-
aviary_version=version("aviary"),
106-
numpy_version=version("numpy"),
107-
torch_version=version("torch"),
105+
**{f"{dep}_version": version(dep) for dep in ("aviary", "numpy", "torch")},
108106
train_df=dict(shape=str(train_data.df.shape), columns=", ".join(train_df)),
109107
test_df=dict(shape=str(test_data.df.shape), columns=", ".join(test_df)),
110108
slurm_vars=slurm_vars,

models/chgnet/2023-03-04-chgnet-wbm-IS2RE.csv

+251,739-251,739
Large diffs are not rendered by default.

models/chgnet/join_chgnet_results.py

+20-43
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Concatenate chgnet results from multiple data files generated by slurm job array
1+
"""Concatenate CHGNet results from multiple data files generated by slurm job array
22
into single file.
33
"""
44

@@ -13,13 +13,11 @@
1313
import pandas as pd
1414
from megnet.utils.models import load_model
1515
from pymatgen.core import Structure
16-
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
17-
from pymatgen.entries.computed_entries import ComputedStructureEntry
1816
from pymatviz import density_scatter
1917
from tqdm import tqdm
2018

2119
from matbench_discovery import today
22-
from matbench_discovery.data import DATA_FILES, as_dict_handler
20+
from matbench_discovery.data import as_dict_handler
2321
from matbench_discovery.energy import get_e_form_per_atom
2422
from matbench_discovery.preds import df_wbm, e_form_col
2523

@@ -32,7 +30,7 @@
3230
# %%
3331
module_dir = os.path.dirname(__file__)
3432
task_type = "IS2RE"
35-
date = "2023-03-04"
33+
date = "2023-03-06"
3634
glob_pattern = f"{date}-chgnet-wbm-{task_type}*/*.json.gz"
3735
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
3836
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
@@ -48,47 +46,23 @@
4846
# drop trajectory to save memory
4947
dfs[file_path] = df.drop(columns="chgnet_trajectory")
5048

51-
52-
# %%
5349
df_chgnet = pd.concat(dfs.values()).round(4)
5450

5551

56-
# %%
57-
df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(
58-
"material_id"
59-
)
60-
61-
df_cse["cse"] = [
62-
ComputedStructureEntry.from_dict(x) for x in tqdm(df_cse.computed_structure_entry)
63-
]
64-
65-
66-
# %% transfer CHGNet energies and relaxed structures WBM CSEs since MP2020 energy
67-
# corrections applied below are structure-dependent (for oxides and sulfides)
68-
cse: ComputedStructureEntry
69-
for row in tqdm(df_chgnet.itertuples(), total=len(df_chgnet)):
70-
mat_id, struct_dict, chgnet_energy, *_ = row
71-
chgnet_struct = Structure.from_dict(struct_dict)
72-
cse = df_cse.loc[mat_id, "cse"]
73-
cse._energy = chgnet_energy # cse._energy is the uncorrected energy
74-
cse._structure = chgnet_struct
75-
df_chgnet.loc[mat_id, "cse"] = cse
76-
77-
78-
# %% apply energy corrections to CSEs with CHGNet
79-
out = MaterialsProject2020Compatibility().process_entries(
80-
df_chgnet.cse, verbose=True, clean=True
81-
)
82-
assert len(out) == len(df_chgnet)
83-
84-
8552
# %% compute corrected formation energies
8653
e_form_chgnet_col = "e_form_per_atom_chgnet"
87-
df_chgnet[e_form_chgnet_col] = [get_e_form_per_atom(cse) for cse in tqdm(df_chgnet.cse)]
54+
df_chgnet["formula"] = df_wbm.formula
55+
df_chgnet[e_form_chgnet_col] = [
56+
get_e_form_per_atom(dict(energy=ene, composition=formula))
57+
for formula, ene in tqdm(
58+
df_chgnet.set_index("formula").chgnet_energy.items(), total=len(df_chgnet)
59+
)
60+
]
61+
df_wbm[e_form_chgnet_col] = df_chgnet[e_form_chgnet_col]
8862

8963

9064
# %%
91-
ax = density_scatter(x=df_wbm[e_form_col], y=df_chgnet[e_form_chgnet_col])
65+
ax = density_scatter(x=df_wbm[e_form_col], y=df_wbm[e_form_chgnet_col])
9266

9367

9468
# %% load 2019 MEGNet formation energy model
@@ -97,11 +71,14 @@
9771

9872

9973
# %% predict formation energies on chgnet relaxed structure with MEGNet
100-
for material_id, cse in tqdm(df_cse.cse.items(), total=len(df_cse)):
74+
for material_id, struct in tqdm(
75+
df_chgnet.chgnet_structure.items(), total=len(df_chgnet)
76+
):
10177
if material_id in megnet_e_form_preds:
10278
continue
10379
try:
104-
struct = cse.structure
80+
if isinstance(struct, dict):
81+
struct = Structure.from_dict(struct)
10582
[e_form_per_atom] = megnet_mp_e_form.predict_structure(struct)
10683
megnet_e_form_preds[material_id] = e_form_per_atom
10784
except Exception as exc:
@@ -118,7 +95,7 @@
11895

11996
assert (
12097
n_isna := df_chgnet.e_form_per_atom_chgnet_megnet.isna().sum()
121-
) < 10, f"{n_isna=}, expected 7 or similar"
98+
) < 10, f"too many missing MEGNet preds: {n_isna}"
12299

123100

124101
# %%
@@ -133,6 +110,6 @@
133110

134111
df_chgnet.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))
135112

136-
# in_path = f"{module_dir}/2022-10-31-chgnet-wbm-IS2RE.json.gz"
137-
# df_chgnet_csv = pd.read_csv(in_path.replace(".json.gz", ".csv"))
113+
# in_path = f"{module_dir}/2023-03-04-chgnet-wbm-IS2RE.json.gz"
114+
# df_chgnet = pd.read_csv(in_path.replace(".json.gz", ".csv")).set_index("material_id")
138115
# df_chgnet = pd.read_json(in_path).set_index("material_id")

models/chgnet/metadata.yml

+4
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@ requirements:
3333
numpy: 1.24.0
3434
trained_on_benchmark: false
3535

36+
hyperparams:
37+
max_steps: 2000
38+
3639
notes:
3740
description: |
3841
The Crystal Hamiltonian Graph Neural Network (CHGNet) is a universal GNN-based interatomic potential trained on energies, forces, stresses and magnetic moments from the MP trajectory dataset containing ∼1.5 million inorganic structures.
3942
![CHGNet Pipeline](https://user-images.githubusercontent.com/30958850/222924937-1d09bbce-ee18-4b19-8061-ec689cd15887.svg)
4043
training: Using pre-trained model with 400,438 params released with preprint. Training set unreleased at time of writing.
44+
corrections: Unlike e.g. M3GNet which predicts raw DFT energies, CHGNet targets include MP2020 corrections. Hence no need to correct again.

models/chgnet/test_chgnet.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from __future__ import annotations
1111

1212
import os
13-
import warnings
1413
from importlib.metadata import version
1514
from typing import Any
1615

@@ -31,7 +30,7 @@
3130

3231
task_type = "IS2RE" # "RS2RE"
3332
module_dir = os.path.dirname(__file__)
34-
# set large job array size for fast testing/debugging
33+
# set large job array size for smaller data splits and faster testing/debugging
3534
slurm_array_task_count = 100
3635
job_name = f"chgnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
3736
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
@@ -41,8 +40,8 @@
4140
out_dir=out_dir,
4241
partition="ampere",
4342
account="LEE-SL3-GPU",
44-
time="3:0:0",
45-
# array=f"1-{slurm_array_task_count}",
43+
time="6:0:0",
44+
array=f"1-{slurm_array_task_count}",
4645
slurm_flags="--nodes 1 --gpus-per-node 1",
4746
)
4847

@@ -54,9 +53,6 @@
5453
if os.path.isfile(out_path):
5554
raise SystemExit(f"{out_path = } already exists, exciting early")
5655

57-
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
58-
warnings.filterwarnings(action="ignore", category=UserWarning, module="tensorflow")
59-
6056

6157
# %%
6258
data_path = {
@@ -67,19 +63,19 @@
6763
print(f"{data_path=}")
6864
df_in = pd.read_json(data_path).set_index("material_id")
6965
e_pred_col = "chgnet_energy"
66+
max_steps = 2000
7067

7168
df_in: pd.DataFrame = np.array_split(df_in, slurm_array_task_count)[
7269
slurm_array_task_id - 1
7370
]
7471

7572
run_params = dict(
7673
data_path=data_path,
77-
chgnet_version=version("chgnet"),
78-
numpy_version=version("numpy"),
79-
torch_version=version("torch"),
74+
**{f"{dep}_version": version(dep) for dep in ("chgnet", "numpy", "torch")},
8075
task_type=task_type,
8176
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
8277
slurm_vars=slurm_vars,
78+
max_steps=max_steps,
8379
)
8480

8581
run_name = f"{job_name}-{slurm_array_task_id}"
@@ -100,7 +96,9 @@
10096
if material_id in relax_results:
10197
continue
10298
try:
103-
relax_result = chgnet.relax(structures[material_id], verbose=False)
99+
relax_result = chgnet.relax(
100+
structures[material_id], verbose=False, steps=max_steps
101+
)
104102
except Exception as error:
105103
print(f"Failed to relax {material_id}: {error}")
106104
continue

models/m3gnet/join_m3gnet_results.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,9 @@
4545
if file_path in dfs:
4646
continue
4747
df = pd.read_json(file_path).set_index("material_id")
48-
df[f"m3gnet_energy_{task_type}"] = [
49-
x["energies"][-1][0] for x in df.m3gnet_trajectory
50-
]
5148
# drop trajectory to save memory
5249
dfs[file_path] = df.drop(columns="m3gnet_trajectory")
5350

54-
55-
# %%
5651
df_m3gnet = pd.concat(dfs.values()).round(4)
5752

5853

@@ -130,7 +125,7 @@
130125

131126
assert (
132127
n_isna := df_m3gnet.e_form_per_atom_m3gnet_megnet.isna().sum()
133-
) < 10, f"{n_isna=}, expected 7 or similar"
128+
) < 10, f"too many missing MEGNet preds: {n_isna}"
134129

135130

136131
# %%

models/m3gnet/test_m3gnet.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
task_type = "IS2RE" # "RS2RE"
3131
module_dir = os.path.dirname(__file__)
32-
# set large job array size for fast testing/debugging
32+
# set large job array size for smaller data splits and faster testing/debugging
3333
slurm_array_task_count = 100
3434
job_name = f"m3gnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
3535
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
@@ -75,8 +75,7 @@
7575

7676
run_params = dict(
7777
data_path=data_path,
78-
m3gnet_version=version("m3gnet"),
79-
numpy_version=version("numpy"),
78+
**{f"{dep}_version": version(dep) for dep in ("m3gnet", "numpy")},
8079
task_type=task_type,
8180
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
8281
slurm_vars=slurm_vars,

models/megnet/test_megnet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@
6565
# %%
6666
run_params = dict(
6767
data_path=data_path,
68-
megnet_version=version("megnet"),
69-
numpy_version=version("numpy"),
68+
**{f"{dep}_version": version(dep) for dep in ("megnet", "numpy")},
7069
model_name=model_name,
7170
task_type=task_type,
7271
target_col=e_form_col,

models/voronoi/join_voronoi_features.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,8 @@
3030
for file_path in tqdm(file_paths):
3131
if file_path in dfs:
3232
continue
33-
df = pd.read_csv(file_path).set_index("material_id")
34-
dfs[file_path] = df
33+
dfs[file_path] = pd.read_csv(file_path).set_index("material_id")
3534

36-
37-
# %%
3835
df_features = pd.concat(dfs.values()).round(4)
3936

4037
ax = df_features.isna().sum().value_counts().T.plot.bar()

models/voronoi/train_test_voronoi_rf.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@
7272
train_path=train_path,
7373
test_path=test_path,
7474
mp_energies_path=DATA_FILES.mp_energies,
75-
scikit_learn_version=version("scikit-learn"),
76-
matminer_version=version("matminer"),
77-
numpy_version=version("numpy"),
75+
**{f"{dep}_version": version(dep) for dep in ("scikit-learn", "matminer", "numpy")},
7876
model_name=model_name,
7977
train_target_col=train_e_form_col,
8078
test_target_col=test_e_form_col,

models/voronoi/voronoi_featurize_dataset.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@
8181
input_col=input_col,
8282
slurm_vars=slurm_vars,
8383
out_path=out_path,
84-
matminer_version=version("matminer"),
85-
numpy_version=version("numpy"),
84+
**{f"{dep}_version": version(dep) for dep in ("matminer", "numpy")},
8685
)
8786

8887
wandb.init(project="matbench-discovery", name=run_name, config=run_params)

models/wrenformer/test_wrenformer.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@
7373
run_params = dict(
7474
data_path=data_path,
7575
df=dict(shape=str(df.shape), columns=", ".join(df)),
76-
aviary_version=version("aviary"),
77-
numpy_version=version("numpy"),
78-
torch_version=version("torch"),
76+
**{f"{dep}_version": version(dep) for dep in ("aviary", "numpy", "torch")},
7977
ensemble_size=len(runs),
8078
task_type=task_type,
8179
target_col=e_form_col,

models/wrenformer/train_wrenformer.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@
5959

6060
run_params = dict(
6161
data_path=data_path,
62-
aviary_version=version("aviary"),
63-
numpy_version=version("numpy"),
64-
torch_version=version("torch"),
62+
**{f"{dep}_version": version(dep) for dep in ("aviary", "numpy", "torch")},
6563
batch_size=batch_size,
6664
train_df=dict(shape=train_df.shape, columns=", ".join(train_df)),
6765
test_df=dict(shape=test_df.shape, columns=", ".join(test_df)),

0 commit comments

Comments
 (0)