Skip to content

Commit e54a8c6

Browse files
committed
add test_precision_recall_vs_calc_count()
setup.py specify python_requires=">=3.8" raise flake8 max-complexity = 16
1 parent 39fa65a commit e54a8c6

5 files changed

+157
-58
lines changed

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
assert all(nan_counts == 0), f"df should not have missing values: {nan_counts}"
5656

5757
target_col = "e_form_target"
58-
criterion: Literal["energy", "std", "neg_std"] = "energy"
58+
stability_crit: Literal["energy", "energy+std", "energy-std"] = "energy"
5959
energy_type: Literal["true", "pred"] = "true"
6060

6161

@@ -69,14 +69,13 @@
6969
pred_cols,
7070
e_above_hull_col="e_above_mp_hull",
7171
energy_type=energy_type,
72-
criterion=criterion,
72+
stability_crit=stability_crit,
7373
)
7474

7575
ax.figure.set_size_inches(10, 9)
7676

7777
ax.legend(loc="upper left", frameon=False)
7878

79-
img_path = (
80-
f"{ROOT}/figures/{today}-wren-wbm-hull-dist-hist-{energy_type=}-{criterion=}.pdf"
81-
)
79+
fig_name = f"wren-wbm-hull-dist-hist-{energy_type=}-{stability_crit=}"
80+
img_path = f"{ROOT}/figures/{today}-{fig_name}.pdf"
8281
# plt.savefig(img_path)

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@
5050

5151

5252
# %%
53-
assert df.e_above_mp_hull.isna().sum() == 0
54-
5553
energy_type = "true"
56-
criterion = "energy"
54+
stability_crit = "energy"
5755
df["wbm_batch"] = df.index.str.split("-").str[2]
5856
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
5957

@@ -65,7 +63,7 @@
6563
target_col="e_form_target",
6664
pred_cols=pred_cols,
6765
energy_type=energy_type,
68-
criterion=criterion,
66+
stability_crit=stability_crit,
6967
e_above_hull_col="e_above_mp_hull",
7068
)
7169

@@ -81,5 +79,5 @@
8179
axs.flat[-1].set(title=f"Combined {batch_idx} ({len(df):,})")
8280
axs.flat[0].legend(frameon=False, loc="upper left")
8381

84-
img_name = f"{today}-wren-wbm-hull-dist-hist-{energy_type=}-{criterion=}.pdf"
82+
img_name = f"{today}-wren-wbm-hull-dist-hist-{energy_type=}-{stability_crit=}.pdf"
8583
# plt.savefig(f"{ROOT}/figures/{img_name}")

mb_discovery/plot_scripts/plot_funcs.py

+63-46
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Literal, Sequence
3+
from typing import Any, Literal, Sequence, get_args
44

55
import matplotlib.pyplot as plt
66
import numpy as np
@@ -13,6 +13,7 @@
1313
__author__ = "Janosh Riebesell"
1414
__date__ = "2022-08-05"
1515

16+
StabilityCriterion = Literal["energy", "energy+std", "energy-std"]
1617

