Skip to content

Commit d564ade

Browse files
committed
refactor data loading in model test scripts
1 parent 387184c commit d564ade

File tree

7 files changed

+61
-75
lines changed

7 files changed

+61
-75
lines changed

models/bowsr/test_bowsr.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from maml.apps.bowsr.model.megnet import MEGNet
1313
from maml.apps.bowsr.optimizer import BayesianOptimizer
1414
from pymatgen.core import Structure
15-
from pymatgen.entries.computed_entries import ComputedStructureEntry
1615
from tqdm import tqdm
1716

1817
from matbench_discovery import DEBUG, timestamp, today
@@ -39,7 +38,11 @@
3938
job_name = f"bowsr-{energy_model}-wbm-{task_type}{'-debug' if DEBUG else ''}"
4039
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
4140

42-
data_path = DATA_FILES.wbm_initial_structures
41+
data_path = {
42+
"IS2RE": DATA_FILES.wbm_initial_structures,
43+
"RS2RE": DATA_FILES.wbm_computed_structure_entries,
44+
}[task_type]
45+
4346

4447
slurm_vars = slurm_submit(
4548
job_name=job_name,
@@ -73,7 +76,7 @@
7376
# %%
7477
df_wbm = pd.read_json(data_path).set_index("material_id")
7578

76-
df_this_job: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
79+
df_in: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
7780
slurm_array_task_id - 1
7881
]
7982

@@ -90,7 +93,7 @@
9093
run_params = dict(
9194
bayes_optim_kwargs=bayes_optim_kwargs,
9295
data_path=data_path,
93-
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
96+
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
9497
energy_model=energy_model,
9598
maml_version=version("maml"),
9699
energy_model_version=version(energy_model),
@@ -106,16 +109,12 @@
106109
# %%
107110
model = MEGNet()
108111
relax_results: dict[str, dict[str, Any]] = {}
112+
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]
109113

110-
if task_type == "IS2RE":
111-
structures = df_this_job.initial_structure.map(Structure.from_dict).to_dict()
112-
elif task_type == "RS2RE":
113-
structures = df_this_job.cse.map(
114-
lambda x: ComputedStructureEntry.from_dict(x).structure
115-
).to_dict()
116-
else:
117-
raise ValueError(f"Unknown {task_type = }")
114+
if task_type == "RS2RE":
115+
df_in[input_col] = [x["structure"] for x in df_in.computed_structure_entry]
118116

117+
structures = df_in[input_col].map(Structure.from_dict).to_dict()
119118

120119
for material_id in tqdm(structures, desc="Main loop", disable=None):
121120
structure = structures[material_id]

models/cgcnn/test_cgcnn.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.utils.data import DataLoader
1515
from tqdm import tqdm
1616

17-
from matbench_discovery import CHECKPOINT_DIR, DEBUG, WANDB_PATH, today
17+
from matbench_discovery import CHECKPOINT_DIR, DEBUG, ROOT, WANDB_PATH, today
1818
from matbench_discovery.data import DATA_FILES, df_wbm
1919
from matbench_discovery.plots import wandb_scatter
2020
from matbench_discovery.slurm import slurm_submit
@@ -45,19 +45,12 @@
4545

4646

4747
# %%
48-
if task_type == "IS2RE":
49-
data_path = DATA_FILES.wbm_initial_structures
50-
# or for debug
51-
# data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json-1k-samples.bz2"
52-
# created with:
53-
# df = df.sample(1000)
54-
# df.reset_index().to_json(data_path.replace(".json", "-1k-samples.json"))
55-
input_col = "initial_structure"
56-
elif task_type == "RS2RE":
57-
data_path = DATA_FILES.wbm_computed_structure_entries
58-
input_col = "relaxed_structure"
59-
else:
60-
raise ValueError(f"Unexpected {task_type=}")
48+
data_path = {
49+
"IS2RE": DATA_FILES.wbm_initial_structures,
50+
"RS2RE": DATA_FILES.wbm_computed_structure_entries,
51+
"IS2RE-debug": f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json-1k-samples.bz2",
52+
}[task_type + "-debug" if DEBUG else ""]
53+
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]
6154

6255
df = pd.read_json(data_path).set_index("material_id")
6356

models/chgnet/join_chgnet_results.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# %%
3434
module_dir = os.path.dirname(__file__)
3535
task_type = "IS2RE"
36-
date = "2023-03-02"
36+
date = "2023-03-04"
3737
glob_pattern = f"{date}-chgnet-wbm-{task_type}*/*.json.gz"
3838
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
3939
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
@@ -128,7 +128,7 @@
128128

