Skip to content

Commit 86f85f3

Browse files
committed
rename plot func precision_recall_vs_calc_count() to cumulative_clf_metric() that plots single metric at a time
1 parent 4936079 commit 86f85f3

File tree

3 files changed

+110
-124
lines changed

3 files changed

+110
-124
lines changed

matbench_discovery/plot_scripts/precision_recall_vs_calc_count.py matbench_discovery/plot_scripts/precision_recall.py

+33-34
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# %%
22
from datetime import datetime
33

4+
import matplotlib.pyplot as plt
45
import pandas as pd
56
from sklearn.metrics import f1_score
67

78
from matbench_discovery import ROOT
89
from matbench_discovery.plot_scripts import df_wbm
9-
from matbench_discovery.plots import StabilityCriterion, precision_recall_vs_calc_count
10+
from matbench_discovery.plots import StabilityCriterion, cumulative_clf_metric
1011

1112
__author__ = "Rhys Goodall, Janosh Riebesell"
1213

@@ -46,13 +47,6 @@
4647
F1s: dict[str, float] = {}
4748

4849
for model_name, df in dfs.items():
49-
# from pymatgen.core import Composition
50-
# rare = "no-lanthanides"
51-
# df["contains_rare_earths"] = df.composition.map(
52-
# lambda x: any(el.is_rare_earth_metal for el in Composition(x))
53-
# )
54-
# df = df.query("~contains_rare_earths")
55-
5650
if "std" in stability_crit:
5751
# TODO column names to compute standard deviation from are currently hardcoded
5852
# needs to be updated when adding non-aviary models with uncertainty estimation
@@ -91,42 +85,47 @@
9185

9286

9387
# %%
88+
fig, (ax_prec, ax_recall) = plt.subplots(1, 2, figsize=(15, 7), sharey=True)
89+
9490
for (model_name, F1), color in zip(sorted(F1s.items(), key=lambda x: x[1]), colors):
9591
df = dfs[model_name]
92+
e_above_hull_error = df.e_above_hull_pred + df.e_above_hull_mp
93+
e_above_hull_true = df.e_above_hull_mp
94+
cumulative_clf_metric(
95+
e_above_hull_error,
96+
e_above_hull_true,
97+
color=color,
98+
label=f"{model_name}\n{F1=:.2}",
99+
project_end_point="xy",
100+
stability_crit=stability_crit,
101+
ax=ax_prec,
102+
metric="precision",
103+
)
96104

97-
ax = precision_recall_vs_calc_count(
98-
e_above_hull_error=df.e_above_hull_pred + df.e_above_hull_mp,
99-
e_above_hull_true=df.e_above_hull_mp,
105+
cumulative_clf_metric(
106+
e_above_hull_error,
107+
e_above_hull_true,
100108
color=color,
101-
label=f"{model_name} {F1=:.2}",
102-
intersect_lines="recall_xy", # or "precision_xy", None, 'all'
109+
label=f"{model_name}\n{F1=:.2}",
110+
project_end_point="xy",
103111
stability_crit=stability_crit,
104-
std_pred=std_total,
112+
ax=ax_recall,
113+
metric="recall",
105114
)
106115

107-
# optimal recall line finds all stable materials without any false positives
108-
# can be included to confirm all models start out of with near optimal recall
109-
# and to see how much each model overshoots total n_stable
110-
n_below_hull = sum(df_wbm.e_above_hull_mp2020_corrected_ppd_mp < 0)
111-
ax.plot(
112-
[0, n_below_hull],
113-
[0, 100],
114-
color="green",
115-
linestyle="dashed",
116-
linewidth=1,
117-
label="Optimal Recall",
118-
)
119-
120-
ax.figure.set_size_inches(10, 9)
121-
ax.set(xlim=(0, None))
122-
# keep this outside loop so all model names appear in legend
123-
ax.legend(frameon=False, loc="lower right")
116+
117+
for ax in (ax_prec, ax_recall):
118+
ax.set(xlim=(0, None))
119+
124120

125121
img_name = f"{today}-precision-recall-vs-calc-count-{rare=}"
126-
ax.set(title=img_name.replace("-", "/", 2).replace("-", " ").title())
127122
# x-ticks every 10k materials
128-
ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000))
123+
# ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000))
124+
125+
fig.suptitle(f"{today} ")
126+
xlabel_cumulative = "Materials predicted stable sorted by hull distance"
127+
fig.text(0.5, -0.08, xlabel_cumulative, ha="center")
129128

130129