1718
plt.rc("savefig", bbox="tight", dpi=200)
1819
plt.rcParams["figure.constrained_layout.use"] = True
@@ -27,10 +28,10 @@ def hist_classified_stable_as_func_of_hull_dist(
2728
e_above_hull_col: str,
2829
ax: plt.Axes = None,
2930
energy_type: Literal["true", "pred"] = "true",
30-
criterion: Literal["energy", "std", "neg_std"] = "energy",
31+
stability_crit: StabilityCriterion = "energy",
3132
show_mae: bool = False,
32-
stability_thresh: float = 0, # set stability threshold as distance to convex hull
33-
# in eV / atom, usually 0 or 0.1 eV
33+
stability_threshold: float = 0, # set stability threshold as distance to convex
34+
# hull in eV / atom, usually 0 or 0.1 eV
3435
x_lim: tuple[float, float] = (-0.4, 0.4),
3536
) -> plt.Axes:
3637
"""
@@ -52,28 +53,28 @@ def hist_classified_stable_as_func_of_hull_dist(
5253

5354
error = df[pred_cols].mean(axis=1) - df[target_col]
5455
e_above_hull_vals = df[e_above_hull_col]
55-
mean = error + e_above_hull_vals
56+
residuals = error + e_above_hull_vals
5657

57-
if criterion == "energy":
58-
test = mean
59-
elif "std" in criterion:
58+
if stability_crit == "energy":
59+
test = residuals
60+
elif "std" in stability_crit:
6061
# TODO column names to compute standard deviation from are currently hardcoded
6162
# needs to be updated when adding non-aviary models with uncertainty estimation
6263
var_aleatoric = (df.filter(like="_ale_") ** 2).mean(axis=1)
6364
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
6465
std_total = (var_epistemic + var_aleatoric) ** 0.5
6566

66-
if criterion == "std":
67+
if stability_crit == "energy+std":
6768
test += std_total
68-
elif criterion == "neg_std":
69+
elif stability_crit == "energy-std":
6970
test -= std_total
7071

7172
# --- histogram by DFT-computed distance to convex hull
7273
if energy_type == "true":
73-
actual_pos = e_above_hull_vals <= stability_thresh
74-
actual_neg = e_above_hull_vals > stability_thresh
75-
model_pos = test <= stability_thresh
76-
model_neg = test > stability_thresh
74+
actual_pos = e_above_hull_vals <= stability_threshold
75+
actual_neg = e_above_hull_vals > stability_threshold
76+
model_pos = test <= stability_threshold
77+
model_neg = test > stability_threshold
7778

7879
n_true_pos = len(e_above_hull_vals[actual_pos & model_pos])
7980
n_false_neg = len(e_above_hull_vals[actual_pos & model_neg])
@@ -89,10 +90,10 @@ def hist_classified_stable_as_func_of_hull_dist(
8990

9091
# --- histogram by model-predicted distance to convex hull
9192
if energy_type == "pred":
92-
true_pos = mean[actual_pos & model_pos]
93-
false_neg = mean[actual_pos & model_neg]
94-
false_pos = mean[actual_neg & model_pos]
95-
true_neg = mean[actual_neg & model_neg]
93+
true_pos = residuals[actual_pos & model_pos]
94+
false_neg = residuals[actual_pos & model_neg]
95+
false_pos = residuals[actual_neg & model_pos]
96+
true_neg = residuals[actual_neg & model_neg]
9697
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
9798

9899
ax.hist(
@@ -163,6 +164,10 @@ def rolling_mae_vs_hull_dist(
163164
if ax is None:
164165
ax = plt.gca()
165166

167+
for col in (residual_col, e_above_hull_col):
168+
n_nans = df[col].isna().sum()
169+
assert n_nans == 0, f"{n_nans} NaNs in {col} column"
170+
166171
is_fresh_ax = len(ax.lines) == 0
167172

168173
bins = np.arange(*x_lim, increment)
@@ -256,9 +261,9 @@ def precision_recall_vs_calc_count(
256261
df: pd.DataFrame,
257262
residual_col: str = "residual",
258263
e_above_hull_col: str = "e_above_hull",
259-
criterion: Literal["energy", "std", "neg_std"] = "energy",
260-
stability_thresh: float = 0, # set stability threshold as distance to convex hull
261-
# in eV / atom, usually 0 or 0.1 eV
264+
stability_crit: StabilityCriterion = "energy",
265+
stability_threshold: float = 0, # set stability threshold as distance to convex
266+
# hull in eV / atom, usually 0 or 0.1 eV
262267
ax: plt.Axes = None,
263268
label: str = None,
264269
intersect_lines: str | Sequence[str] = (),
@@ -272,10 +277,10 @@ def precision_recall_vs_calc_count(
272277
i.e. residual = pred - target. Defaults to "residual".
273278
e_above_hull_col (str, optional): Column name with convex hull distance values.
274279
Defaults to "e_above_hull".
275-
criterion (Literal['energy', 'std', 'neg_std'], optional): Whether to use
276-
energy, energy+model_std, or energy-model_std as stability criterion.
277-
Defaults to "energy".
278-
stability_thresh (float, optional): Max distance from convex hull before
280+
stability_crit ('energy' | 'energy+std' | 'energy-std', optional): Whether to
281+
use energy+/-std as stability stability_crit where std is the model
282+
predicted uncertainty for the energy it stipulated. Defaults to "energy".
283+
stability_threshold (float, optional): Max distance from convex hull before
279284
material is considered unstable. Defaults to 0.
280285
label (str, optional): Model name used to identify its liens in the legend.
281286
Defaults to None.
@@ -288,36 +293,43 @@ def precision_recall_vs_calc_count(
288293
if ax is None:
289294
ax = plt.gca()
290295

296+
for col in (residual_col, e_above_hull_col):
297+
n_nans = df[col].isna().sum()
298+
assert n_nans == 0, f"{n_nans} NaNs in {col} column"
299+
291300
is_fresh_ax = len(ax.lines) == 0
292301

293302
df = df.sort_values(by="residual")
303+
residuals = df[residual_col]
294304

295-
if criterion == "energy":
296-
test = df[residual_col]
297-
elif "std" in criterion:
305+
if stability_crit not in get_args(StabilityCriterion):
306+
raise ValueError(
307+
f"Invalid {stability_crit=} must be one of {get_args(StabilityCriterion)}"
308+
)
309+
if "std" in stability_crit:
298310
# TODO column names to compute standard deviation from are currently hardcoded
299311
# needs to be updated when adding non-aviary models with uncertainty estimation
300312
var_aleatoric = (df.filter(like="_ale_") ** 2).mean(axis=1)
301313
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
302314
std_total = (var_epistemic + var_aleatoric) ** 0.5
303315

304-
if criterion == "std":
305-
test += std_total
306-
elif criterion == "neg_std":
307-
test -= std_total
316+
if stability_crit == "energy+std":
317+
residuals += std_total
318+
elif stability_crit == "energy-std":
319+
residuals -= std_total
308320

309-
# stability_thresh = 0.02
310-
stability_thresh = 0
311-
# stability_thresh = 0.10
321+
# stability_threshold = 0.02
322+
stability_threshold = 0
323+
# stability_threshold = 0.10
312324

313-
true_pos_mask = (df[e_above_hull_col] <= stability_thresh) & (
314-
df.residual <= stability_thresh
325+
true_pos_mask = (df[e_above_hull_col] <= stability_threshold) & (
326+
df.residual <= stability_threshold
315327
)
316-
false_neg_mask = (df[e_above_hull_col] <= stability_thresh) & (
317-
df.residual > stability_thresh
328+
false_neg_mask = (df[e_above_hull_col] <= stability_threshold) & (
329+
df.residual > stability_threshold
318330
)
319-
false_pos_mask = (df[e_above_hull_col] > stability_thresh) & (
320-
df.residual <= stability_thresh
331+
false_pos_mask = (df[e_above_hull_col] > stability_threshold) & (
332+
df.residual <= stability_threshold
321333
)
322334

323335
true_pos_cumsum = true_pos_mask.cumsum()
@@ -349,11 +361,16 @@ def precision_recall_vs_calc_count(
349361

350362
if intersect_lines == "all":
351363
intersect_lines = ("precision_xy", "recall_xy")
364+
if isinstance(intersect_lines, str):
365+
intersect_lines = [intersect_lines]
352366
for line_name in intersect_lines:
353-
y_func = dict(
354-
precision=precision_curve,
355-
recall=rolling_recall_curve,
356-
)[line_name.split("_")[0]]
367+
try:
368+
line_name_map = dict(precision=precision_curve, recall=rolling_recall_curve)
369+
y_func = line_name_map[line_name.split("_")[0]]
370+
except KeyError:
371+
raise ValueError(
372+
f"Invalid {intersect_lines=}, must be one of {list(line_name_map)}"
373+
)
357374
intersect_kwargs = dict(
358375
linestyle=":", alpha=0.4, color=kwargs.get("color", "gray")
359376
)
@@ -370,7 +387,7 @@ def precision_recall_vs_calc_count(
370387

371388
ax.set(xlabel="Number of Calculations", ylabel="Percentage")
372389

373-
ax.set(xlim=(0, 8e4), ylim=(0, 100))
390+
ax.set(ylim=(0, 100))
374391

375392
[precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
376393
[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@
5252
):
5353
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
5454

55-
assert df.e_above_mp_hull.isna().sum() == 0
56-
5755
target_col = "e_form_target"
5856
rare = "all"
5957

tests/test_plot_funcs.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Sequence
4+
5+
import pandas as pd
6+
import pytest
7+
8+
from mb_discovery import ROOT
9+
from mb_discovery.plot_scripts.plot_funcs import precision_recall_vs_calc_count
10+
11+
12+
df_hull = pd.read_csv(
13+
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
14+
).set_index("material_id")
15+
16+
test_dfs: dict[str, pd.DataFrame] = {}
17+
for model_name in ("Wren", "CGCNN", "Voronoi"):
18+
df = pd.read_csv(
19+
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv",
20+
nrows=100,
21+
).set_index("material_id")
22+
23+
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
24+
25+
test_dfs[model_name] = df
26+
27+
28+
@pytest.mark.parametrize(
29+
"intersect_lines, stability_crit, stability_threshold, expected_line_count",
30+
[
31+
((), "energy", 0, 11),
32+
("precision_x", "energy+std", 0, 23),
33+
(["recall_y"], "energy", -0.1, 35),
34+
("all", "energy-std", 0.1, 56),
35+
],
36+
)
37+
def test_precision_recall_vs_calc_count(
38+
intersect_lines: str | Sequence[str],
39+
stability_crit: str,
40+
stability_threshold: float,
41+
expected_line_count: int,
42+
) -> None:
43+
ax = None
44+
45+
for (model_name, df), color in zip(
46+
test_dfs.items(), ("tab:blue", "tab:orange", "tab:pink")
47+
):
48+
model_preds = df.filter(like=r"_pred").mean(axis=1)
49+
targets = df.e_form_target
50+
51+
df["residual"] = model_preds - targets + df.e_above_mp_hull
52+
53+
ax = precision_recall_vs_calc_count(
54+
df,
55+
residual_col="residual",
56+
e_above_hull_col="e_above_mp_hull",
57+
color=color,
58+
label=model_name,
59+
intersect_lines=intersect_lines,
60+
stability_crit=stability_crit, # type: ignore[arg-type]
61+
stability_threshold=stability_threshold,
62+
ax=ax,
63+
)
64+
65+
assert ax is not None
66+
assert len(ax.lines) == expected_line_count
67+
assert ax.get_ylim() == (0, 100)
68+
assert ax.get_xlim() == pytest.approx((-1.4, 29.4))
69+
70+
71+
@pytest.mark.parametrize(
72+
"kwargs, expected_exc, match_pat",
73+
[
74+
(dict(intersect_lines="INVALID"), ValueError, "Invalid intersect_lines="),
75+
(dict(stability_crit="INVALID"), ValueError, "Invalid stability_crit="),
76+
],
77+
)
78+
def test_precision_recall_vs_calc_count_raises(
79+
kwargs: dict[str, Any], expected_exc: type[Exception], match_pat: str
80+
) -> None:
81+
with pytest.raises(expected_exc, match=match_pat):
82+
precision_recall_vs_calc_count(
83+
test_dfs["Wren"],
84+
residual_col="residual",
85+
e_above_hull_col="e_above_mp_hull",
86+
**kwargs,
87+
)

0 commit comments

Comments
 (0)