Skip to content

Commit 76647e3

Browse files
committed
add data/mp/get_mp_traj.py to save a snapshot of all MP ionic steps on 2023-03-15 to be released as MBD canonical training set
add site/src/routes/models/per-element/+page.svelte
1 parent a6bfa74 commit 76647e3

29 files changed

+287
-64
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/charliermarsh/ruff-pre-commit
10-
rev: v0.0.255
10+
rev: v0.0.257
1111
hooks:
1212
- id: ruff
1313
args: [--fix]

data/mp/get_mp_traj.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Download all MP ionic steps on 2023-03-15."""
2+
3+
4+
# %%
5+
import os
6+
7+
import pandas as pd
8+
from emmet.core.tasks import TaskDoc
9+
from pymongo import MongoClient
10+
from pymongo.database import Database
11+
from tqdm import trange
12+
13+
from matbench_discovery import ROOT, today
14+
15+
__author__ = "Janosh Riebesell"
16+
__date__ = "2023-03-15"
17+
18+
module_dir = os.path.dirname(__file__)
19+
20+
21+
# %% access mp_core database directly through pymongo instead of API for speed
22+
host = "knowhere.lbl.gov"
23+
db_name = "mp_core"
24+
25+
with open(f"{ROOT}/site/.env") as file:
26+
text = file.read()
27+
user = text.split("user=")[1].split("\n")[0]
28+
password = text.split("password=")[1].split("\n")[0]
29+
30+
uri = f"mongodb://{user}:{password}@{host}/?authSource={db_name}"
31+
db: Database[TaskDoc] = MongoClient(uri)[db_name]
32+
33+
34+
# %%
35+
ids_path = f"{module_dir}/2023-03-15-mp-task-ids.csv.bz2"
36+
fields = "task_id formula_pretty run_type nsites task_type tags completed_at".split()
37+
38+
if os.path.isfile(ids_path):
39+
print(f"Found existing list of task IDs to query at {ids_path=}")
40+
df_tasks = pd.read_csv(ids_path).set_index("task_id")
41+
else:
42+
print(f"Querying all task docs from {db_name}\n{fields=}.\nThis takes a while...")
43+
task_docs = sorted(
44+
db["tasks"].find({}, fields), key=lambda doc: int(doc["task_id"].split("-")[1])
45+
)
46+
47+
print(f"{today}: {len(task_docs) = :,}")
48+
49+
df_tasks = pd.DataFrame(task_docs).drop(columns=["_id"]).set_index("task_id")
50+
df_tasks.task_type.value_counts(dropna=False).plot.pie()
51+
52+
df_tasks.to_csv(f"{module_dir}/{today}-mp-task-ids.csv.bz2")
53+
54+
55+
# %% inspect schema of a single task doc
56+
doc = db.tasks.find_one({"task_id": "mp-288"})
57+
# the most relevant task data is found in the 1st calc's ionic steps which are
58+
# the relaxation trajectory frames with the highest rate of change
59+
# docs[0]["calcs_reversed"][-1]["output"]["ionic_steps"]
60+
61+
62+
# %%
63+
batch_size = 10_000
64+
task_ids = df_tasks.index.tolist()
65+
66+
os.makedirs(f"{module_dir}/mp-tasks", exist_ok=True)
67+
# Iterate over task_ids in batches
68+
desc = "Loading MP task docs"
69+
pbar = trange(0, len(task_ids), batch_size, desc=desc, unit_scale=batch_size)
70+
for start_idx in pbar:
71+
# Define start and end indices for batch
72+
end_idx = min(start_idx + batch_size, len(task_ids))
73+
start_id = task_ids[start_idx]
74+
end_id = task_ids[end_idx - 1]
75+
batch_ids = task_ids[start_idx:end_idx]
76+
pbar.set_postfix_str(f"{start_id} to {end_id}")
77+
78+
out_path = f"{module_dir}/mp-tasks/{start_id}__{end_id}.json.gz"
79+
80+
# Check if output file for batch already exists
81+
if os.path.isfile(out_path):
82+
continue
83+
84+
# query batch of task docs
85+
batch_docs = list(
86+
db["tasks"].find(
87+
{"task_id": {"$in": batch_ids}},
88+
[*fields, "calcs_reversed.output.ionic_steps"],
89+
)
90+
)
91+
92+
# Convert documents to DataFrame and save to file
93+
df_batch = pd.DataFrame(batch_docs).set_index("task_id").drop(columns=["_id"])
94+
# handler=str needed since MongoDB ObjectId is not JSON serializable
95+
df_batch.reset_index().to_json(out_path, default_handler=str)
96+
# don't store df_batch to save memory
97+
98+
99+
# %% inspect saved task docs for expected data
100+
df_10k = pd.read_json(
101+
f"{module_dir}/mp-tasks/mp-1708653__mp-1735769.json.gz"
102+
).set_index("task_id")

matbench_discovery/plots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
model_labels = dict(
5252
bowsr_megnet="BOWSR + MEGNet",
5353
chgnet="CHGNet",
54-
chgnet_megnet="CHGNet + MEGNet",
54+
# chgnet_megnet="CHGNet + MEGNet",
5555
cgcnn_p="CGCNN+P",
5656
cgcnn="CGCNN",
5757
m3gnet_megnet="M3GNet + MEGNet",

matbench_discovery/preds.py

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def load_df_wbm_with_preds(
111111
return df_out
112112

113113

114+
# load WBM summary dataframe with all models' formation energy predictions (eV/atom)
114115
df_preds = load_df_wbm_with_preds().round(3)
115116
for combo in [["CHGNet", "M3GNet"]]:
116117
df_preds[" + ".join(combo)] = df_preds[combo].mean(axis=1)

matbench_discovery/slurm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def slurm_submit(
9393
slurm_vars = {
9494
f"slurm_{key}": val
9595
for key in SLURM_KEYS
96-
if (val := os.environ.get(f"SLURM_{key}".upper()))
96+
if (val := os.getenv(f"SLURM_{key}".upper()))
9797
}
9898
slurm_vars["slurm_timelimit"] = time
9999
if slurm_flags:

models/bowsr/test_bowsr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
slurm_max_parallel = 100
3737
energy_model = "megnet"
3838
job_name = f"bowsr-{energy_model}-wbm-{task_type}{'-debug' if DEBUG else ''}"
39-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
39+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
4040

4141
data_path = {
4242
"IS2RE": DATA_FILES.wbm_initial_structures,
@@ -62,7 +62,7 @@
6262

6363

6464
# %%
65-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
65+
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 0))
6666
out_path = f"{out_dir}/bowsr-preds-{slurm_array_task_id}.json.gz"
6767

6868
if os.path.isfile(out_path):

models/cgcnn/test_cgcnn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
debug = "slurm-submit" in sys.argv
3333
job_name = f"test-cgcnn-wbm-{task_type}{'-debug' if DEBUG else ''}"
3434
module_dir = os.path.dirname(__file__)
35-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
35+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3636

3737
slurm_vars = slurm_submit(
3838
job_name=job_name,
@@ -116,7 +116,7 @@
116116
data_loader=data_loader,
117117
)
118118

119-
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
119+
slurm_job_id = os.getenv("SLURM_JOB_ID", "debug")
120120
df.round(4).to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv")
121121
pred_col = f"{e_form_col}_pred_ens"
122122
assert pred_col in df, f"{pred_col=} not in {list(df)}"

models/cgcnn/train_cgcnn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
robust = "robust" in job_name.lower()
3737
ensemble_size = 10
3838
module_dir = os.path.dirname(__file__)
39-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
39+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
4040

4141
slurm_vars = slurm_submit(
4242
job_name=job_name,
@@ -54,7 +54,7 @@
5454
learning_rate = 3e-4
5555
batch_size = 128
5656
swa_start = None
57-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
57+
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 0))
5858
task_type: TaskType = "regression"
5959

6060

models/chgnet/analyze_chgnet.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Compare CHGNet long vs short relaxations."""
2+
3+
4+
# %%
5+
import os
6+
7+
import matplotlib.pyplot as plt
8+
import pandas as pd
9+
from pymatgen.core import Structure
10+
from pymatviz import density_scatter, plot_structure_2d, ptable_heatmap_plotly
11+
12+
from matbench_discovery import plots
13+
from matbench_discovery.data import DATA_FILES, df_wbm
14+
from matbench_discovery.preds import PRED_FILES
15+
16+
__author__ = "Janosh Riebesell"
17+
__date__ = "2023-03-06"
18+
19+
module_dir = os.path.dirname(__file__)
20+
del plots # https://github.com/PyCQA/pyflakes/issues/366
21+
22+
23+
# %%
24+
df_chgnet = pd.read_csv(PRED_FILES.__dict__["CHGNet"])
25+
df_chgnet = df_chgnet.set_index("material_id").add_suffix("_2000")
26+
df_chgnet_500 = pd.read_csv(PRED_FILES.__dict__["CHGNet"].replace("-06", "-04"))
27+
df_chgnet_500 = df_chgnet_500.set_index("material_id").add_suffix("_500")
28+
df_chgnet[list(df_chgnet_500)] = df_chgnet_500
29+
df_chgnet["formula"] = df_wbm.formula
30+
31+
e_form_2000 = "e_form_per_atom_chgnet_2000"
32+
e_form_500 = "e_form_per_atom_chgnet_500"
33+
34+
min_e_diff = 0.35
35+
df_bad = df_chgnet.query(f"{e_form_2000} - {e_form_500} > {min_e_diff}")
36+
37+
38+
# %%
39+
density_scatter(df=df_chgnet, x=e_form_2000, y=e_form_500)
40+
41+
42+
# %%
43+
fig = ptable_heatmap_plotly(df_bad.formula)
44+
title = "structures with larger error after longer relaxation"
45+
fig.layout.title.update(text=f"{len(df_bad)} {title}")
46+
47+
48+
# %%
49+
df_cse = pd.read_json(DATA_FILES.wbm_initial_structures).set_index("material_id")
50+
51+
52+
# %%
53+
n_rows, n_cols = 3, 4
54+
fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 4 * n_rows))
55+
n_struct = min(n_rows * n_cols, len(df_bad))
56+
struct_col = "initial_structure"
57+
58+
fig.suptitle(f"{n_struct} {struct_col} {title}", fontsize=16, fontweight="bold", y=1.05)
59+
for idx, (ax, row) in enumerate(
60+
zip(axs.flat, df_cse.loc[df_bad.index].itertuples()), 1
61+
):
62+
struct = Structure.from_dict(getattr(row, struct_col))
63+
plot_structure_2d(struct, ax=ax)
64+
_, spg_num = struct.get_space_group_info()
65+
formula = struct.composition.reduced_formula
66+
id = row.Index
67+
ax.set_title(f"{idx}. {formula} (spg={spg_num})\n{id}", fontweight="bold")
68+
69+
# fig.savefig(f"{ROOT}/tmp/figures/chgnet-bad-relax-structures.webp", dpi=300)

