Skip to content

Commit 0473994

Browse files
committed
rename job_id->slurm_job_id in wandb run.config, rename 2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz to 2022-08-16-m3gnet-wbm-IS2RE.json.gz
1 parent c487910 commit 0473994

12 files changed

+64
-49
lines changed

.pre-commit-config.yaml

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@ repos:
1212
- id: isort
1313

1414
- repo: https://github.com/psf/black
15-
rev: 22.6.0
15+
rev: 22.8.0
1616
hooks:
1717
- id: black
1818

1919
- repo: https://github.com/pycqa/flake8
20-
rev: 4.0.1
20+
rev: 5.0.4
2121
hooks:
2222
- id: flake8
2323
additional_dependencies: [flake8-bugbear]
2424

2525
- repo: https://github.com/asottile/pyupgrade
26-
rev: v2.34.0
26+
rev: v2.38.2
2727
hooks:
2828
- id: pyupgrade
2929
args: [--py39-plus]
@@ -45,19 +45,19 @@ repos:
4545
- id: trailing-whitespace
4646

4747
- repo: https://github.com/pre-commit/mirrors-mypy
48-
rev: v0.961
48+
rev: v0.981
4949
hooks:
5050
- id: mypy
5151
additional_dependencies: [types-pyyaml]
5252

5353
- repo: https://github.com/codespell-project/codespell
54-
rev: v2.1.0
54+
rev: v2.2.1
5555
hooks:
5656
- id: codespell
5757
stages: [commit, commit-msg]
5858
exclude_types: [csv, html, json]
5959

6060
- repo: https://github.com/PyCQA/autoflake
61-
rev: v1.4
61+
rev: v1.6.1
6262
hooks:
6363
- id: autoflake

mb_discovery/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import os
4-
from typing import Any, Generator, Sequence
4+
from collections.abc import Generator, Sequence
5+
from typing import Any
56

67
PKG_DIR = os.path.dirname(__file__)
78
ROOT = os.path.dirname(PKG_DIR)

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@
3535
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
3636
).set_index("material_id")
3737
dfs["m3gnet"] = pd.read_json(
38-
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
38+
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
3939
).set_index("material_id")
4040
dfs["Wrenformer"] = pd.read_csv(
4141
f"{ROOT}/models/wrenformer/mp/"
4242
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
4343
).set_index("material_id")
44+
dfs["bowsr_megnet"] = pd.read_json(
45+
f"{ROOT}/models/bowsr/2022-09-22-bowsr-wbm-megnet-IS2RE.json.gz"
46+
).set_index("material_id")
4447

4548

4649
df_hull = pd.read_csv(
@@ -53,9 +56,6 @@
5356
).set_index("material_id")
5457

5558

56-
dfs["m3gnet"] = dfs.pop("M3Gnet")
57-
58-
5959
# %%
6060
if "wren" in dfs:
6161
df = dfs["wren"]
@@ -66,14 +66,17 @@
6666
if "m3gnet" in dfs:
6767
df = dfs["m3gnet"]
6868
df["e_form_per_atom_pred"] = df.e_form_ppd_2022_01_25
69+
if "bowsr_megnet" in dfs:
70+
df = dfs["bowsr_megnet"]
71+
df["e_form_per_atom_pred"] = df.e_form_per_atom_bowsr
6972

7073

7174
# %%
7275
which_energy: WhichEnergy = "true"
7376
stability_crit: StabilityCriterion = "energy"
7477
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
7578

76-
df = dfs[(model_name := "wren")]
79+
df = dfs[(model_name := "bowsr_megnet")]
7780

7881
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
7982
df["e_form_per_atom"] = df_wbm.e_form_per_atom
@@ -91,7 +94,7 @@
9194
ax=ax,
9295
)
9396

94-
title = f"Batch {batch_idx} ({len(batch_df):,})"
97+
title = f"Batch {batch_idx} ({len(batch_df.filter(like='e_').dropna()):,})"
9598
ax.set(title=title)
9699

97100

@@ -103,13 +106,17 @@
103106
ax=axs.flat[-1],
104107
)
105108

106-
axs.flat[-1].set(title=f"Combined {batch_idx} ({len(df):,})")
109+
axs.flat[-1].set(title=f"Combined ({len(df.filter(like='e_').dropna()):,})")
107110
axs.flat[0].legend(frameon=False, loc="upper left")
108111

109112
img_name = (
110113
f"{today}-{model_name}-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf"
111114
)
112-
# plt.savefig(f"{ROOT}/figures/{img_name}")
115+
fig.suptitle(img_name.replace("-", "/", 2).replace("-", " "), y=1.07, fontsize=16)
116+
117+
118+
# %%
119+
ax.figure.savefig(f"{ROOT}/figures/{img_name}")
113120