131130
# %%
132-
ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")
131+
fig.savefig(f"{ROOT}/figures/{img_name}.pdf")

matbench_discovery/plots.py

+59-71
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from collections.abc import Sequence
43
from typing import Any, Literal, get_args
54

65
import matplotlib.pyplot as plt
@@ -17,6 +16,7 @@
1716

1817
StabilityCriterion = Literal["energy", "energy+std", "energy-std"]
1918
WhichEnergy = Literal["true", "pred"]
19+
AxLine = Literal["x", "y", "xy", ""]
2020

2121

2222
# --- define global plot settings
@@ -53,6 +53,7 @@
5353

5454

5555
plt.rc("font", size=14)
56+
plt.rc("legend", fontsize=16)
5657
plt.rc("savefig", bbox="tight", dpi=200)
5758
plt.rc("figure", dpi=200, titlesize=16)
5859
plt.rcParams["figure.constrained_layout.use"] = True
@@ -282,16 +283,18 @@ def rolling_mae_vs_hull_dist(
282283
return ax
283284

284285

285-
def precision_recall_vs_calc_count(
286+
def cumulative_clf_metric(
286287
e_above_hull_error: pd.Series,
287288
e_above_hull_true: pd.Series,
289+
metric: Literal["precision", "recall"],
288290
std_pred: pd.Series = None,
289291
stability_crit: StabilityCriterion = "energy",
290292
stability_threshold: float = 0, # set stability threshold as distance to convex
291293
# hull in eV / atom, usually 0 or 0.1 eV
292294
ax: plt.Axes = None,
293295
label: str = None,
294-
intersect_lines: str | Sequence[str] = (),
296+
project_end_point: AxLine = "xy",
297+
show_optimal: bool = False,
295298
**kwargs: Any,
296299
) -> plt.Axes:
297300
"""Precision and recall as a function of the number of included materials sorted
@@ -305,26 +308,27 @@ def precision_recall_vs_calc_count(
305308
predictions, i.e. residual = pred - target. Defaults to "residual".
306309
e_above_hull_true (str, optional): Column name with convex hull distance values.
307310
Defaults to "e_above_hull".
311+
metric ('precision' | 'recall', optional): Metric to plot.
308312
stability_crit ('energy' | 'energy+std' | 'energy-std', optional): Whether to
309313
use energy+/-std as stability stability_crit where std is the model
310314
predicted uncertainty for the energy it stipulated. Defaults to "energy".
311315
stability_threshold (float, optional): Max distance from convex hull before
312316
material is considered unstable. Defaults to 0.
313317
label (str, optional): Model name used to identify its liens in the legend.
314318
Defaults to None.
315-
intersect_lines (Sequence[str], optional): precision_{x,y,xy} and/or
316-
recall_{x,y,xy}. Defaults to (), i.e. no intersect lines.
319+
project_end_point ('x' | 'y' | 'xy' | '', optional): Defaults to '', i.e. no
320+
axis projection lines.
321+
show_optimal (bool, optional): Whether to plot the optimal precision/recall
322+
line. Defaults to False.
317323
318324
Returns:
319325
plt.Axes: The matplotlib axes object.
320326
"""
321327
ax = ax or plt.gca()
322328

323-
# for series in (e_above_hull_error, e_above_hull_true):
324-
# n_nans = series.isna().sum()
325-
# assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"
326-
327-
is_fresh_ax = len(ax.lines) == 0
329+
for series in (e_above_hull_error, e_above_hull_true):
330+
n_nans = series.isna().sum()
331+
assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"
328332

329333
e_above_hull_error = e_above_hull_error.sort_values()
330334
e_above_hull_true = e_above_hull_true.loc[e_above_hull_error.index]
@@ -338,10 +342,6 @@ def precision_recall_vs_calc_count(
338342
elif stability_crit == "energy-std":
339343
e_above_hull_error -= std_pred
340344

341-
# stability_threshold = 0.02
342-
stability_threshold = 0
343-
# stability_threshold = 0.10
344-
345345
true_pos_mask = (e_above_hull_true <= stability_threshold) & (
346346
e_above_hull_error <= stability_threshold
347347
)
@@ -362,68 +362,56 @@ def precision_recall_vs_calc_count(
362362
true_pos_rate = true_pos_cumsum / n_total_pos * 100
363363

364364
end = int(np.argmax(true_pos_rate))
365-
366365
xs = np.arange(end)
367366

368-
precision_curve = scipy.interpolate.interp1d(xs, precision[:end], kind="cubic")
369-
rolling_recall_curve = scipy.interpolate.interp1d(
370-
xs, true_pos_rate[:end], kind="cubic"
371-
)
367+
ys_raw = dict(precision=precision, recall=true_pos_rate)[metric]
368+
y_interp = scipy.interpolate.interp1d(xs, ys_raw[:end], kind="cubic")
369+
ys = y_interp(xs)
372370

373371
line_kwargs = dict(
374-
linewidth=4,
375-
markevery=[-1],
376-
marker="x",
377-
markersize=14,
378-
markeredgewidth=2.5,
379-
**kwargs,
380-
)
381-
ax.plot(xs, precision_curve(xs), linestyle="-", **line_kwargs)
382-
ax.plot(xs, rolling_recall_curve(xs), linestyle=":", **line_kwargs)
383-
ax.plot((0, 0), (0, 0), label=label, **line_kwargs)
384-
385-
if intersect_lines == "all":
386-
intersect_lines = ("precision_xy", "recall_xy")
387-
if isinstance(intersect_lines, str):
388-
intersect_lines = [intersect_lines]
389-
for line_name in intersect_lines:
390-
try:
391-
line_name_map = dict(precision=precision_curve, recall=rolling_recall_curve)
392-
y_func = line_name_map[line_name.split("_")[0]]
393-
except KeyError:
394-
raise ValueError(
395-
f"Invalid {intersect_lines=}, must be one of {list(line_name_map)}"
396-
)
397-
intersect_kwargs = dict(
398-
linestyle=":", alpha=0.4, color=kwargs.get("color", "gray")
399-
)
400-
# Add some visual guidelines
401-
if "x" in line_name:
402-
ax.plot((0, xs[-1]), (y_func(xs[-1]), y_func(xs[-1])), **intersect_kwargs)
403-
if "y" in line_name:
404-
ax.plot((xs[-1], xs[-1]), (0, y_func(xs[-1])), **intersect_kwargs)
405-
406-
if not is_fresh_ax:
407-
# return earlier if all plot objects besides the line were already drawn by a
408-
# previous call
409-
return ax
410-
411-
xlabel = "Number of compounds sorted by model-predicted hull distance"
412-
ylabel = "Precision and Recall (%)"
413-
ax.set(ylim=(0, 100), xlabel=xlabel, ylabel=ylabel)
414-
415-
[precision] = ax.plot(
416-
(0, 0), (0, 0), "black", linestyle="-", linewidth=line_kwargs["linewidth"]
417-
)
418-
[recall] = ax.plot(
419-
(0, 0), (0, 0), "black", linestyle=":", linewidth=line_kwargs["linewidth"]
372+
linewidth=2, markevery=[-1], marker="x", markersize=14, markeredgewidth=2.5
420373
)
421-
legend = ax.legend(
422-
[precision, recall],
423-
("Precision", "Recall"),
424-
frameon=False,
425-
loc="upper right",
374+
ax.plot(xs, ys, **line_kwargs | kwargs)
375+
ax.text(
376+
xs[-1],
377+
ys[-1],
378+
label,
379+
color=kwargs.get("color"),
380+
verticalalignment="bottom",
381+
rotation=30,
382+
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
426383
)
427-
ax.add_artist(legend)
384+
385+
# add some visual guidelines
386+
intersect_kwargs = dict(linestyle=":", alpha=0.4, color=kwargs.get("color"))
387+
if "x" in project_end_point:
388+
ax.plot((0, xs[-1]), (ys[-1], ys[-1]), **intersect_kwargs)
389+
if "y" in project_end_point:
390+
ax.plot((xs[-1], xs[-1]), (0, ys[-1]), **intersect_kwargs)
391+
392+
ax.set(ylim=(0, 100), ylabel=f"{metric.title()} (%)")
393+
394+
# optimal recall line finds all stable materials without any false positives
395+
# can be included to confirm all models start out of with near optimal recall
396+
# and to see how much each model overshoots total n_stable
397+
n_below_hull = sum(e_above_hull_true < 0)
398+
if show_optimal:
399+
ax.plot(
400+
[0, n_below_hull],
401+
[0, 100],
402+
color="green",
403+
linestyle="dashed",
404+
linewidth=1,
405+
label=f"Optimal {metric.title()}",
406+
)
407+
ax.text(
408+
n_below_hull,
409+
100,
410+
label,
411+
color=kwargs.get("color"),
412+
verticalalignment="top",
413+
rotation=-30,
414+
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
415+
)
428416

429417
return ax

0 commit comments

Comments
 (0)