models/chgnet/test_chgnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# set large job array size for smaller data splits and faster testing/debugging
3434
slurm_array_task_count = 100
3535
job_name = f"chgnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
36-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
36+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3737

3838
slurm_vars = slurm_submit(
3939
job_name=job_name,
@@ -47,7 +47,7 @@
4747

4848

4949
# %%
50-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
50+
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 0))
5151
out_path = f"{out_dir}/chgnet-preds-{slurm_array_task_id}.json.gz"
5252

5353
if os.path.isfile(out_path):

models/m3gnet/test_m3gnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
# set large job array size for smaller data splits and faster testing/debugging
3333
slurm_array_task_count = 100
3434
job_name = f"m3gnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
35-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
35+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3636

3737
slurm_vars = slurm_submit(
3838
job_name=job_name,
@@ -49,7 +49,7 @@
4949

5050

5151
# %%
52-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 3))
52+
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 3))
5353
out_path = f"{out_dir}/m3gnet-preds-{slurm_array_task_id}.json.gz"
5454

5555
if os.path.isfile(out_path):

models/megnet/test_megnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
task_type = "chgnet_structure"
3232
module_dir = os.path.dirname(__file__)
3333
job_name = f"megnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
34-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
34+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3535
slurm_array_task_count = 20
3636