114121

115122
# %%

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55

66
from mb_discovery import ROOT
7-
from mb_discovery.plots import StabilityCriterion, plt, precision_recall_vs_calc_count
7+
from mb_discovery.plots import StabilityCriterion, precision_recall_vs_calc_count
88

99
__author__ = "Rhys Goodall, Janosh Riebesell"
1010
__date__ = "2022-06-18"
@@ -24,14 +24,18 @@
2424
dfs[model_name] = df
2525

2626
dfs["M3GNet"] = pd.read_json(
27-
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
27+
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
2828
).set_index("material_id")
2929

3030
dfs["Wrenformer"] = pd.read_csv(
3131
f"{ROOT}/models/wrenformer/mp/"
3232
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
3333
).set_index("material_id")
3434

35+
dfs["BOWSR Megnet"] = pd.read_json(
36+
f"{ROOT}/models/bowsr/2022-09-22-bowsr-wbm-megnet-IS2RE.json.gz"
37+
).set_index("material_id")
38+
3539
print(f"loaded models: {list(dfs)}")
3640

3741

@@ -43,11 +47,9 @@
4347

4448
# %%
4549
stability_crit: StabilityCriterion = "energy"
50+
colors = "tab:blue tab:orange teal tab:pink black red turquoise tab:purple".split()
4651

47-
for (model_name, df), color in zip(
48-
dfs.items(),
49-
("tab:blue", "tab:orange", "teal", "tab:pink", "black", "red", "turquoise"),
50-
):
52+
for (model_name, df), color in zip(dfs.items(), colors):
5153
rare = "all"
5254

5355
# from pymatgen.core import Composition
@@ -76,6 +78,8 @@
7678
# other cases are unexpected
7779
assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}"
7880
model_preds = df[pred_cols].mean(axis=1)
81+
elif "BOWSR" in model_name:
82+
model_preds = df.e_form_per_atom_bowsr
7983
else:
8084
raise ValueError(f"Unhandled {model_name = }")
8185
except AttributeError as exc:
@@ -103,6 +107,7 @@
103107
# keep this outside loop so all model names appear in legend
104108
ax.legend(frameon=False, loc="lower right")
105109

110+
111+
# %%
106112
img_path = f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-{rare=}.pdf"
107-
if False:
108-
plt.savefig(img_path)
113+
ax.figure.savefig(img_path)

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
assert len(markers) == 5 # number of WBM rounds of element substitution
4343

4444
for idx, marker in enumerate(markers, 1):
45-
title = f"Batch {idx}"
4645
df = df_wbm[df_wbm.index.str.startswith(f"wbm-step-{idx}")]
46+
title = f"Batch {idx} ({len(df.filter(like='e_').dropna()):,})"
4747
assert 1e4 < len(df) < 1e5, print(f"{len(df) = :,}")
4848

