Skip to content

Commit fed968f

Browse files
committed
add wandb scatter plot logging to slurm_array_megnet_wbm.py
add maml, megnet, m3gnet-dgl to setup.py extras_require running-models plot density_scatter() at end of use_cgcnn_ensemble.py
1 parent ffa075e commit fed968f

File tree

4 files changed

+63
-41
lines changed

4 files changed

+63
-41
lines changed

models/cgcnn/use_cgcnn_ensemble.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aviary.cgcnn.model import CrystalGraphConvNet
1111
from aviary.deploy import predict_from_wandb_checkpoints
1212
from pymatgen.core import Structure
13+
from pymatviz import density_scatter
1314
from torch.utils.data import DataLoader
1415
from tqdm import tqdm
1516

@@ -29,7 +30,7 @@
2930
module_dir = os.path.dirname(__file__)
3031
today = f"{datetime.now():%Y-%m-%d}"
3132
ensemble_id = "cgcnn-e_form-ensemble-1"
32-
run_name = f"{today}-{ensemble_id}-IS2RE"
33+
run_name = f"{ensemble_id}-IS2RE"
3334

3435
slurm_submit(
3536
job_name=run_name,
@@ -82,3 +83,13 @@
8283
)
8384

8485
df.round(6).to_csv(f"{module_dir}/{today}-{run_name}-preds.csv", index=False)
86+
87+
88+
# %%
89+
print(f"{runs[0].url=}")
90+
ax = density_scatter(
91+
df=df.query("e_form_per_atom_mp2020_corrected < 10"),
92+
x="e_form_per_atom_mp2020_corrected",
93+
y="e_form_per_atom_mp2020_corrected_pred_1",
94+
)
95+
# ax.figure.savefig(f"{ROOT}/tmp/{today}-{run_name}-scatter-preds.png", dpi=300)

models/m3gnet/slurm_array_m3gnet_wbm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@
5656
print(f"Job started running {timestamp}")
5757
print(f"{version('m3gnet') = }")
5858

59-
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
59+
out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
6060

61-
if os.path.isfile(json_out_path):
62-
raise SystemExit(f"{json_out_path = } already exists, exciting early")
61+
if os.path.isfile(out_path):
62+
raise SystemExit(f"{out_path = } already exists, exciting early")
6363

6464
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
6565
warnings.filterwarnings(action="ignore", category=UserWarning, module="tensorflow")
@@ -125,6 +125,6 @@
125125
df_output = pd.DataFrame(relax_results).T
126126
df_output.index.name = "material_id"
127127

128-
df_output.reset_index().to_json(json_out_path, default_handler=as_dict_handler)
128+
df_output.reset_index().to_json(out_path, default_handler=as_dict_handler)
129129

130-
wandb.log_artifact(json_out_path, type=f"m3gnet-wbm-{task_type}")
130+
wandb.log_artifact(out_path, type=f"m3gnet-wbm-{task_type}")

models/megnet/slurm_array_megnet_wbm.py

+45-35
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from datetime import datetime
66
from importlib.metadata import version
77

8-
import numpy as np
98
import pandas as pd
109
import wandb
1110
from megnet.utils.models import load_model
1211
from tqdm import tqdm
1312

1413
from matbench_discovery import ROOT
14+
from matbench_discovery.plot_scripts import df_wbm
1515
from matbench_discovery.slurm import slurm_submit
1616