3737
slurm_vars = slurm_submit(
@@ -49,7 +49,7 @@
4949

5050

5151
# %%
52-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
52+
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 0))
5353
out_path = f"{out_dir}/megnet-e-form-preds.csv"
5454
if os.path.isfile(out_path):
5555
raise SystemExit(f"{out_path = } already exists, exciting early")

models/voronoi/voronoi_featurize_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
debug = "slurm-submit" in sys.argv
3636
job_name = f"voronoi-features-{data_name}{'-debug' if DEBUG else ''}"
3737
module_dir = os.path.dirname(__file__)
38-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
38+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3939
slurm_array_task_count = 50
4040

4141

@@ -51,7 +51,7 @@
5151

5252

5353
# %%
54-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
54+
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 0))
5555
run_name = f"{job_name}-{slurm_array_task_id}"
5656
out_path = f"{out_dir}/{run_name}.csv.bz2"
5757

models/wrenformer/test_wrenformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
debug = "slurm-submit" in sys.argv
3232
job_name = f"test-wrenformer-wbm-{task_type}{'-debug' if DEBUG else ''}"
3333
module_dir = os.path.dirname(__file__)
34-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
34+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3535

3636
slurm_vars = slurm_submit(
3737
job_name=job_name,
@@ -103,7 +103,7 @@
103103
)
104104
df = df.round(4)
105105

