Skip to content

Commit f25d91f

Browse files
committed
rename e_above_mp_hull->e_above_hull_mp
1 parent a19ca5d commit f25d91f

9 files changed

+26
-25
lines changed

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
3737
).set_index("material_id")
3838

39-
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
39+
df["e_above_hull_mp"] = df_hull.e_above_hull_mp
4040

4141

4242
# %%
@@ -62,7 +62,7 @@
6262

6363
ax = hist_classified_stable_as_func_of_hull_dist(
6464
e_above_hull_pred=df[pred_cols].mean(axis=1) - df[target_col],
65-
e_above_hull_true=df.e_above_mp_hull,
65+
e_above_hull_true=df.e_above_hull_mp,
6666
which_energy=which_energy,
6767
stability_crit=stability_crit,
6868
std_pred=std_total,

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
model_name = "m3gnet"
8080
df = dfs[model_name]
8181

82-
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
82+
df["e_above_hull_mp"] = df_hull.e_above_hull_mp
8383
df["e_form_per_atom"] = df_wbm.e_form_per_atom
8484

8585

@@ -89,7 +89,7 @@
8989

9090
hist_classified_stable_as_func_of_hull_dist(
9191
e_above_hull_pred=batch_df.e_form_per_atom_pred - batch_df.e_form_per_atom,
92-
e_above_hull_true=batch_df.e_above_mp_hull,
92+
e_above_hull_true=batch_df.e_above_hull_mp,
9393
which_energy=which_energy,
9494
stability_crit=stability_crit,
9595
ax=ax,
@@ -101,7 +101,7 @@
101101

102102
hist_classified_stable_as_func_of_hull_dist(
103103
e_above_hull_pred=df.e_form_per_atom_pred - df.e_form_per_atom,
104-
e_above_hull_true=df.e_above_mp_hull,
104+
e_above_hull_true=df.e_above_hull_mp,
105105
which_energy=which_energy,
106106
stability_crit=stability_crit,
107107
ax=axs.flat[-1],

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@
8585
except AttributeError as exc:
8686
raise KeyError(f"{model_name = }") from exc
8787

88-
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
88+
df["e_above_hull_mp"] = df_hull.e_above_hull_mp
8989
df["e_form_per_atom"] = df_wbm.e_form_per_atom
9090
df["e_above_hull_pred"] = model_preds - df.e_form_per_atom
9191
if n_nans := df.isna().values.sum() > 0:
9292
assert n_nans < 10, f"{model_name=} has {n_nans=}"
9393
df = df.dropna()
9494

95-
F1 = f1_score(df.e_above_mp_hull < 0, df.e_above_hull_pred < 0)
95+
F1 = f1_score(df.e_above_hull_mp < 0, df.e_above_hull_pred < 0)
9696
F1s[model_name] = F1
9797

9898

@@ -101,8 +101,8 @@
101101
df = dfs[model_name]
102102

103103
ax = precision_recall_vs_calc_count(
104-
e_above_hull_error=df.e_above_hull_pred + df.e_above_mp_hull,
105-
e_above_hull_true=df.e_above_mp_hull,
104+
e_above_hull_error=df.e_above_hull_pred + df.e_above_hull_mp,
105+
e_above_hull_true=df.e_above_hull_mp,
106106
color=color,
107107
label=f"{model_name} {F1=:.2}",
108108
intersect_lines="recall_xy", # or "precision_xy", None, 'all'
@@ -113,7 +113,7 @@
113113
# optimal recall line finds all stable materials without any false positives
114114
# can be included to confirm all models start out of with near optimal recall
115115
# and to see how much each model overshoots total n_stable
116-
n_below_hull = sum(df_hull.e_above_mp_hull < 0)
116+
n_below_hull = sum(df_hull.e_above_hull_mp < 0)
117117
ax.plot(
118118
[0, n_below_hull],
119119
[0, 100],

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
4040
).set_index("material_id")
4141

42-
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
42+
df["e_above_hull_mp"] = df_hull.e_above_hull_mp
4343

4444
assert all(n_nans := df.isna().sum() == 0), f"Found {n_nans} NaNs"
4545

@@ -59,7 +59,7 @@
5959
# %%
6060
ax = rolling_mae_vs_hull_dist(
6161
e_above_hull_pred=df.e_above_hull_pred,
62-
e_above_hull_true=df.e_above_mp_hull,
62+
e_above_hull_true=df.e_above_hull_mp,
6363
label=legend_label,
6464
)
6565

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
2929
).set_index("material_id")
3030

31-
df_wrenformer["e_above_mp_hull"] = df_hull.e_above_mp_hull
32-
assert df_wrenformer.e_above_mp_hull.isna().sum() == 0
31+
df_wrenformer["e_above_hull_mp"] = df_hull.e_above_hull_mp
32+
assert df_wrenformer.e_above_hull_mp.isna().sum() == 0
3333

3434
target_col = "e_form_per_atom"
3535
# target_col = "e_form_target"
@@ -54,7 +54,7 @@
5454

5555
rolling_mae_vs_hull_dist(
5656
e_above_hull_pred=df.e_above_hull_pred,
57-
e_above_hull_true=df.e_above_mp_hull,
57+
e_above_hull_true=df.e_above_hull_mp,
5858
ax=ax,
5959
label=title,
6060
marker=marker,

mb_discovery/plots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
e_form="Formation energy (eV/atom)",
3232
e_above_hull="Energy above convex hull (eV/atom)",
3333
e_above_hull_pred="Predicted energy above convex hull (eV/atom)",
34-
e_above_mp_hull="Energy above MP convex hull (eV/atom)",
34+
e_above_hull_mp="Energy above MP convex hull (eV/atom)",
3535
e_above_hull_error="Error in energy above convex hull (eV/atom)",
3636
)
3737
model_labels = dict(

mb_discovery/slurm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ def slurm_submit_python(
2828
array: str = "",
2929
pre_cmd: str = "",
3030
) -> None:
31-
"""Slurm submit a python script using sbatch --wrap 'python path/to/file.py' by
32-
calling this function in the script and invoking the script with
33-
`python path/to/file.py slurm-submit`.
31+
"""Slurm submits a python script using `sbatch --wrap 'python path/to/file.py'`.
32+
Usage: Call this function at the top of the script (before doing any real work) and
33+
then submit a job with `python path/to/file.py slurm-submit`. The slurm job will run
34+
the whole script.
3435
3536
Args:
3637
job_name (str): Slurm job name.

models/m3gnet/join_m3gnet_relax_results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
7979
).set_index("material_id")
8080

81-
df_m3gnet["e_above_mp_hull"] = df_hull.e_above_mp_hull
81+
df_m3gnet["e_above_hull_mp"] = df_hull.e_above_hull_mp
8282

8383

8484
# %%

tests/test_plots.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
f"{DATA_DIR}/{model_name.lower()}-mp-initial-structures.csv", nrows=100
2626
).set_index("material_id")
2727

28-
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
28+
df["e_above_hull_mp"] = df_hull.e_above_hull_mp
2929

3030
model_preds = df.filter(like=r"_pred").mean(axis=1)
3131

@@ -62,7 +62,7 @@ def test_precision_recall_vs_calc_count(
6262

6363
ax = precision_recall_vs_calc_count(
6464
e_above_hull_error=df.e_above_hull_pred,
65-
e_above_hull_true=df.e_above_mp_hull,
65+
e_above_hull_true=df.e_above_hull_mp,
6666
color=color,
6767
label=model_name,
6868
intersect_lines=intersect_lines,
@@ -96,7 +96,7 @@ def test_precision_recall_vs_calc_count_raises(
9696
with pytest.raises(expected_exc, match=match_pat):
9797
precision_recall_vs_calc_count(
9898
e_above_hull_error=test_dfs["Wren"].e_above_hull_pred,
99-
e_above_hull_true=test_dfs["Wren"].e_above_mp_hull,
99+
e_above_hull_true=test_dfs["Wren"].e_above_hull_mp,
100100
**kwargs,
101101
)
102102

@@ -114,7 +114,7 @@ def test_rolling_mae_vs_hull_dist(
114114
):
115115
ax = rolling_mae_vs_hull_dist(
116116
e_above_hull_pred=df.e_above_hull_pred,
117-
e_above_hull_true=df.e_above_mp_hull,
117+
e_above_hull_true=df.e_above_hull_mp,
118118
color=color,
119119
label=model_name,
120120
ax=ax,
@@ -150,7 +150,7 @@ def test_hist_classified_stable_as_func_of_hull_dist(
150150

151151
ax = hist_classified_stable_as_func_of_hull_dist(
152152
e_above_hull_pred=df.e_above_hull_pred,
153-
e_above_hull_true=df.e_above_mp_hull,
153+
e_above_hull_true=df.e_above_hull_mp,
154154
ax=ax,
155155
stability_threshold=stability_threshold,
156156
stability_crit=stability_crit,

0 commit comments

Comments
 (0)