129129
# %%
130130
ax = density_scatter(
131-
df=df_chgnet, x="e_form_per_atom_chgnet", y="e_form_per_atom_chgnet_megnet"
131+
df=df_chgnet, x="e_form_per_atom_chgnet_megnet", y="e_form_per_atom_chgnet"
132132
)
133133

134134

models/chgnet/test_chgnet.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import wandb
2020
from chgnet.model import StructOptimizer
2121
from pymatgen.core import Structure
22-
from pymatgen.entries.computed_entries import ComputedStructureEntry
2322
from tqdm import tqdm
2423

2524
from matbench_discovery import DEBUG, timestamp, today
@@ -69,7 +68,7 @@
6968
df_in = pd.read_json(data_path).set_index("material_id")
7069
e_pred_col = "chgnet_energy"
7170

72-
df_this_job: pd.DataFrame = np.array_split(df_in, slurm_array_task_count)[
71+
df_in: pd.DataFrame = np.array_split(df_in, slurm_array_task_count)[
7372
slurm_array_task_id - 1
7473
]
7574

@@ -79,7 +78,7 @@
7978
numpy_version=version("numpy"),
8079
torch_version=version("torch"),
8180
task_type=task_type,
82-
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
81+
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
8382
slurm_vars=slurm_vars,
8483
)
8584

@@ -90,15 +89,12 @@
9089
# %%
9190
chgnet = StructOptimizer() # load default pre-trained CHGNnet model
9291
relax_results: dict[str, dict[str, Any]] = {}
92+
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]
9393

94-
if task_type == "IS2RE":
95-
structures = df_this_job.initial_structure.map(Structure.from_dict).to_dict()
96-
elif task_type == "RS2RE":
97-
df_this_job.cse = df_this_job.cse.map(ComputedStructureEntry.from_dict)
98-
structures = df_this_job.cse.map(lambda x: x.structure).to_dict()
99-
else:
100-
raise ValueError(f"Unknown {task_type = }")
94+
if task_type == "RS2RE":
95+
df_in[input_col] = [x["structure"] for x in df_in.computed_structure_entry]
10196

97+
structures = df_in[input_col].map(Structure.from_dict).to_dict()
10298

10399
for material_id in tqdm(structures, disable=None):
104100
if material_id in relax_results:

models/m3gnet/test_m3gnet.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import wandb
1919
from m3gnet.models import Relaxer
2020
from pymatgen.core import Structure
21-
from pymatgen.entries.computed_entries import ComputedStructureEntry
2221
from tqdm import tqdm
2322

2423
from matbench_discovery import DEBUG, timestamp, today
@@ -61,12 +60,15 @@
6160

6261

6362
# %%
64-
data_path = DATA_FILES.wbm_computed_structure_entries_plus_init_structs
63+
data_path = {
64+
"IS2RE": DATA_FILES.wbm_initial_structures,
65+
"RS2RE": DATA_FILES.wbm_computed_structure_entries,
66+
}[task_type]
6567
print(f"\nJob started running {timestamp}")
6668
print(f"{data_path=}")
6769
df_wbm = pd.read_json(data_path).set_index("material_id")
6870

69-
df_this_job: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
71+
df_in: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
7072
slurm_array_task_id - 1
7173
]
7274

@@ -75,7 +77,7 @@
7577
m3gnet_version=version("m3gnet"),
7678
numpy_version=version("numpy"),
7779
task_type=task_type,
78-
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
80+
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
7981
slurm_vars=slurm_vars,
8082
)
8183

@@ -86,15 +88,12 @@
8688
# %%
8789
megnet = Relaxer() # load default pre-trained M3GNet model
8890
relax_results: dict[str, dict[str, Any]] = {}
91+
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]
8992

90-
if task_type == "IS2RE":
91-
structures = df_this_job.initial_structure.map(Structure.from_dict).to_dict()
92-
elif task_type == "RS2RE":
93-
df_this_job.cse = df_this_job.cse.map(ComputedStructureEntry.from_dict)
94-
structures = df_this_job.cse.map(lambda x: x.structure).to_dict()
95-
else:
96-
raise ValueError(f"Unknown {task_type = }")
93+
if task_type == "RS2RE":
94+
df_in[input_col] = [x["structure"] for x in df_in.computed_structure_entry]
9795

96+
structures = df_in[input_col].map(Structure.from_dict).to_dict()
9897

9998
for material_id in tqdm(structures, disable=None):
10099
if material_id in relax_results:

models/megnet/test_megnet.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import wandb
1616
from megnet.utils.models import load_model
1717
from pymatgen.core import Structure
18-
from pymatgen.entries.computed_entries import ComputedStructureEntry
1918
from sklearn.metrics import r2_score
2019
from tqdm import tqdm
2120

@@ -50,13 +49,16 @@
5049
if os.path.isfile(out_path):
5150
raise SystemExit(f"{out_path = } already exists, exciting early")
5251

