Skip to content

Commit 6101995

Browse files
committed
compare_cse_vs_ce_mp_2020_corrections.py code for materialsproject/pymatgen#2730
1 parent 6450ebb commit 6101995

5 files changed

+79
-27
lines changed

data/wbm/compare_cse_vs_ce_mp_2020_corrections.py

+56-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# %%
2+
import gzip
3+
import json
14
import warnings
25
from datetime import datetime
36

@@ -12,6 +15,7 @@
1215
from matbench_discovery import ROOT
1316
from matbench_discovery.energy import get_e_form_per_atom
1417
from matbench_discovery.plot_scripts import df_wbm
18+
from matbench_discovery.plots import plt
1519

1620
"""
1721
NOTE MaterialsProject2020Compatibility takes structural information into account when
@@ -47,8 +51,10 @@
4751
get_e_form_per_atom(entry) for entry in tqdm(cses)
4852
]
4953

50-
df_wbm["mp2020_cse_correction"] = [cse.correction for cse in tqdm(cses)]
51-
df_wbm["mp2020_ce_correction"] = [ce.correction for ce in tqdm(ces)]
54+
df_wbm["mp2020_cse_correction_per_atom"] = [
55+
cse.correction_per_atom for cse in tqdm(cses)
56+
]
57+
df_wbm["mp2020_ce_correction_per_atom"] = [ce.correction_per_atom for ce in tqdm(ces)]
5258

5359

5460
# %%
@@ -81,21 +87,25 @@
8187

8288

8389
# %%
90+
ax = plt.gca()
8491
for key, df_anion in df_ce_ne_cse.groupby("anion"):
8592
ax = df_anion.plot.scatter(
86-
ax=locals().get("ax"),
87-
x="mp2020_cse_correction",
88-
y="mp2020_ce_correction",
93+
ax=ax,
94+
x="mp2020_cse_correction_per_atom",
95+
y="mp2020_ce_correction_per_atom",
8996
label=f"{key} ({len(df_anion):,})",
9097
color=dict(oxide="orange", sulfide="teal").get(key, "blue"),
91-
title=f"Outliers in formation energy from CSE vs CE ({len(df_ce_ne_cse):,}"
92-
f" / {len(df_wbm):,} = {len(df_ce_ne_cse) / len(df_wbm):.1%})",
98+
title=f"CSE vs CE corrections for ({len(df_ce_ne_cse):,} / {len(df_wbm):,} = "
99+
f"{len(df_ce_ne_cse) / len(df_wbm):.1%})\n outliers of largest difference",
93100
)
94101

95102
ax.axline((0, 0), slope=1, color="gray", linestyle="dashed", zorder=-1)
96103

104+
# ax.figure.savefig(f"{ROOT}/tmp/{today}-ce-vs-cse-corrections-outliers.pdf")
105+
97106

98107
# %%
108+
ax = plt.gca()
99109
for key, df_anion in df_ce_ne_cse.groupby("anion"):
100110
ax = df_anion.plot.scatter(
101111
ax=locals().get("ax"),
@@ -113,3 +123,42 @@
113123
# different formation energies are oxides or sulfides for which MP 2020 compat takes
114124
# into account structural information to make more accurate corrections.
115125
# ax.figure.savefig(f"{ROOT}/tmp/{today}-ce-vs-cse-outliers.pdf")
126+
127+
128+
# %% below code resulted in
129+
# https://github.com/materialsproject/pymatgen/issues/2730
130+
wbm_step_2_34803 = (
131+
df_ce_ne_cse.e_form_per_atom_mp2020_from_cse
132+
- df_ce_ne_cse.e_form_per_atom_mp2020_from_ce
133+
).idxmax()
134+
idx = df_wbm.index.get_loc(wbm_step_2_34803)
135+
cse_mp2020, cse_legacy = cses[idx].copy(), cses[idx].copy()
136+
ce_mp2020, ce_legacy = ces[idx].copy(), ces[idx].copy()
137+
138+
139+
with gzip.open(f"{ROOT}/tmp/cse-wbm-step-2-34803.json.zip", "w") as f:
140+
f.write(cse_mp2020.to_json().encode("utf-8"))
141+
142+
with gzip.open(f"{ROOT}/tmp/cse-wbm-step-2-34803.json.zip") as f:
143+
cse = ComputedStructureEntry.from_dict(json.load(f))
144+
145+
cse_mp2020 = cse.copy()
146+
cse_legacy = cse.copy()
147+
ce_mp2020 = ComputedEntry.from_dict(cse.to_dict())
148+
ce_legacy = ce_mp2020.copy()
149+
150+
151+
MaterialsProject2020Compatibility().process_entry(cse_mp2020)
152+
MaterialsProject2020Compatibility().process_entry(ce_mp2020)
153+
MaterialsProjectCompatibility().process_entry(cse_legacy)
154+
MaterialsProjectCompatibility().process_entry(ce_legacy)
155+
156+
print(f"{cse_mp2020.correction=:.4}")
157+
print(f"{ce_mp2020.correction=:.4}")
158+
print(f"{cse_legacy.correction=:.4}")
159+
print(f"{ce_legacy.correction=:.4}")
160+
161+
print(f"{cse_mp2020.energy_adjustments=}\n")
162+
print(f"{ce_mp2020.energy_adjustments=}\n")
163+
print(f"{cse_legacy.energy_adjustments=}\n")
164+
print(f"{ce_legacy.energy_adjustments=}\n")

models/bowsr/slurm_array_bowsr_wbm.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,21 @@
2727
"""
2828

