Skip to content

Commit c855e51

Browse files
committed
add test_hist_classified_stable_as_func_of_hull_dist()
refactor plot funcs to use e_above_hull_pred and e_above_hull_true as main inputs rewrite plot scripts to match new signature
1 parent 17df9d0 commit c855e51

10 files changed

+238
-193
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pretrained/
1515

1616
# Weights and Biases logs
1717
wandb/
18+
job-logs/
1819

1920
# slurm logs
2021
slurm-*out

mb_discovery/m3gnet/join_and_plot_m3gnet_relax_results.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import gzip
55
import io
66
import pickle
7+
import urllib.request
78
from datetime import datetime
89
from glob import glob
9-
from urllib.request import urlopen
1010

1111
import pandas as pd
1212
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram, PDEntry
@@ -15,7 +15,7 @@
1515

1616
from mb_discovery import ROOT, as_dict_handler
1717
from mb_discovery.plot_scripts.plot_funcs import (
18-
hist_classify_stable_as_func_of_hull_dist,
18+
hist_classified_stable_as_func_of_hull_dist,
1919
)
2020

2121

@@ -25,7 +25,7 @@
2525
# %%
2626
task_type = "RS3RE"
2727
date = "2022-08-19"
28-
glob_pattern = f"{date}-m3gnet-relax-wbm-{task_type}/*.json.gz"
28+
glob_pattern = f"{date}-m3gnet-wbm-relax-{task_type}/*.json.gz"
2929
file_paths = sorted(glob(f"{ROOT}/data/{glob_pattern}"))
3030
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
3131

@@ -68,7 +68,7 @@
6868
# %%
6969
# 2022-01-25-ppd-mp+wbm.pkl.gz (235 MB)
7070
ppd_pickle_url = "https://figshare.com/ndownloader/files/36669624"
71-
zipped_file = urlopen(ppd_pickle_url)
71+
zipped_file = urllib.request.urlopen(ppd_pickle_url)
7272

