Skip to content

Commit 6dd4398

Browse files
committed
add structure perturbation to train_cgcnn.py
including new matbench_discovery/structure.py module with tests in tests/test_structure.py
1 parent 4f1e5b6 commit 6dd4398

File tree

6 files changed

+125
-20
lines changed

6 files changed

+125
-20
lines changed

matbench_discovery/structure.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
from pymatgen.core import Structure
3+
4+
__author__ = "Janosh Riebesell"
5+
__date__ = "2022-12-02"
6+
7+
np.random.seed(0) # ensure reproducible structure perturbations
8+
9+
10+
def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure:
11+
"""Perturb the atomic coordinates of a pymatgen structure
12+
13+
Args:
14+
struct (Structure): pymatgen structure to be perturbed
15+
16+
Returns:
17+
Structure: Perturbed structure
18+
"""
19+
perturbed = struct.copy()
20+
for site in perturbed:
21+
magnitude = np.random.weibull(gamma)
22+
vec = np.random.randn(3) # TODO maybe make func recursive to deal with 0-vector
23+
vec /= np.linalg.norm(vec) # unit vector
24+
site.coords += vec * magnitude
25+
site.to_unit_cell(in_place=True)
26+
27+
return perturbed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# %%
2+
import numpy as np
3+
import pandas as pd
4+
from pymatgen.core import Lattice, Structure
5+
from pymatviz import plot_structure_2d
6+
7+
from matbench_discovery.plots import plt
8+
from matbench_discovery.structure import perturb_structure
9+
10+
__author__ = "Janosh Riebesell"
11+
__date__ = "2022-12-02"
12+
13+
14+
# %%
15+
ax = pd.Series(np.random.weibull(1.5, 100000)).hist(bins=100)
16+
title = "Distribution of perturbation magnitudes"
17+
ax.set(xlabel="magnitude of perturbation", ylabel="count", title=title)
18+
19+
20+
# %%
21+
struct = Structure(
22+
lattice=Lattice.cubic(5),
23+
species=("Fe", "O"),
24+
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
25+
)
26+
27+
ax = plot_structure_2d(struct)
28+
ax.set(title=f"Original structure: {struct.formula}")
29+
ax.set_aspect("equal")
30+
31+
32+
# %%
33+
fig, axs = plt.subplots(3, 4, figsize=(12, 10))
34+
for idx, ax in enumerate(axs.flat, 1):
35+
plot_structure_2d(perturb_structure(struct), ax=ax)
36+
ax.set(title=f"perturbation {idx}")

models/cgcnn/test_cgcnn.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
__date__ = "2022-08-15"
2424

2525
"""
26-
Script that downloads checkpoints for an ensemble of CGCNN models trained on all MP
26+
Download WandB checkpoints for an ensemble of CGCNN models trained on all MP
2727
formation energies, then makes predictions on some dataset, prints ensemble metrics and
2828
saves predictions to CSV.
2929
"""
3030

31-
task_type = "RS2RE"
31+
task_type = "IS2RE"
3232
debug = "slurm-submit" in sys.argv
3333
job_name = f"test-cgcnn-wbm-{task_type}{'-debug' if DEBUG else ''}"
3434
module_dir = os.path.dirname(__file__)
@@ -58,16 +58,15 @@
5858

5959
target_col = "e_form_per_atom_mp2020_corrected"
6060
df[target_col] = df_wbm[target_col]
61-
assert target_col in df, f"{target_col=} not in {list(df)}"
6261
if task_type == "RS2RE":
6362
df[input_col] = [x["structure"] for x in df.computed_structure_entry]
6463
assert input_col in df, f"{input_col=} not in {list(df)}"
6564

6665
df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
6766

6867
filters = {
69-
"created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"},
70-
"display_name": {"$regex": "^cgcnn-robust"},
68+
"created_at": {"$gt": "2022-12-03", "$lt": "2022-12-04"},
69+
"display_name": {"$regex": "^train-cgcnn-robust-augment=3-"},
7170
}
7271
runs = wandb.Api().runs("janosh/matbench-discovery", filters=filters)
7372

@@ -92,19 +91,15 @@
9291
slurm_vars=slurm_vars,
9392
)
9493

95-
9694
wandb.init(project="matbench-discovery", name=job_name, config=run_params)
9795

9896
cg_data = CrystalGraphData(
99-
df,
100-
task_dict={target_col: "regression"},
101-
structure_col=input_col,
102-
identifiers=["formula_from_cse"],
97+
df, task_dict={target_col: "regression"}, structure_col=input_col
10398
)
10499
data_loader = DataLoader(
105100
cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
106101
)
107-
df, ensemble_metrics = predict_from_wandb_checkpoints(
102+
df_preds, ensemble_metrics = predict_from_wandb_checkpoints(
108103
runs,
109104
# dropping isolated-atom structs means len(cg_data.df) < len(df)
110105
cache_dir=CHECKPOINT_DIR,
@@ -114,9 +109,10 @@
114109
data_loader=data_loader,
115110
)
116111

117-
df.to_csv(f"{out_dir}/{job_name}-preds.csv", index=False)
112+
df_preds.to_csv(f"{out_dir}/{job_name}-preds.csv", index=False)
118113
pred_col = f"{target_col}_pred_ens"
119-
table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index())
114+
assert pred_col in df, f"{pred_col=} not in {list(df)}"
115+
table = wandb.Table(dataframe=df_preds[[target_col, pred_col]].reset_index())
120116

