Skip to content

Commit 4f6d853

Browse files
committed
add mb_discovery/energy/slurm_e_above_hull.py
mv mb_discovery/compute_formation_energy.py -> mb_discovery/energy/__init__.py
1 parent 2517855 commit 4f6d853

9 files changed

+76
-44
lines changed

mb_discovery/build_phase_diagram.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
from pymatgen.ext.matproj import MPRester
1414

1515
from mb_discovery import ROOT
16-
from mb_discovery.compute_formation_energy import (
17-
get_elemental_ref_entries,
18-
get_form_energy_per_atom,
19-
)
16+
from mb_discovery.energy import get_elemental_ref_entries, get_form_energy_per_atom
2017

2118
today = f"{datetime.now():%Y-%m-%d}"
2219
module_dir = os.path.dirname(__file__)

mb_discovery/compute_formation_energy.py mb_discovery/energy/__init__.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import itertools
22

33
import pandas as pd
4-
from pymatgen.analysis.phase_diagram import Entry
5-
from pymatgen.entries.computed_entries import ComputedEntry
4+
from pymatgen.analysis.phase_diagram import Entry, PDEntry
65
from tqdm import tqdm
76

87
from mb_discovery import ROOT
@@ -47,9 +46,7 @@ def get_elemental_ref_entries(
4746
mp_elem_refs_path = f"{ROOT}/data/2022-09-19-mp-elemental-reference-entries.json"
4847
try:
4948
mp_elem_reference_entries = (
50-
pd.read_json(mp_elem_refs_path, typ="series")
51-
.map(ComputedEntry.from_dict)
52-
.to_dict()
49+
pd.read_json(mp_elem_refs_path, typ="series").map(PDEntry.from_dict).to_dict()
5350
)
5451
except FileNotFoundError:
5552
mp_elem_reference_entries = None

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
dfs["m3gnet"] = pd.read_json(
3838
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
3939
).set_index("material_id")
40-
dfs["Wrenformer"] = pd.read_csv(
40+
dfs["wrenformer"] = pd.read_csv(
4141
f"{ROOT}/models/wrenformer/mp/"
4242
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
4343
).set_index("material_id")

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from datetime import datetime
33

44
import pandas as pd
5+
from sklearn.metrics import f1_score
56

67
from mb_discovery import ROOT
78
from mb_discovery.plots import StabilityCriterion, precision_recall_vs_calc_count
@@ -15,6 +16,7 @@
1516
# %%
1617
DATA_DIR = f"{ROOT}/data/2022-06-11-from-rhys"
1718
df_hull = pd.read_csv(f"{DATA_DIR}/wbm-e-above-mp-hull.csv").set_index("material_id")
19+
rare = "all"
1820

1921
dfs: dict[str, pd.DataFrame] = {}
2022
for model_name in ("wren", "cgcnn", "voronoi"):
@@ -47,10 +49,9 @@
4749
# %%
4850
stability_crit: StabilityCriterion = "energy"
4951
colors = "tab:blue tab:orange teal tab:pink black red turquoise tab:purple".split()
52+
F1s: dict[str, float] = {}
5053

51-
for (model_name, df), color in zip(dfs.items(), colors):
52-
rare = "all"
53-
54+
for model_name, df in dfs.items():
5455
# from pymatgen.core import Composition
5556
# rare = "no-lanthanides"
5657
# df["contains_rare_earths"] = df.composition.map(
@@ -91,23 +92,46 @@
9192
assert n_nans < 10, f"{model_name=} has {n_nans=}"
9293
df = df.dropna()
9394

95+
F1 = f1_score(df.e_above_mp_hull < 0, df.e_above_hull_pred < 0)
96+
F1s[model_name] = F1
97+
98+
99+
# %%
100+
for (model_name, F1), color in zip(sorted(F1s.items(), key=lambda x: x[1]), colors):
101+
df = dfs[model_name]
102+
94103
ax = precision_recall_vs_calc_count(
95104
e_above_hull_error=df.e_above_hull_pred + df.e_above_mp_hull,
96105
e_above_hull_true=df.e_above_mp_hull,
97106
color=color,
98-
label=model_name,
107+
label=f"{model_name} {F1=:.2}",
99108
intersect_lines="recall_xy", # or "precision_xy", None, 'all'
100109
stability_crit=stability_crit,
101110
std_pred=std_total,
102111
)
103112

113+
# optimal recall line finds all stable materials without any false positives
114+
# can be included to confirm all models start out of with near optimal recall
115+
# and to see how much each model overshoots total n_stable
116+
n_below_hull = sum(df_hull.e_above_mp_hull < 0)
117+
ax.plot(
118+
[0, n_below_hull],
119+
[0, 100],
120+
color="green",
121+
linestyle="dashed",
122+
linewidth=1,
123+
label="Optimal Recall",
124+
)
125+
104126
ax.figure.set_size_inches(10, 9)
105127
ax.set(xlim=(0, None))
106128
# keep this outside loop so all model names appear in legend
107129
ax.legend(frameon=False, loc="lower right")
108130

109131
img_name = f"{today}-precision-recall-vs-calc-count-{rare=}"
110132
ax.set(title=img_name.replace("-", "/", 2).replace("-", " ").title())
133+
# x-ticks every 10k materials
134+
ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000))
111135