1717
"""
@@ -23,14 +23,11 @@
2323
__author__ = "Janosh Riebesell"
2424
__date__ = "2022-11-14"
2525

26-
task_type = "IS2RE" # "RS2RE"
26+
task_type = "IS2RE"
2727
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
2828
today = timestamp.split("@")[0]
2929
module_dir = os.path.dirname(__file__)
30-
# set large job array size for fast testing/debugging
31-
slurm_array_task_count = 1
32-
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
33-
job_name = f"megnet-wbm-{task_type}-{slurm_job_id}"
30+
job_name = f"megnet-wbm-{task_type}"
3431
out_dir = f"{module_dir}/{today}-{job_name}"
3532

3633
slurm_vars = slurm_submit(
@@ -39,80 +36,93 @@
3936
partition="icelake-himem",
4037
account="LEE-SL3-CPU",
4138
time=(slurm_max_job_time := "12:0:0"),
42-
array=f"1-{slurm_array_task_count}",
4339
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
4440
# https://stackoverflow.com/a/40982782
4541
pre_cmd="TF_CPP_MIN_LOG_LEVEL=2",
4642
)
4743

4844

4945
# %%
50-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
51-
5246
print(f"Job started running {timestamp}")
5347

54-
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
55-
if os.path.isfile(json_out_path):
56-
raise SystemExit(f"{json_out_path = } already exists, exciting early")
48+
out_path = f"{out_dir}/megnet-e-form-preds.csv"
49+
if os.path.isfile(out_path):
50+
raise SystemExit(f"{out_path = } already exists, exciting early")
5751

5852

5953
# %%
6054
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
6155
print(f"Loading from {data_path=}")
62-
df_wbm = pd.read_json(data_path).set_index("material_id")
56+
df_wbm_structs = pd.read_json(data_path).set_index("material_id")
6357

64-
df_this_job: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
65-
slurm_array_task_id - 1
66-
]
6758

6859
megnet_mp_e_form = load_model(model_name := "Eform_MP_2019")
6960

61+
62+
# %%
7063
run_params = dict(
7164
data_path=data_path,
7265
megnet_version=version("megnet"),
7366
model_name=model_name,
7467
task_type=task_type,
7568
slurm_max_job_time=slurm_max_job_time,
76-
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
69+
df=dict(shape=str(df_wbm_structs.shape), columns=", ".join(df_wbm_structs)),
7770
slurm_vars=slurm_vars,
7871
)
7972
if wandb.run is None:
8073
wandb.login()
8174

82-
wandb.init(
83-
project="matbench-discovery",
84-
name=f"{job_name}-{slurm_array_task_id}",
85-
config=run_params,
86-
)
75+
wandb.init(project="matbench-discovery", name=job_name, config=run_params)
8776

8877

8978
# %%
9079
if task_type == "IS2RE":
9180
from pymatgen.core import Structure
9281

93-
structures = df_this_job.initial_structure.map(Structure.from_dict)
82+
structures = df_wbm_structs.initial_structure.map(Structure.from_dict)
9483
elif task_type == "RS2RE":
9584
from pymatgen.entries.computed_entries import ComputedStructureEntry
9685

97-
df_this_job.cse = df_this_job.cse.map(ComputedStructureEntry.from_dict)
98-
structures = df_this_job.cse.map(lambda x: x.structure)
86+
df_wbm_structs.cse = df_wbm_structs.cse.map(ComputedStructureEntry.from_dict)
87+
structures = df_wbm_structs.cse.map(lambda x: x.structure)
9988
else:
10089
raise ValueError(f"Unknown {task_type = }")
10190

102-
megnet_preds = {}
103-
for material_id, structure in tqdm(structures.items(), disable=None):
104-
if material_id in megnet_preds:
91+
megnet_e_form_preds = {}
92+
for material_id, structure in tqdm(structures.items(), total=len(structures)):
93+
if material_id in megnet_e_form_preds:
10594
continue
106-
e_form_per_atom = megnet_mp_e_form.predict_structure(structure)[0]
107-
megnet_preds[material_id] = e_form_per_atom
95+
try:
96+
e_form_per_atom = megnet_mp_e_form.predict_structure(structure)[0]
97+
megnet_e_form_preds[material_id] = e_form_per_atom
98+
except Exception as exc:
99+
print(f"Failed to predict {material_id=}: {exc}")
100+
108101

102+
# %%
103+
print(f"{len(megnet_e_form_preds)=:,}")
104+
print(f"{len(structures)=:,}")
105+
print(f"missing: {len(structures) - len(megnet_e_form_preds):,}")
106+
out_col = "e_form_per_atom_megnet"
107+
df_wbm[out_col] = pd.Series(megnet_e_form_preds)
109108

110-
assert len(megnet_preds) == len(structures) == len(df_this_job)
111-
out_col = "megnet_e_form"
112-
df_this_job[out_col] = pd.Series(megnet_preds)
109+
df_wbm[out_col].reset_index().to_csv(out_path)
113110

114111

115112
# %%
116-
df_this_job[out_col].reset_index().to_json(json_out_path)
113+
fields = {"x": "e_form_per_atom_mp2020_corrected", "y": out_col}
114+
cols = list(fields.values())
115+
assert all(col in df_wbm for col in cols)
116+
117+
table = wandb.Table(dataframe=df_wbm[cols].reset_index())
118+
119+
MAE = (df_wbm[fields["x"]] - df_wbm[fields["y"]]).abs().mean()
120+
121+
scatter_plot = wandb.plot_table(
122+
vega_spec_name="janosh/scatter-parity",
123+
data_table=table,
124+
fields=fields,
125+
string_fields={"title": f"{model_name} {task_type} {MAE=:.4}"},
126+
)
117127

118-
wandb.log_artifact(json_out_path, type=f"m3gnet-wbm-{task_type}")
128+
wandb.log({"true_pred_scatter": scatter_plot})

models/voronoi/featurize_mp_wbm.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
]
5757

5858

59+
# %%
5960
run_params = dict(
6061
data_path=data_path,
6162
slurm_max_job_time=slurm_max_job_time,

0 commit comments

Comments
 (0)