Skip to content

Commit 3d42214

Browse files
committed
add mb_discovery/wrenformer/{mp,m3gnet_train_set}
1 parent f1cc667 commit 3d42214

File tree

5 files changed

+103
-2
lines changed

5 files changed

+103
-2
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ job-logs/
1919

2020
# slurm logs
2121
slurm-*out
22+
mb_discovery/**/*.csv

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# ).set_index("material_id")
3434

3535
dfs["Wrenformer"] = pd.read_csv(
36-
f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2"
36+
f"{ROOT}/data/2022-08-16-wrenformer-preds.csv.bz2"
3737
).set_index("material_id")
3838

3939

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
data_path = (
2121
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
22-
# f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2"
22+
# f"{ROOT}/data/2022-08-16-wrenformer-preds.csv.bz2"
2323
)
2424
df = pd.read_csv(data_path).set_index("material_id")
2525
legend_label = "Wren"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# %%
2+
from datetime import datetime
3+
4+
import pandas as pd
5+
from aviary import ROOT
6+
from aviary.utils import as_dict_handler
7+
from aviary.wren.utils import get_aflow_label_from_spglib
8+
from mp_api.client import MPRester
9+
10+
11+
"""
12+
Download all MP formation and above hull energies on 2022-08-13 for training a
13+
Wrenformer ensemble.
14+
15+
Related EDA of MP formation energies:
16+
https://github.com/janosh/pymatviz/blob/main/examples/mp_bimodal_e_form.ipynb
17+
"""
18+
19+
__author__ = "Janosh Riebesell"
20+
__date__ = "2022-08-13"
21+
22+
today = f"{datetime.now():%Y-%m-%d}"
23+
24+
25+
# %% query all MP formation energies on 2022-08-13
26+
fields = [
27+
"material_id",
28+
"task_ids",
29+
"formula_pretty",
30+
"formation_energy_per_atom",
31+
"energy_per_atom",
32+
"structure",
33+
"symmetry",
34+
"energy_above_hull",
35+
]
36+
with MPRester() as mpr:
37+
docs = mpr.summary.search(fields=fields)
38+
39+
print(f"{today}: {len(docs) = :,}")
40+
# 2022-08-13: len(docs) = 146,323
41+
42+
43+
# %%
44+
df = pd.DataFrame(
45+
[{key: getattr(doc, key, None) for key in fields} for doc in docs]
46+
).set_index("material_id")
47+
48+
df["spacegroup_number"] = df.pop("symmetry").map(lambda x: x.number)
49+
50+
df["wyckoff"] = df.structure.map(get_aflow_label_from_spglib)
51+
52+
df.to_json(
53+
f"{ROOT}/datasets/{today}-mp-all-energies.json.gz", default_handler=as_dict_handler
54+
)
55+
56+
# df = pd.read_json(f"{ROOT}/datasets/2022-08-13-mp-all-energies.json.gz")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# %%
2+
from __future__ import annotations
3+
4+
import os
5+
from datetime import datetime
6+
7+
import pandas as pd
8+
import wandb
9+
from aviary.wrenformer.deploy import deploy_wandb_checkpoints
10+
11+
12+
__author__ = "Janosh Riebesell"
13+
__date__ = "2022-08-15"
14+
15+
"""
16+
Script that downloads checkpoints for an ensemble of Wrenformer models trained on the MP
17+
formation energies, then makes predictions on some dataset, prints ensemble metrics and
18+
stores predictions to CSV.
19+
"""
20+
21+
MODULE_DIR = os.path.dirname(__file__)
22+
23+
24+
# %%
25+
today = f"{datetime.now():%Y-%m-%d}"
26+
# download wbm-steps-summary.csv (23.31 MB)
27+
data_path = "https://figshare.com/files/36714216?private_link=ff0ad14505f9624f0c05"
28+
df = pd.read_csv(data_path).set_index("material_id")
29+
30+
31+
target_col = "e_form_per_atom"
32+
df[target_col] = df.e_form / df.n_sites
33+
34+
wandb.login()
35+
wandb_api = wandb.Api()
36+
runs = wandb_api.runs(
37+
"aviary/mp", filters={"tags": {"$in": ["wrenformer-e_form-ensemble-1"]}}
38+
)
39+
40+
df, ensemble_metrics = deploy_wandb_checkpoints(
41+
runs, df, input_col="wyckoff", target_col=target_col
42+
)
43+
44+
df.to_csv(f"{MODULE_DIR}/{today}-wrenformer-preds-{target_col}.csv")

0 commit comments

Comments
 (0)