53-
data_path = DATA_FILES.wbm_initial_structures
52+
data_path = {
53+
"IS2RE": DATA_FILES.wbm_initial_structures,
54+
"RS2RE": DATA_FILES.wbm_computed_structure_entries,
55+
}[task_type]
5456
print(f"\nJob started running {timestamp}")
5557
print(f"{data_path=}")
5658
e_form_col = "e_form_per_atom_mp2020_corrected"
5759
assert e_form_col in df_wbm, f"{e_form_col=} not in {list(df_wbm)=}"
5860

59-
df_wbm_structs = pd.read_json(data_path).set_index("material_id")
61+
df_in = pd.read_json(data_path).set_index("material_id")
6062
megnet_mp_e_form = load_model(model_name := "Eform_MP_2019")
6163

6264

@@ -68,21 +70,20 @@
6870
model_name=model_name,
6971
task_type=task_type,
7072
target_col=e_form_col,
71-
df=dict(shape=str(df_wbm_structs.shape), columns=", ".join(df_wbm_structs)),
73+
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
7274
slurm_vars=slurm_vars,
7375
)
7476

7577
wandb.init(project="matbench-discovery", name=job_name, config=run_params)
7678

7779

7880
# %%
79-
if task_type == "IS2RE":
80-
structures = df_wbm_structs.initial_structure.map(Structure.from_dict)
81-
elif task_type == "RS2RE":
82-
df_wbm_structs.cse = df_wbm_structs.cse.map(ComputedStructureEntry.from_dict)
83-
structures = df_wbm_structs.cse.map(lambda x: x.structure)
84-
else:
85-
raise ValueError(f"Unknown {task_type = }")
81+
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]
82+
83+
if task_type == "RS2RE":
84+
df_in[input_col] = [x["structure"] for x in df_in.computed_structure_entry]
85+
86+
structures = df_in[input_col].map(Structure.from_dict).to_dict()
8687

8788
megnet_e_form_preds = {}
8889
for material_id in tqdm(structures, disable=None):

models/voronoi/voronoi_featurize_dataset.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525

2626

2727
data_name = "mp" # "mp"
28-
if data_name == "wbm":
29-
data_path = DATA_FILES.wbm_initial_structures
30-
elif data_name == "mp":
31-
data_path = DATA_FILES.mp_computed_structure_entries
28+
data_path = {
29+
"wbm": DATA_FILES.wbm_initial_structures,
30+
"mp": DATA_FILES.mp_computed_structure_entries,
31+
}[data_name]
3232

3333
input_col = "initial_structure"
3434
# input_col = "relaxed_structure"
@@ -60,26 +60,24 @@
6060

6161
print(f"{data_path=}")
6262
df = pd.read_json(data_path).set_index("material_id")
63-
df_this_job: pd.DataFrame = np.array_split(df, slurm_array_task_count)[
63+
df_in: pd.DataFrame = np.array_split(df, slurm_array_task_count)[
6464
slurm_array_task_id - 1
6565
]
6666

6767
if data_name == "mp": # extract structure dicts from ComputedStructureEntry
68-
struct_dicts = [x["structure"] for x in df_this_job.entry]
68+
struct_dicts = [x["structure"] for x in df_in.entry]
6969
elif data_name == "wbm" and input_col == "relaxed_structure":
70-
struct_dicts = [x["structure"] for x in df_this_job.computed_structure_entry]
70+
struct_dicts = [x["structure"] for x in df_in.computed_structure_entry]
7171
elif data_name == "wbm" and input_col == "initial_structure":
72-
struct_dicts = df_this_job.initial_structure
72+
struct_dicts = df_in.initial_structure
7373

74-
df_this_job[input_col] = [
75-
Structure.from_dict(x) for x in tqdm(struct_dicts, disable=None)
76-
]
74+
df_in[input_col] = [Structure.from_dict(x) for x in tqdm(struct_dicts, disable=None)]
7775

7876

7977
# %%
8078
run_params = dict(
8179
data_path=data_path,
82-
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
80+
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
8381
input_col=input_col,
8482
slurm_vars=slurm_vars,
8583
out_path=out_path,
@@ -94,9 +92,9 @@
9492
# > No electronegativity for Ne. Setting to NaN. This has no physical meaning, ...
9593
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
9694

97-
df_features = featurizer.featurize_dataframe(
98-
df_this_job, input_col, ignore_errors=True
99-
)[featurizer.feature_labels()].round(4)
95+
df_features = featurizer.featurize_dataframe(df_in, input_col, ignore_errors=True)[
96+
featurizer.feature_labels()
97+
].round(4)
10098

10199

102100
# %%

0 commit comments

Comments
 (0)