4949
rolling_mae_vs_hull_dist(

mb_discovery/plots.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import Any, Literal, Sequence, get_args
3+
from collections.abc import Sequence
4+
from typing import Any, Literal, get_args
45

56
import matplotlib.pyplot as plt
67
import numpy as np

models/bowsr/slurm_array_bowsr_wbm.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,21 @@
4747
data_path = f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz"
4848

4949
module_dir = os.path.dirname(__file__)
50-
job_id = os.environ.get("SLURM_JOB_ID", "debug")
51-
job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
50+
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
51+
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
5252
# set large fallback job array size for fast testing/debugging
5353
job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
5454

5555
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
56-
print(f"{job_id = }")
57-
print(f"{job_array_id = }")
56+
print(f"{slurm_job_id = }")
57+
print(f"{slurm_array_task_id = }")
5858
print(f"{version('maml') = }")
5959
print(f"{version('megnet') = }")
6060

6161
today = f"{datetime.now():%Y-%m-%d}"
6262
out_dir = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
6363
os.makedirs(out_dir, exist_ok=True)
64-
json_out_path = f"{out_dir}/{job_array_id}.json.gz"
64+
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
6565

6666
if os.path.isfile(json_out_path):
6767
raise SystemExit(f"{json_out_path = } already exists, exciting early")
@@ -79,8 +79,8 @@
7979
run_params = dict(
8080
megnet_version=version("megnet"),
8181
maml_version=version("maml"),
82-
job_id=job_id,
83-
job_array_id=job_array_id,
82+
slurm_job_id=slurm_job_id,
83+
slurm_array_task_id=slurm_array_task_id,
8484
data_path=data_path,
8585
bayes_optim_kwargs=bayes_optim_kwargs,
8686
optimize_kwargs=optimize_kwargs,
@@ -93,7 +93,7 @@
9393
wandb.init(
9494
entity="janosh",
9595
project="matbench-discovery",
96-
name=f"bowsr-megnet-wbm-{task_type}-{job_id}-{job_array_id}",
96+
name=f"bowsr-megnet-wbm-{task_type}-{slurm_job_id}-{slurm_array_task_id}",
9797
config=run_params,
9898
)
9999

@@ -102,7 +102,7 @@
102102
print(f"Loading from {data_path=}")
103103
df_wbm = pd.read_json(data_path).set_index("material_id")
104104

105-
df_this_job = np.array_split(df_wbm, job_array_size + 1)[job_array_id]
105+
df_this_job = np.array_split(df_wbm, job_array_size + 1)[slurm_array_task_id]
106106

107107

108108
# %%

models/m3gnet/eda_wbm_pre_vs_post_m3gnet_relaxation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828

2929
# %%
3030
df_m3gnet_is2re = pd.read_json(
31-
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
31+
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
3232
).set_index("material_id")
3333
df_m3gnet_rs2re = pd.read_json(
34-
f"{ROOT}/models/m3gnet/2022-08-19-m3gnet-wbm-relax-results-RS2RE.json.gz"
34+
f"{ROOT}/models/m3gnet/2022-08-19-m3gnet-wbm-RS2RE.json.gz"
3535
).set_index("material_id")
3636

3737

@@ -226,5 +226,5 @@
226226
# %% write df back to compressed JSON
227227
# filter out columns containing 'rs2re'
228228
# df_m3gnet_is2re.reset_index().filter(regex="^((?!rs2re).)*$").to_json(
229-
# f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE-2.json.gz"
229+
# f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE-2.json.gz"
230230
# ).set_index("material_id")

models/m3gnet/join_m3gnet_relax_results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,5 +118,5 @@
118118
out_path = f"{ROOT}/models/m3gnet/{today}-m3gnet-wbm-relax-{task_type}.json.gz"
119119
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
120120

121-
# out_path = f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
121+
# out_path = f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
122122
# df_m3gnet = pd.read_json(out_path).set_index("material_id")

models/m3gnet/slurm_array_m3gnet_relax_wbm.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,20 @@
4040
task_type = "IS2RE"
4141
# task_type = "RS2RE"
4242

43-
job_id = os.environ.get("SLURM_JOB_ID", "debug")
44-
job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
43+
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
44+
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
4545
# set large fallback job array size for fast testing/debugging
4646
job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
4747

4848
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
49-
print(f"{job_id = }")
50-
print(f"{job_array_id = }")
49+
print(f"{slurm_job_id = }")
50+
print(f"{slurm_array_task_id = }")
5151
print(f"{version('m3gnet') = }")
5252

5353
today = f"{datetime.now():%Y-%m-%d}"
5454
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-relax-{task_type}"
5555
os.makedirs(out_dir, exist_ok=True)
56-
json_out_path = f"{out_dir}/{job_array_id}.json.gz"
56+
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
5757

5858
if os.path.isfile(json_out_path):
5959
raise SystemExit(f"{json_out_path = } already exists, exciting early")
@@ -67,20 +67,20 @@
6767
print(f"Loading from {data_path=}")
6868
df_wbm = pd.read_json(data_path).set_index("material_id")
6969

70-
df_this_job = np.array_split(df_wbm, job_array_size)[job_array_id]
70+
df_this_job = np.array_split(df_wbm, job_array_size)[slurm_array_task_id]
7171

7272
run_params = dict(
7373
m3gnet_version=version("m3gnet"),
74-
job_id=job_id,
75-
job_array_id=job_array_id,
74+
slurm_job_id=slurm_job_id,
75+
slurm_array_task_id=slurm_array_task_id,
7676
data_path=data_path,
7777
)
7878
if wandb.run is None:
7979
wandb.login()
8080

8181
wandb.init(
8282
project="m3gnet",
83-
name=f"m3gnet-wbm-relax-{task_type}-{job_id}-{job_array_id}",
83+
name=f"m3gnet-wbm-relax-{task_type}-{slurm_job_id}-{slurm_array_task_id}",
8484
config=run_params,
8585
)
8686

readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[![Link check](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml)
44
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/janosh/matbench-discovery/main.svg?badge_token=Qza33izjRxSbegTqeSyDvA)](https://results.pre-commit.ci/latest/github/janosh/matbench-discovery/main?badge_token=Qza33izjRxSbegTqeSyDvA)
5-
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg)](https://python.org/downloads)
5+
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python)](https://python.org/downloads)
66

77
Several new energy models specifically designed to handle unrelaxed structures were published in 2021/22
88

tests/test_plots.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import Any, Sequence
3+
from collections.abc import Sequence
4+
from typing import Any
45

56
import matplotlib.pyplot as plt
67
import pandas as pd

0 commit comments

Comments
 (0)