7373
ppd_mp_wbm: PatchedPhaseDiagram = pickle.load(
7474
io.BytesIO(gzip.decompress(zipped_file.read()))
@@ -114,15 +114,21 @@
114114

115115

116116
# %%
117-
out_path = f"{ROOT}/data/{today}-m3gnet-relax-wbm-{task_type}.json.gz"
117+
out_path = f"{ROOT}/data/{today}-m3gnet-wbm-relax-{task_type}.json.gz"
118118
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
119119

120+
out_path = f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
121+
df_m3gnet = pd.read_json(out_path)
122+
120123

121124
# %%
122-
ax_hull_dist_hist = hist_classify_stable_as_func_of_hull_dist(
123-
formation_energy_targets=df_m3gnet.e_form_ppd_2022_01_25,
124-
formation_energy_preds=df_m3gnet.e_form_m3gnet_from_ppd,
125-
e_above_hull_vals=df_m3gnet.e_above_mp_hull,
125+
df_m3gnet["e_above_hull_pred"] = ( # TODO fix this incorrect e_above_hull_pred
126+
df_m3gnet["e_form_m3gnet_from_ppd"] - df_m3gnet["e_above_mp_hull"]
127+
)
128+
129+
ax_hull_dist_hist = hist_classified_stable_as_func_of_hull_dist(
130+
e_above_hull_pred=df_m3gnet.e_above_hull_pred,
131+
e_above_hull_true=df_m3gnet.e_above_mp_hull,
126132
)
127133

128134
# ax_hull_dist_hist.figure.savefig(f"{ROOT}/plots/{today}-m3gnet-wbm-hull-dist-hist.pdf")

mb_discovery/m3gnet/slurm_array_m3gnet_relax_wbm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import m3gnet
1010
import numpy as np
1111
import pandas as pd
12+
import wandb
1213
from m3gnet.models import Relaxer
1314

14-
import wandb
1515
from mb_discovery import ROOT, as_dict_handler
1616

1717

@@ -20,7 +20,7 @@
2020
2121
```sh
2222
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-101 \
23-
--time 3:0:0 --job-name m3gnet-relax-wbm-RS2RE --mem 12000 \
23+
--time 3:0:0 --job-name m3gnet-wbm-relax-RS2RE --mem 12000 \
2424
--output mb_discovery/m3gnet/slurm_logs/slurm-%A-%a.out \
2525
--wrap "python mb_discovery/m3gnet/slurm_array_m3gnet_relax_wbm.py"
2626
```
@@ -48,7 +48,7 @@
4848
print(f"{job_array_id=}")
4949

5050
today = f"{datetime.now():%Y-%m-%d}"
51-
out_dir = f"{ROOT}/data/{today}-m3gnet-relax-wbm-{task_type}"
51+
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-relax-{task_type}"
5252
os.makedirs(out_dir, exist_ok=True)
5353
json_out_path = f"{out_dir}/{job_array_id}.json.gz"
5454

@@ -77,7 +77,7 @@
7777
wandb.login()
7878
wandb.init(
7979
project="m3gnet", # run will be added to this project
80-
name=f"m3gnet-relax-wbm-{task_type}-{job_id}-{job_array_id}",
80+
name=f"m3gnet-wbm-relax-{task_type}-{job_id}-{job_array_id}",
8181
config=run_params,
8282
)
8383

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# %%
22
from datetime import datetime
3-
from typing import Literal
43

54
import matplotlib.pyplot as plt
65
import pandas as pd
76

87
from mb_discovery import ROOT
98
from mb_discovery.plot_scripts.plot_funcs import (
9+
StabilityCriterion,
10+
WhichEnergy,
1011
hist_classified_stable_as_func_of_hull_dist,
1112
)
1213

@@ -55,27 +56,34 @@
5556
assert all(nan_counts == 0), f"df should not have missing values: {nan_counts}"
5657

5758
target_col = "e_form_target"
58-
stability_crit: Literal["energy", "energy+std", "energy-std"] = "energy"
59-
energy_type: Literal["true", "pred"] = "true"
60-
59+
stability_crit: StabilityCriterion = "energy"
60+
which_energy: WhichEnergy = "true"
61+
62+
if "std" in stability_crit:
63+
# TODO column names to compute standard deviation from are currently hardcoded
64+
# needs to be updated when adding non-aviary models with uncertainty estimation
65+
var_aleatoric = (df.filter(like="_ale_") ** 2).mean(axis=1)
66+
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
67+
std_total = (var_epistemic + var_aleatoric) ** 0.5
68+
else:
69+
std_total = None
6170

6271
# make sure we average the expected number of ensemble member predictions
6372
pred_cols = df.filter(regex=r"_pred_\d").columns
6473
assert len(pred_cols) == 10
6574

6675
ax = hist_classified_stable_as_func_of_hull_dist(
67-
df,
68-
target_col,
69-
pred_cols,
70-
e_above_hull_col="e_above_mp_hull",
71-
energy_type=energy_type,
76+
e_above_hull_pred=df[pred_cols].mean(axis=1) - df[target_col],
77+
e_above_hull_true=df.e_above_mp_hull,
78+
which_energy=which_energy,
7279
stability_crit=stability_crit,
80+
std_pred=std_total,
7381
)
7482

7583
ax.figure.set_size_inches(10, 9)
7684

7785
ax.legend(loc="upper left", frameon=False)
7886

79-
fig_name = f"wren-wbm-hull-dist-hist-{energy_type=}-{stability_crit=}"
87+
fig_name = f"wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}"
8088
img_path = f"{ROOT}/figures/{today}-{fig_name}.pdf"
8189
# plt.savefig(img_path)

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from mb_discovery import ROOT
88
from mb_discovery.plot_scripts.plot_funcs import (
9+
StabilityCriterion,
10+
WhichEnergy,
911
hist_classified_stable_as_func_of_hull_dist,
1012
)
1113

@@ -50,34 +52,39 @@
5052

5153

5254
# %%
53-
energy_type = "true"
54-
stability_crit = "energy"
55+
which_energy: WhichEnergy = "true"
56+
stability_crit: StabilityCriterion = "energy"
5557
df["wbm_batch"] = df.index.str.split("-").str[2]
5658
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
5759

5860
# make sure we average the expected number of ensemble member predictions
5961
pred_cols = df.filter(regex=r"_pred_\d").columns
6062
assert len(pred_cols) == 10
6163

62-
common_kwargs = dict(
63-
target_col="e_form_target",
64-
pred_cols=pred_cols,
65-
energy_type=energy_type,
66-
stability_crit=stability_crit,
67-
e_above_hull_col="e_above_mp_hull",
68-
)
6964

7065
for (batch_idx, batch_df), ax in zip(df.groupby("wbm_batch"), axs.flat):
71-
hist_classified_stable_as_func_of_hull_dist(batch_df, ax=ax, **common_kwargs)
66+
hist_classified_stable_as_func_of_hull_dist(
67+
e_above_hull_pred=batch_df[pred_cols].mean(axis=1) - batch_df.e_form_target,
68+
e_above_hull_true=batch_df.e_above_mp_hull,
69+
which_energy=which_energy,
70+
stability_crit=stability_crit,
71+
ax=ax,
72+
)
7273

7374
title = f"Batch {batch_idx} ({len(df):,})"
7475
ax.set(title=title)
7576

7677

77-
hist_classified_stable_as_func_of_hull_dist(df, ax=axs.flat[-1], **common_kwargs)
78+
hist_classified_stable_as_func_of_hull_dist(
79+
e_above_hull_pred=df[pred_cols].mean(axis=1),
80+
e_above_hull_true=df.e_above_mp_hull,
81+
which_energy=which_energy,
82+
stability_crit=stability_crit,
83+
ax=axs.flat[-1],
84+
)
7885

7986
axs.flat[-1].set(title=f"Combined {batch_idx} ({len(df):,})")
8087
axs.flat[0].legend(frameon=False, loc="upper left")
8188

82-
img_name = f"{today}-wren-wbm-hull-dist-hist-{energy_type=}-{stability_crit=}.pdf"
89+
img_name = f"{today}-wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf"
8390
# plt.savefig(f"{ROOT}/figures/{img_name}")

0 commit comments

Comments
 (0)