112136

113137
# %%

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,26 @@
1919
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
2020
).set_index("material_id")
2121

22+
df_wrenformer = pd.read_csv(
23+
f"{ROOT}/models/wrenformer/mp/"
24+
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
25+
).set_index("material_id")
26+
2227
df_hull = pd.read_csv(
2328
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
2429
).set_index("material_id")
2530

26-
df_wren["e_above_mp_hull"] = df_hull.e_above_mp_hull
27-
assert df_wren.e_above_mp_hull.isna().sum() == 0
31+
df_wrenformer["e_above_mp_hull"] = df_hull.e_above_mp_hull
32+
assert df_wrenformer.e_above_mp_hull.isna().sum() == 0
2833

29-
target_col = "e_form_target"
34+
target_col = "e_form_per_atom"
35+
# target_col = "e_form_target"
3036

3137
# make sure we average the expected number of ensemble member predictions
32-
assert df_wren.filter(regex=r"_pred_\d").shape[1] == 10
38+
assert df_wrenformer.filter(regex=r"_pred_\d").shape[1] == 10
3339

34-
df_wren["e_above_hull_pred"] = (
35-
df_wren.filter(regex=r"_pred_\d").mean(axis=1) - df_wren[target_col]
40+
df_wrenformer["e_above_hull_pred"] = (
41+
df_wrenformer.filter(regex=r"_pred_\d").mean(axis=1) - df_wrenformer[target_col]
3642
)
3743

3844

@@ -42,7 +48,7 @@
4248
assert len(markers) == 5 # number of WBM rounds of element substitution
4349

4450
for idx, marker in enumerate(markers, 1):
45-
df = df_wren[df_wren.index.str.startswith(f"wbm-step-{idx}")]
51+
df = df_wrenformer[df_wrenformer.index.str.startswith(f"wbm-step-{idx}")]
4652
title = f"Batch {idx} ({len(df.filter(like='e_').dropna()):,})"
4753
assert 1e4 < len(df) < 1e5, print(f"{len(df) = :,}")
4854

@@ -62,4 +68,4 @@
6268

6369

6470
img_path = f"{ROOT}/figures/{today}-rolling-mae-vs-hull-dist-wbm-batches-{rare=}.pdf"
65-
# plt.savefig(img_path)
71+
# fig.savefig(img_path)

mb_discovery/plots.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,9 @@ def precision_recall_vs_calc_count(
320320
"""
321321
ax = ax or plt.gca()
322322

323-
for series in (e_above_hull_error, e_above_hull_true):
324-
n_nans = series.isna().sum()
325-
assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"
323+
# for series in (e_above_hull_error, e_above_hull_true):
324+
# n_nans = series.isna().sum()
325+
# assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"
326326

327327
is_fresh_ax = len(ax.lines) == 0
328328

@@ -412,8 +412,12 @@ def precision_recall_vs_calc_count(
412412
ylabel = "Precision and Recall (%)"
413413
ax.set(ylim=(0, 100), xlabel=xlabel, ylabel=ylabel)
414414

415-
[precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
416-
[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")
415+
[precision] = ax.plot(
416+
(0, 0), (0, 0), "black", linestyle="-", linewidth=line_kwargs["linewidth"]
417+
)
418+
[recall] = ax.plot(
419+
(0, 0), (0, 0), "black", linestyle=":", linewidth=line_kwargs["linewidth"]
420+
)
417421
legend = ax.legend(
418422
[precision, recall],
419423
("Precision", "Recall"),

models/bowsr/slurm_array_bowsr_wbm.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
To slurm submit this file, use
2121
2222
```sh
23-
# slurm will not create logdir automatically and fail if missing
24-
mkdir -p models/bowsr/slurm_logs
23+
log_dir=models/bowsr/$(date +"%Y-%m-%d")-bowsr-megnet-wbm
24+
job_name=bowsr-megnet-wbm-IS2RE
25+
mkdir -p $log_dir # slurm fails if log_dir is missing
26+
2527
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-500 \
26-
--time 12:0:0 --job-name bowsr-megnet-wbm-IS2RE --mem 12000 \
27-
--output models/bowsr/slurm_logs/slurm-%A-%a.out \
28+
--time 12:0:0 --job-name $job_name --mem 12000 \
29+
--output $log_dir/slurm-%A-%a.out \
2830
--wrap "TF_CPP_MIN_LOG_LEVEL=2 python models/bowsr/slurm_array_bowsr_wbm.py"
2931
```
3032
@@ -50,7 +52,7 @@
5052
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
5153
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
5254
# set large fallback job array size for fast testing/debugging
53-
job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
55+
slurm_array_task_count = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
5456

5557
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
5658
print(f"{slurm_job_id = }")
@@ -60,7 +62,6 @@
6062

6163
today = f"{datetime.now():%Y-%m-%d}"
6264
out_dir = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
63-
os.makedirs(out_dir, exist_ok=True)
6465
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
6566

6667
if os.path.isfile(json_out_path):
@@ -81,6 +82,7 @@
8182
maml_version=version("maml"),
8283
slurm_job_id=slurm_job_id,
8384
slurm_array_task_id=slurm_array_task_id,
85+
slurm_array_task_count=slurm_array_task_count,
8486
data_path=data_path,
8587
bayes_optim_kwargs=bayes_optim_kwargs,
8688
optimize_kwargs=optimize_kwargs,
@@ -100,10 +102,10 @@
100102

101103

102104
# %%
103-
print(f"Loading from {data_path=}")
105+
print(f"Loading from {data_path = }")
104106
df_wbm = pd.read_json(data_path).set_index("material_id")
105107

106-
df_this_job = np.array_split(df_wbm, job_array_size + 1)[slurm_array_task_id]
108+
df_this_job = np.array_split(df_wbm, slurm_array_task_count)[slurm_array_task_id - 1]
107109

108110

109111
# %%

models/m3gnet/join_m3gnet_relax_results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tqdm import tqdm
1212

1313
from mb_discovery import ROOT, as_dict_handler
14-
from mb_discovery.compute_formation_energy import get_form_energy_per_atom
14+
from mb_discovery.energy import get_form_energy_per_atom
1515

1616
__author__ = "Janosh Riebesell"
1717
__date__ = "2022-08-16"

models/m3gnet/slurm_array_m3gnet_relax_wbm.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
To slurm submit this file, use
1919
2020
```sh
21-
# slurm will not create logdir automatically and fail if missing
22-
mkdir -p models/m3gnet/slurm_logs
21+
job_name=m3gnet-wbm-relax-IS2RE
22+
log_dir=models/m3gnet/$(date +"%Y-%m-%d")-$job_name
23+
mkdir -p $log_dir # slurm fails if log_dir is missing
24+
2325
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-100 \
24-
--time 3:0:0 --job-name m3gnet-wbm-relax-IS2RE --mem 12000 \
25-
--output models/m3gnet/slurm_logs/slurm-%A-%a.out \
26+
--time 3:0:0 --job-name $job_name --mem 12000 \
27+
--output $log_dir/slurm-%A-%a.out \
2628
--wrap "TF_CPP_MIN_LOG_LEVEL=2 python models/m3gnet/slurm_array_m3gnet_relax_wbm.py"
2729
```
2830
@@ -43,16 +45,15 @@
4345
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
4446
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
4547
# set large fallback job array size for fast testing/debugging
46-
job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
48+
slurm_array_task_count = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
4749

4850
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
4951
print(f"{slurm_job_id = }")
5052
print(f"{slurm_array_task_id = }")
5153
print(f"{version('m3gnet') = }")
5254

5355
today = f"{datetime.now():%Y-%m-%d}"
54-
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-relax-{task_type}"
55-
os.makedirs(out_dir, exist_ok=True)
56+
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-{task_type}"
5657
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
5758

5859
if os.path.isfile(json_out_path):
@@ -67,12 +68,13 @@
6768
print(f"Loading from {data_path=}")
6869
df_wbm = pd.read_json(data_path).set_index("material_id")
6970

70-
df_this_job = np.array_split(df_wbm, job_array_size)[slurm_array_task_id]
71+
df_this_job = np.array_split(df_wbm, slurm_array_task_count)[slurm_array_task_id - 1]
7172

7273
run_params = dict(
7374
m3gnet_version=version("m3gnet"),
7475
slurm_job_id=slurm_job_id,
7576
slurm_array_task_id=slurm_array_task_id,
77+
slurm_array_task_count=slurm_array_task_count,
7678
data_path=data_path,
7779
task_type=task_type,
7880
)

0 commit comments

Comments
 (0)