121117

122118
# %%

models/cgcnn/train_cgcnn.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from aviary.train import df_train_test_split, train_model
99
from pymatgen.core import Structure
1010
from torch.utils.data import DataLoader
11-
from tqdm import tqdm
11+
from tqdm import tqdm, trange
1212

1313
from matbench_discovery import DEBUG, ROOT, timestamp, today
1414
from matbench_discovery.slurm import slurm_submit
15+
from matbench_discovery.structure import perturb_structure
1516

1617
"""
1718
Train a CGCNN ensemble on target_col of data_path.
@@ -24,7 +25,10 @@
2425
# %%
2526
epochs = 300
2627
target_col = "formation_energy_per_atom"
27-
job_name = f"train-cgcnn-robust-{target_col}{'-debug' if DEBUG else ''}"
28+
input_col = "structure"
29+
id_col = "material_id"
30+
augment = 3
31+
job_name = f"train-cgcnn-robust-{augment=}{'-debug' if DEBUG else ''}"
2832
print(f"{job_name=}")
2933
robust = "robust" in job_name.lower()
3034
ensemble_size = 10
@@ -35,7 +39,7 @@
3539
job_name=job_name,
3640
partition="ampere",
3741
account="LEE-SL3-GPU",
38-
time="8:0:0",
42+
time="12:0:0",
3943
array=f"1-{ensemble_size}",
4044
out_dir=out_dir,
4145
slurm_flags="--nodes 1 --gpus-per-node 1",
@@ -55,10 +59,18 @@
5559
data_path = f"{ROOT}/data/mp/2022-08-13-mp-energies.json.gz"
5660
# data_path = f"{ROOT}/data/mp/2022-08-13-mp-energies-1k-samples.json.gz"
5761
print(f"{data_path=}")
58-
df = pd.read_json(data_path).set_index("material_id", drop=False)
59-
df["structure"] = [Structure.from_dict(s) for s in tqdm(df.structure, disable=None)]
62+
df = pd.read_json(data_path).set_index(id_col)
63+
df[input_col] = [Structure.from_dict(s) for s in tqdm(df[input_col], disable=None)]
6064
assert target_col in df
6165

66+
df_aug = df.copy()
67+
structs = df_aug.pop(input_col)
68+
for idx in trange(augment, desc="Augmenting"):
69+
df_aug[input_col] = [perturb_structure(x) for x in structs]
70+
df = pd.concat([df, df_aug.set_index(f"{x}-aug={idx+1}" for x in df_aug.index)])
71+
72+
del df_aug
73+
6274
train_df, test_df = df_train_test_split(df, test_size=0.05)
6375

6476
print(f"{train_df.shape=}")
@@ -91,6 +103,8 @@
91103
train_df=dict(shape=str(train_data.df.shape), columns=", ".join(train_df)),
92104
test_df=dict(shape=str(test_data.df.shape), columns=", ".join(test_df)),
93105
slurm_vars=slurm_vars,
106+
augment=augment,
107+
input_col=input_col,
94108
)
95109

96110

@@ -108,9 +122,9 @@
108122
swa_start=swa_start,
109123
target_col=target_col,
110124
task_type=task_type,
125+
train_loader=train_loader,
111126
test_loader=test_loader,
112127
timestamp=timestamp,
113-
train_loader=train_loader,
114128
wandb_path="janosh/matbench-discovery",
115129
run_params=run_params,
116130
)

models/wrenformer/test_wrenformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
task_type = "IS2RE"
2828
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
2929
debug = "slurm-submit" in sys.argv
30-
job_name = f"test-wrenformer-wbm-IS2RE{'-debug' if DEBUG else ''}"
30+
job_name = f"test-wrenformer-wbm-{task_type}{'-debug' if DEBUG else ''}"
3131
module_dir = os.path.dirname(__file__)
3232
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3333

tests/test_structure.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import pytest
5+
from pymatgen.core import Lattice, Structure
6+
7+
from matbench_discovery.structure import perturb_structure
8+
9+
10+
@pytest.fixture
11+
def struct() -> Structure:
12+
return Structure(
13+
lattice=Lattice.cubic(5),
14+
species=("Fe", "O"),
15+
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
16+
)
17+
18+
19+
def test_perturb_structure(struct: Structure) -> None:
20+
np.random.seed(0)
21+
perturbed = perturb_structure(struct)
22+
assert len(perturbed) == len(struct)
23+
24+
for site, new in zip(struct, perturbed):
25+
assert site.specie == new.specie
26+
assert tuple(site.coords) != tuple(new.coords)
27+
28+
# test that the perturbation is reproducible
29+
np.random.seed(0)
30+
assert perturbed == perturb_structure(struct)
31+
# but different on subsequent calls
32+
assert perturb_structure(struct) != perturb_structure(struct)

0 commit comments

Comments
 (0)