Skip to content

Commit 17df9d0

Browse files
committed
add test_rolling_mae_vs_hull_dist() to test_plot_funcs.py
1 parent c369d18 commit 17df9d0

File tree

2 files changed

+52
-11
lines changed

2 files changed

+52
-11
lines changed

mb_discovery/plot_scripts/plot_funcs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ def hist_classified_stable_as_func_of_hull_dist(
9090
false_neg = e_above_hull_vals[actual_pos & model_neg]
9191
false_pos = e_above_hull_vals[actual_neg & model_pos]
9292
true_neg = e_above_hull_vals[actual_neg & model_neg]
93-
xlabel = r"$\Delta E_{Hull-MP}$ / eV per atom"
93+
xlabel = r"$\Delta E_{Hull-MP}$ (eV / atom)"
9494

9595
# --- histogram by model-predicted distance to convex hull
9696
if energy_type == "pred":
9797
true_pos = residuals[actual_pos & model_pos]
9898
false_neg = residuals[actual_pos & model_neg]
9999
false_pos = residuals[actual_neg & model_pos]
100100
true_neg = residuals[actual_neg & model_neg]
101-
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
101+
xlabel = r"$\Delta E_{Hull-Pred}$ (eV / atom)"
102102

103103
ax.hist(
104104
[true_pos, false_neg, false_pos, true_neg],
@@ -153,7 +153,7 @@ def rolling_mae_vs_hull_dist(
153153
e_above_hull_col: str,
154154
residual_col: str = "residual",
155155
half_window: float = 0.02,
156-
increment: float = 0.002,
156+
bin_width: float = 0.002,
157157
x_lim: tuple[float, float] = (-0.2, 0.3),
158158
ax: plt.Axes = None,
159159
**kwargs: Any,
@@ -174,7 +174,7 @@ def rolling_mae_vs_hull_dist(
174174

175175
is_fresh_ax = len(ax.lines) == 0
176176

177-
bins = np.arange(*x_lim, increment)
177+
bins = np.arange(*x_lim, bin_width)
178178

179179
rolling_maes = np.zeros_like(bins)
180180
rolling_stds = np.zeros_like(bins)
@@ -254,7 +254,7 @@ def rolling_mae_vs_hull_dist(
254254

255255
ax.text(0, 0.13, r"$|\Delta E_{Hull-MP}| > $MAE", horizontalalignment="center")
256256

257-
ax.set(xlabel=r"$\Delta E_{Hull-MP}$ / eV per atom", ylabel="MAE / eV per atom")
257+
ax.set(xlabel=r"$\Delta E_{Hull-MP}$ (eV / atom)", ylabel="MAE (eV / atom)")
258258

259259
ax.set(xlim=x_lim, ylim=(0.0, 0.14))
260260

@@ -389,7 +389,7 @@ def precision_recall_vs_calc_count(
389389
# previous call
390390
return ax
391391

392-
ax.set(xlabel="Number of Calculations", ylabel="Percentage")
392+
ax.set(xlabel="Number of Calculations", ylabel="Precision and Recall (%)")
393393

394394
ax.set(ylim=(0, 100))
395395

tests/test_plot_funcs.py

+46-5
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
from typing import Any, Sequence
44

5+
import matplotlib.pyplot as plt
56
import pandas as pd
67
import pytest
78

89
from mb_discovery import ROOT
9-
from mb_discovery.plot_scripts.plot_funcs import precision_recall_vs_calc_count
10+
from mb_discovery.plot_scripts.plot_funcs import (
11+
precision_recall_vs_calc_count,
12+
rolling_mae_vs_hull_dist,
13+
)
1014

1115

1216
DATA_DIR = f"{ROOT}/data/2022-06-11-from-rhys"
@@ -28,9 +32,9 @@
2832
"intersect_lines, stability_crit, stability_threshold, expected_line_count",
2933
[
3034
((), "energy", 0, 11),
31-
("precision_x", "energy+std", 0, 23),
32-
(["recall_y"], "energy", -0.1, 35),
33-
("all", "energy-std", 0.1, 56),
35+
("precision_x", "energy+std", 0, 14),
36+
(["recall_y"], "energy", -0.1, 14),
37+
("all", "energy-std", 0.1, 23),
3438
],
3539
)
3640
def test_precision_recall_vs_calc_count(
@@ -39,7 +43,7 @@ def test_precision_recall_vs_calc_count(
3943
stability_threshold: float,
4044
expected_line_count: int,
4145
) -> None:
42-
ax = None
46+
ax = plt.figure().gca() # ensure test functions use different axes
4347

4448
for (model_name, df), color in zip(
4549
test_dfs.items(), ("tab:blue", "tab:orange", "tab:pink")
@@ -66,6 +70,9 @@ def test_precision_recall_vs_calc_count(
6670
assert ax.get_ylim() == (0, 100)
6771
assert ax.get_xlim() == pytest.approx((-1.4, 29.4))
6872

73+
assert ax.get_xlabel() == "Number of Calculations"
74+
assert ax.get_ylabel() == "Precision and Recall (%)"
75+
6976

7077
@pytest.mark.parametrize(
7178
"kwargs, expected_exc, match_pat",
@@ -84,3 +91,37 @@ def test_precision_recall_vs_calc_count_raises(
8491
e_above_hull_col="e_above_mp_hull",
8592
**kwargs,
8693
)
94+
95+
96+
@pytest.mark.parametrize("half_window", (0.02, 0.002))
97+
@pytest.mark.parametrize("bin_width", (0.1, 0.001))
98+
@pytest.mark.parametrize("x_lim", ((0, 0.6), (-0.2, 0.8)))
99+
def test_rolling_mae_vs_hull_dist(
100+
half_window: float, bin_width: float, x_lim: tuple[float, float]
101+
) -> None:
102+
ax = plt.figure().gca() # ensure test functions use different axes
103+
104+
for (model_name, df), color in zip(
105+
test_dfs.items(), ("tab:blue", "tab:orange", "tab:pink")
106+
):
107+
model_preds = df.filter(like=r"_pred").mean(axis=1)
108+
targets = df.e_form_target
109+
110+
df["residual"] = model_preds - targets + df.e_above_mp_hull
111+
112+
ax = rolling_mae_vs_hull_dist(
113+
df,
114+
residual_col="residual",
115+
e_above_hull_col="e_above_mp_hull",
116+
color=color,
117+
label=model_name,
118+
ax=ax,
119+
x_lim=x_lim,
120+
half_window=half_window,
121+
bin_width=bin_width,
122+
)
123+
124+
assert ax is not None
125+
assert ax.get_ylim() == pytest.approx((0, 0.14))
126+
assert ax.get_ylabel() == "MAE (eV / atom)"
127+
assert ax.get_xlabel() == r"$\Delta E_{Hull-MP}$ (eV / atom)"

0 commit comments

Comments
 (0)