Skip to content

Commit 0a2284a

Browse files
committed
fix models/bowsr/test_bowsr.py passing unrelaxed, not bowsr-relaxed structure to megnet
1 parent 1b7f056 commit 0a2284a

File tree

8 files changed

+32
-13
lines changed

8 files changed

+32
-13
lines changed

matbench_discovery/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# directory to store model checkpoints downloaded from wandb cloud storage
1616
CHECKPOINT_DIR = f"{ROOT}/wandb/checkpoints"
1717
# wandb <entity>/<project name> to record new runs to
18-
WANDB_PATH = "materialsproject/matbench-discovery"
18+
WANDB_PATH = "janosh/matbench-discovery"
1919

2020
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
2121
today = timestamp.split("@")[0]

models/bowsr/test_bowsr.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
# Some of your processes may have been killed by the cgroup out-of-memory handler.
3333
slurm_mem_per_node = 12000
3434
# set large job array size for fast testing/debugging
35-
slurm_array_task_count = 1000
35+
slurm_array_task_count = 500
3636
# see https://stackoverflow.com/a/55431306 for how to change array throttling
3737
# post submission
38-
slurm_max_parallel = 50
38+
slurm_max_parallel = 100
3939
energy_model = "megnet"
4040
job_name = f"bowsr-{energy_model}-wbm-{task_type}{'-debug' if DEBUG else ''}"
4141
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
@@ -137,7 +137,9 @@
137137
structure_bowsr, energy_bowsr = optimizer.get_optimized_structure_and_energy()
138138

139139
results = {
140-
f"e_form_per_atom_bowsr_{energy_model}": model.predict_energy(structure),
140+
f"e_form_per_atom_bowsr_{energy_model}": model.predict_energy(
141+
structure_bowsr
142+
),
141143
"structure_bowsr": structure_bowsr,
142144
"energy_bowsr": energy_bowsr,
143145
}

models/cgcnn/test_cgcnn.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,16 @@
7070
df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
7171

7272
filters = {
73-
"created_at": {"$gt": "2022-12-03", "$lt": "2022-12-04"},
74-
"display_name": {"$regex": "^train-cgcnn-robust-augment=3-"},
73+
# "display_name": {"$regex": "^train-cgcnn-robust-augment=3-"},
74+
# "created_at": {"$gt": "2022-12-03", "$lt": "2022-12-04"},
75+
"display_name": {"$regex": "^train-cgcnn-robust-augment=0-"},
76+
"created_at": {"$gt": "2023-01-09", "$lt": "2023-01-10"},
7577
}
7678
runs = wandb.Api().runs(WANDB_PATH, filters=filters)
79+
assert (
80+
len(runs) == 10
81+
), f"Expected 10 runs, got {len(runs)} filtering {WANDB_PATH=} with {filters=}"
7782

78-
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {filters=}"
7983
for idx, run in enumerate(runs):
8084
for key, val in run.config.items():
8185
if val == runs[0].config[key] or key.startswith(("slurm_", "timestamp")):
@@ -96,6 +100,7 @@
96100
input_col=input_col,
97101
wandb_run_filters=filters,
98102
slurm_vars=slurm_vars,
103+
training_run_ids=[run.id for run in runs],
99104
)
100105

101106
wandb.init(project="matbench-discovery", name=job_name, config=run_params)

models/cgcnn/train_cgcnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
target_col = "formation_energy_per_atom"
2929
input_col = "structure"
3030
id_col = "material_id"
31-
augment = 1 # 0 for no augmentation, n>1 means train on n perturbations of each crystal
31+
augment = 0 # 0 for no augmentation, n>1 means train on n perturbations of each crystal
3232
# in the training set all assigned the same original target energy
3333
job_name = f"train-cgcnn-robust-{augment=}{'-debug' if DEBUG else ''}"
3434
print(f"{job_name=}")

models/wrenformer/test_wrenformer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@
5656
"display_name": {"$regex": "wrenformer-robust"},
5757
}
5858
runs = wandb.Api().runs(WANDB_PATH, filters=filters)
59+
assert (
60+
len(runs) == 10
61+
), f"Expected 10 runs, got {len(runs)} filtering {WANDB_PATH=} with {filters=}"
5962

60-
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {filters=}"
6163
for idx, run in enumerate(runs):
6264
for key, val in run.config.items():
6365
if val == runs[0].config[key] or key.startswith(("slurm_", "timestamp")):
@@ -78,6 +80,7 @@
7880
input_col=input_col,
7981
wandb_run_filters=filters,
8082
slurm_vars=slurm_vars,
83+
training_run_ids=[run.id for run in runs],
8184
)
8285

8386
wandb.init(project="matbench-discovery", name=job_name, config=run_params)

scripts/metrics_table.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
),
2828
),
2929
"Voronoi RF": dict(
30-
n_runs=70,
30+
n_runs=68,
3131
filters=dict(
3232
created_at={"$gt": "2022-11-17", "$lt": "2022-11-28"},
3333
display_name={"$regex": "voronoi-features"},
@@ -83,12 +83,20 @@
8383

8484
assert len(runs) == n_runs, f"found {len(runs)=} for {model}, expected {n_runs}"
8585

86-
run_time = sum(run.summary.get("_wandb", {}).get("runtime", 0) for run in runs)
86+
each_run_time = [run.summary.get("_wandb", {}).get("runtime", 0) for run in runs]
87+
88+
run_time_total = sum(each_run_time)
8789
# NOTE we assume all jobs have the same metadata here
8890
metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
8991

9092
n_gpu, n_cpu = metadata.get("gpu_count", 0), metadata.get("cpu_count", 0)
91-
run_times[model] = {"Run time": run_time, "Hardware": f"GPU: {n_gpu}, CPU: {n_cpu}"}
93+
run_times[model] = {
94+
"Run time": run_time_total,
95+
"Hardware": f"GPU: {n_gpu}, CPU: {n_cpu}",
96+
}
97+
98+
99+
ax = (pd.Series(each_run_time) / 3600).hist(bins=100)
92100

93101

94102
# on 2022-11-28:

site/src/app.css

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ body > div {
3030
main {
3131
padding: calc(1ex + 2vw);
3232
flex: 1;
33+
container-type: inline-size;
3334
}
3435
button {
3536
color: var(--text-color);

site/svelte.config.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export default {
3535

3636
preprocess: [
3737
{
38-
// preprocess markdown citations @auth_1stwordtitle_yyyy into superscript
38+
// preprocess markdown citations @auth_1st-word-title_yyyy into superscript
3939
// links to bibliography items, href must match References.svelte
4040
markup: (file) => {
4141
if (file.filename.endsWith(`paper/+page.svx`)) {

0 commit comments

Comments
 (0)