2929
task_type = "IS2RE" # "RS2RE"
30-
today = f"{datetime.now():%Y-%m-%d}"
3130
module_dir = os.path.dirname(__file__)
3231
# --mem 12000 avoids slurmstepd: error: Detected 1 oom-kill event(s)
3332
# Some of your processes may have been killed by the cgroup out-of-memory handler.
3433
slurm_mem_per_node = 12000
3534
# set large job array size for fast testing/debugging
3635
slurm_array_task_count = 500
37-
out_dir = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
36+
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
37+
today = timestamp.split("@")[0]
38+
job_name = f"bowsr-megnet-wbm-{task_type}"
39+
out_dir = f"{module_dir}/{today}-{job_name}"
3840

3941
data_path = f"{ROOT}/data/2022-10-19-wbm-init-structs.json.gz"
4042

4143
slurm_submit_python(
42-
job_name=f"bowsr-megnet-wbm-{task_type}",
44+
job_name=job_name,
4345
log_dir=out_dir,
4446
partition="icelake-himem",
4547
account="LEE-SL3-CPU",
@@ -57,7 +59,6 @@
5759
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
5860
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
5961
out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
60-
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
6162

6263
print(f"Job started running {timestamp}")
6364
print(f"{slurm_job_id = }")
@@ -164,4 +165,4 @@
164165

165166
df_output.reset_index().to_json(out_path, default_handler=as_dict_handler)
166167

167-
wandb.log_artifact(out_path, type=f"bowsr-megnet-wbm-{task_type}")
168+
wandb.log_artifact(out_path, type=job_name)

models/cgcnn/slurm_train_cgcnn_ensemble.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@
2525
# %%
2626
epochs = 300
2727
target_col = "formation_energy_per_atom"
28-
run_name = f"cgcnn-robust-{epochs=}-{target_col}"
28+
run_name = f"cgcnn-robust-{target_col}-{epochs=}"
2929
print(f"{run_name=}")
3030
robust = "robust" in run_name.lower()
3131
n_folds = 10
32-
today = f"{datetime.now():%Y-%m-%d}"
32+
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
33+
today = timestamp.split("@")[0]
3334
log_dir = f"{os.path.dirname(__file__)}/{today}-{run_name}"
3435

3536
slurm_submit_python(
@@ -60,7 +61,7 @@
6061
df["structure"] = [Structure.from_dict(s) for s in tqdm(df.structure, disable=None)]
6162
assert target_col in df
6263

63-
train_df, test_df = df_train_test_split(df, test_size=0.5)
64+
train_df, test_df = df_train_test_split(df, test_size=0.05)
6465

6566
train_data = CrystalGraphData(train_df, task_dict={target_col: task_type})
6667
train_loader = DataLoader(
@@ -85,14 +86,14 @@
8586
model = CrystalGraphConvNet(**model_params)
8687

8788
run_params = dict(
89+
data_path=data_path,
8890
batch_size=batch_size,
8991
train_df=dict(shape=train_data.df.shape, columns=", ".join(train_df)),
9092
test_df=dict(shape=test_data.df.shape, columns=", ".join(test_df)),
9193
)
9294

9395

9496
# %%
95-
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
9697
print(f"Job started running {timestamp}")
9798

9899
train_model(

models/m3gnet/slurm_array_m3gnet_wbm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
__date__ = "2022-08-15"
2727

2828
task_type = "IS2RE" # "RS2RE"
29-
today = f"{datetime.now():%Y-%m-%d}"
29+
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
30+
today = timestamp.split("@")[0]
3031
module_dir = os.path.dirname(__file__)
3132
# set large job array size for fast testing/debugging
3233
slurm_array_task_count = 100
@@ -51,7 +52,6 @@
5152
# %%
5253
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
5354
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
54-
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
5555

5656
print(f"Job started running {timestamp}")
5757
print(f"{slurm_job_id = }")

models/wrenformer/slurm_train_wrenformer_ensemble.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@
1818

1919
# %%
2020
epochs = 300
21-
target_col = "e_form"
22-
run_name = f"wrenformer-robust-mp+wbm-{epochs=}-{target_col}"
21+
data_path = f"{ROOT}/data/mp/2022-08-13-mp-energies.json.gz"
22+
target_col = "formation_energy_per_atom"
23+
# data_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
24+
# target_col = "mp_energy_per_atom"
25+
data_name = "m3gnet-trainset" if "m3gnet" in data_path else "mp"
26+
run_name = f"wrenformer-robust-{data_name}-{target_col}-{epochs=}"
2327
n_folds = 10
24-
today = f"{datetime.now():%Y-%m-%d}"
28+
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
29+
today = timestamp.split("@")[0]
2530
dataset = "mp"
2631
log_dir = f"{os.path.dirname(__file__)}/{dataset}/{today}-{run_name}"
2732

@@ -38,13 +43,8 @@
3843

3944
# %%
4045
learning_rate = 3e-4
41-
data_path = f"{ROOT}/data/mp/2022-08-13-mp-energies.json.gz"
42-
target_col = "energy_per_atom"
43-
# data_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
44-
# target_col = "mp_energy_per_atom"
4546
batch_size = 128
4647
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
47-
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
4848
input_col = "wyckoff_spglib"
4949

5050
print(f"Job started running {timestamp}")
@@ -54,9 +54,10 @@
5454
df = pd.read_json(data_path).set_index("material_id", drop=False)
5555
assert target_col in df, f"{target_col=} not in {list(df)}"
5656
assert input_col in df, f"{input_col=} not in {list(df)}"
57-
train_df, test_df = df_train_test_split(df, test_size=0.3)
57+
train_df, test_df = df_train_test_split(df, test_size=0.05)
5858

5959
run_params = dict(
60+
data_path=data_path,
6061
batch_size=batch_size,
6162
train_df=dict(shape=train_df.shape, columns=", ".join(train_df)),
6263
test_df=dict(shape=test_df.shape, columns=", ".join(test_df)),

0 commit comments

Comments
 (0)