106-
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
106+
slurm_job_id = os.getenv("SLURM_JOB_ID", "debug")
107107
df.to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv")
108108

109109

models/wrenformer/train_wrenformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
ensemble_size = 10
2828
dataset = "mp"
2929
module_dir = os.path.dirname(__file__)
30-
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
30+
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3131

3232

3333
slurm_vars = slurm_submit(
@@ -44,7 +44,7 @@
4444
# %%
4545
learning_rate = 3e-4
4646
batch_size = 128
47-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
47+
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 0))
4848
input_col = "wyckoff_spglib"
4949

5050
print(f"\nJob started running {timestamp}")

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ ignore = [
8686
"D100", # Missing docstring in public module
8787
"D205", # 1 blank line required between summary line and description
8888
"E731", # Do not assign a lambda expression, use a def
89+
"PLW1508", # Invalid type for environment variable default
8990
"PLW2901", # Outer for loop variable overwritten by inner assignment target
9091
]
9192
pydocstyle.convention = "google"

scripts/compile_metrics.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@
9292
"Slurm Jobs": n_runs,
9393
}
9494

95-
test_stats["M3GNet + MEGNet"] = test_stats["M3GNet"].copy()
96-
test_stats["M3GNet + MEGNet"][time_col] = (
97-
test_stats["MEGNet"][time_col] + test_stats["M3GNet"][time_col]
98-
)
95+
# test_stats["M3GNet + MEGNet"] = test_stats["M3GNet"].copy()
96+
# test_stats["M3GNet + MEGNet"][time_col] = (
97+
# test_stats["MEGNet"][time_col] + test_stats["M3GNet"][time_col]
98+
# )
9999
test_stats["CGCNN+P"] = {}
100100

101101

0 commit comments

Comments
 (0)