Skip to content

Commit 4972c01

Browse files
committed
refactor cumulative_clf_metric()->cumulative_precision_recall() to plot both prec and recall for all models
support both plotly and matplotlib as backend
1 parent 0fb7550 commit 4972c01

File tree

6 files changed

+210
-170
lines changed

6 files changed

+210
-170
lines changed

matbench_discovery/energy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def classify_stable(
138138
negative values close to 0 make sense.
139139
140140
Returns:
141-
tuple[pd.Series, pd.Series, pd.Series, pd.Series]: Indices for true positives,
141+
tuple[TP, FN, FP, TN]: Indices as pd.Series for true positives,
142142
false negatives, false positives and true negatives (in this order).
143143
"""
144144
actual_pos = e_above_hull_true <= stability_threshold

matbench_discovery/plots.py

+139-103
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
from typing import Any, Literal
45

56
import matplotlib.pyplot as plt
@@ -80,11 +81,11 @@ def hist_classified_stable_vs_hull_dist(
8081
ax: plt.Axes = None,
8182
which_energy: WhichEnergy = "true",
8283
stability_threshold: float = 0,
83-
show_threshold: bool = True,
8484
x_lim: tuple[float | None, float | None] = (-0.4, 0.4),
8585
rolling_accuracy: float | None = 0.02,
8686
backend: Backend = "plotly",
8787
ylabel: str = "Number of materials",
88+
**kwargs: Any,
8889
) -> tuple[plt.Axes | go.Figure, dict[str, float]]:
8990
"""
9091
Histogram of the energy difference (either according to DFT ground truth [default]
@@ -108,13 +109,13 @@ def hist_classified_stable_vs_hull_dist(
108109
distance or the model's predicted hull distance for the histogram.
109110
stability_threshold (float, optional): set stability threshold as distance to
110111
convex hull in eV/atom, usually 0 or 0.1 eV.
111-
show_threshold (bool, optional): Whether to plot stability threshold as dashed
112-
vertical line.
113112
x_lim (tuple[float | None, float | None]): x-axis limits.
114113
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to None
115114
or 0 to disable. Defaults to 0.02, meaning 20 meV / atom.
116115
backend ('matplotlib' | 'plotly'], optional): Which plotting backend to use.
117116
Changes the return type.
117+
kwargs: Additional keyword arguments passed to the ax.hist() or px.histogram()
118+
depending on backend.
118119
119120
Returns:
120121
tuple[plt.Axes, dict[str, float]]: plot axes and classification metrics
@@ -159,15 +160,17 @@ def hist_classified_stable_vs_hull_dist(
159160
color=["tab:green", "tab:orange", "tab:red", "tab:blue"],
160161
label=labels,
161162
stacked=True,
163+
**kwargs,
162164
)
163165
ax.set(xlabel=xlabel, ylabel=ylabel, xlim=x_lim)
164166

165-
ax.axvline(
166-
stability_threshold,
167-
color="black",
168-
linestyle="--",
169-
label="Stability Threshold",
170-
)
167+
if stability_threshold is not None:
168+
ax.axvline(
169+
stability_threshold,
170+
color="black",
171+
linestyle="--",
172+
label="Stability Threshold",
173+
)
171174

172175
if rolling_accuracy:
173176
# add moving average of the accuracy computed within given window
@@ -203,28 +206,35 @@ def hist_classified_stable_vs_hull_dist(
203206
# )
204207

205208
if backend == "plotly":
206-
clf = (true_pos * 1 + false_neg * 2 + false_pos * 3 + true_neg * 4).map(
209+
clf = (true_pos + false_neg * 2 + false_pos * 3 + true_neg * 4).map(
207210
dict(zip(range(1, 5), labels))
208211
)
209212
df = pd.DataFrame(dict(e_above_hull=e_above_hull, clf=clf))
210213

211214
ax = px.histogram(
212-
df, x="e_above_hull", color="clf", nbins=20000, range_x=x_lim, opacity=0.9
215+
df,
216+
x="e_above_hull",
217+
color="clf",
218+
nbins=20000,
219+
range_x=x_lim,
220+
opacity=0.9,
221+
**kwargs,
213222
)
214223
ax.update_layout(
215224
dict(xaxis_title=xlabel, yaxis_title=ylabel),
216225
legend=dict(title=None, yanchor="top", y=1, xanchor="right", x=1),
217226
)
218227

219-
ax.add_vline(stability_threshold, line=dict(dash="dash", width=1))
220-
ax.add_annotation(
221-
text="Stability threshold",
222-
x=stability_threshold,
223-
y=1.1,
224-
yref="paper",
225-
font=dict(size=14, color="gray"),
226-
showarrow=False,
227-
)
228+
if stability_threshold is not None:
229+
ax.add_vline(stability_threshold, line=dict(dash="dash", width=1))
230+
ax.add_annotation(
231+
text="Stability threshold",
232+
x=stability_threshold,
233+
y=1.1,
234+
yref="paper",
235+
font=dict(size=14, color="gray"),
236+
showarrow=False,
237+
)
228238

229239
recall = n_true_pos / n_total_pos
230240
return ax, dict(
@@ -341,115 +351,141 @@ def rolling_mae_vs_hull_dist(
341351
return ax
342352

343353

344-
def cumulative_clf_metric(
354+
def cumulative_precision_recall(
345355
e_above_hull_true: pd.Series,
346-
e_above_hull_pred: pd.Series,
347-
metric: Literal["precision", "recall"],
356+
df_preds: pd.DataFrame,
348357
stability_threshold: float = 0, # set stability threshold as distance to convex
349358
# hull in eV / atom, usually 0 or 0.1 eV
350-
ax: plt.Axes = None,
351-
label: str = None,
352359
project_end_point: AxLine = "xy",
353360
show_optimal: bool = False,
361+
backend: Backend = "plotly",
354362
**kwargs: Any,
355-
) -> plt.Axes:
356-
"""Precision and recall as a function of the number of included materials sorted
357-
by model-predicted distance to the convex hull, i.e. materials predicted most stable
358-
enter the precision and recall calculation first. The curves end when all materials
359-
predicted stable are included.
363+
) -> tuple[plt.Figure | go.Figure, pd.DataFrame]:
364+
"""Create 2 subplots side-by-side with cumulative precision and recall curves for
365+
all models starting with materials predicted most stable, adding the next material,
366+
recomputing the cumulative metrics, adding the next most stable material and so on
367+
until each model no longer predicts the material to be stable. Again, materials
368+
predicted more stable enter the precision and recall calculation sooner. Different
369+
models predict different number of materials to be stable. Hence the curves end at
370+
different points.
360371
361372
Args:
362373
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
363374
ground truth (in eV / atom).
364-
e_above_hull_pred (pd.Series): Distance to convex hull predicted by model
365-
(in eV / atom). Same as true energy to convex hull plus predicted minus true
366-
formation energy.
367-
metric ('precision' | 'recall', optional): Metric to plot.
368-
stability_threshold (float, optional): Max distance from convex hull before
375+
df_preds (pd.DataFrame): Distance to convex hull predicted by models, one column
376+
per model (in eV / atom). Same as true energy to convex hull plus predicted
377+
minus true formation energy.
378+
stability_threshold (float, optional): Max distance above convex hull before
369379
material is considered unstable. Defaults to 0.
370-
label (str, optional): Model name used to identify its liens in the legend.
371-
Defaults to None.
372-
project_end_point ('x' | 'y' | 'xy' | '', optional): Defaults to '', i.e. no
380+
project_end_point ('x' | 'y' | 'xy' | '', optional): Whether to project end
381+
points of precision and recall curves to the x/y axis. Defaults to '', i.e. no
373382
axis projection lines.
374-
show_optimal (bool, optional): Whether to plot the optimal precision/recall
375-
line. Defaults to False.
376-
**kwargs: Keyword arguments passed to ax.plot().
383+
show_optimal (bool, optional): Whether to plot the optimal recall line. Defaults
384+
to False.
385+
backend ('plotly' | 'matplotlib', optional): Defaults to 'plotly'. **kwargs:
386+
Keyword arguments passed to df.plot().
377387
378388
Returns:
379-
plt.Axes: The matplotlib axes object.
389+
tuple[plt.Figure | go.Figure, pd.DataFrame]: The matplotlib/plotly figure and
390+
dataframe of cumulative metrics for each model.
380391
"""
381-
ax = ax or plt.gca()
392+
fact = lambda: pd.DataFrame(index=range(len(e_above_hull_true)))
393+
dfs = dict(Precision=fact(), Recall=fact())
382394

383-
e_above_hull_pred = e_above_hull_pred.sort_values()
384-
e_above_hull_true = e_above_hull_true.loc[e_above_hull_pred.index]
395+
for model_name in df_preds:
396+
model_preds = df_preds[model_name].sort_values()
397+
e_above_hull_true = e_above_hull_true.loc[model_preds.index]
385398

386-
true_pos, false_neg, false_pos, _true_neg = classify_stable(
387-
e_above_hull_true, e_above_hull_pred, stability_threshold
388-
)
399+
true_pos, false_neg, false_pos, _true_neg = classify_stable(
400+
e_above_hull_true, model_preds, stability_threshold
401+
)
389402

390-
true_pos_cumsum = true_pos.cumsum()
403+
true_pos_cumsum = true_pos.cumsum()
404+
# precision aka positive predictive value (PPV)
405+
precision = true_pos_cumsum / (true_pos_cumsum + false_pos.cumsum())
406+
n_total_pos = sum(true_pos) + sum(false_neg)
407+
recall = true_pos_cumsum / n_total_pos # aka true_pos_rate aka sensitivity
391408

392-
# precision aka positive predictive value (PPV)
393-
precision = true_pos_cumsum / (true_pos_cumsum + false_pos.cumsum()) * 100
394-
n_true_pos = sum(true_pos)
395-
n_false_neg = sum(false_neg)
396-
n_total_pos = n_true_pos + n_false_neg
397-
true_pos_rate = true_pos_cumsum / n_total_pos * 100
409+
end = int(np.argmax(recall))
410+
xs = np.arange(end)
398411

399-
end = int(np.argmax(true_pos_rate))
400-
xs = np.arange(end)
412+
prec_interp = scipy.interpolate.interp1d(xs, precision[:end], kind="cubic")
413+
recall_interp = scipy.interpolate.interp1d(xs, recall[:end], kind="cubic")
414+
dfs["Precision"][model_name] = pd.Series(prec_interp(xs))
415+
dfs["Recall"][model_name] = pd.Series(recall_interp(xs))
401416

402-
ys_raw = dict(precision=precision, recall=true_pos_rate)[metric]
403-
y_interp = scipy.interpolate.interp1d(xs, ys_raw[:end], kind="cubic")
404-
ys = y_interp(xs)
417+
for key, df in dfs.items():
418+
# drop all-NaN rows so plotly plot x-axis only extends to largest number of
419+
# predicted materials by any model
420+
df.dropna(how="all", inplace=True)
421+
df["metric"] = key
405422

406-
line_kwargs = dict(
407-
linewidth=2, markevery=[-1], marker="x", markersize=14, markeredgewidth=2.5
408-
)
409-
ax.plot(xs, ys, **line_kwargs | kwargs)
410-
ax.text(
411-
xs[-1],
412-
ys[-1],
413-
label,
414-
color=kwargs.get("color"),
415-
verticalalignment="bottom",
416-
rotation=30,
417-
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
418-
)
423+
df_cum = pd.concat(dfs.values())
419424

420-
# add some visual guidelines
421-
intersect_kwargs = dict(linestyle=":", alpha=0.4, color=kwargs.get("color"))
422-
if "x" in project_end_point:
423-
ax.plot((0, xs[-1]), (ys[-1], ys[-1]), **intersect_kwargs)
424-
if "y" in project_end_point:
425-
ax.plot((xs[-1], xs[-1]), (0, ys[-1]), **intersect_kwargs)
426-
427-
ax.set(ylim=(0, 100), ylabel=f"{metric.title()} (%)")
428-
429-
# optimal recall line finds all stable materials without any false positives
430-
# can be included to confirm all models start out of with near optimal recall
431-
# and to see how much each model overshoots total n_stable
432-
n_below_hull = sum(e_above_hull_true < 0)
433-
if show_optimal:
434-
ax.plot(
435-
[0, n_below_hull],
436-
[0, 100],
437-
color="green",
438-
linestyle="dashed",
439-
linewidth=1,
440-
label=f"Optimal {metric.title()}",
425+
if backend == "matplotlib":
426+
fig, axs = plt.subplots(1, 2, figsize=(15, 7), sharey=True)
427+
line_kwargs = dict(
428+
linewidth=3, markevery=[-1], marker="x", markersize=14, markeredgewidth=2.5
441429
)
442-
ax.text(
443-
n_below_hull,
444-
100,
445-
label,
446-
color=kwargs.get("color"),
447-
verticalalignment="top",
448-
rotation=-30,
449-
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
430+
for (key, df), ax in zip(dfs.items(), axs):
431+
# select every n-th row of df so that 1000 rows are left for increased
432+
# plotting speed and reduced file size
433+
# falls back on every row if df has less than 1000 rows
434+
435+
df.iloc[:: len(df) // 1000 or 1].plot(
436+
ax=ax, legend=False, backend=backend, **line_kwargs | kwargs, ylabel=key
437+
)
438+
439+
# add some visual guidelines
440+
intersect_kwargs = dict(linestyle=":", alpha=0.4, linewidth=2)
441+
bbox = dict(facecolor="white", alpha=0.5, edgecolor="none")
442+
assert len(axs) == len(dfs), f"{len(axs)} != {len(dfs)}"
443+
444+
for ax, df in zip(axs, dfs.values()):
445+
ax.set(ylim=(0, 1), xlim=(0, None), ylabel=key)
446+
for model in df_preds:
447+
x_end = df[model].dropna().index[-1]
448+
y_end = df[model].dropna().iloc[-1]
449+
# place model name at the end of every line
450+
ax.text(x_end, y_end, model, va="bottom", rotation=30, bbox=bbox)
451+
if "x" in project_end_point:
452+
ax.plot((x_end, x_end), (0, y_end), **intersect_kwargs)
453+
if "y" in project_end_point:
454+
ax.plot((0, x_end), (y_end, y_end), **intersect_kwargs)
455+
456+
# optimal recall line finds all stable materials without any false positives
457+
# can be included to confirm all models start out of with near optimal recall
458+
# and to see how much each model overshoots total n_stable
459+
n_below_hull = sum(e_above_hull_true < 0)
460+
if show_optimal:
461+
opt_label = "Optimal Recall"
462+
axs[1].plot([0, n_below_hull], [0, 1], color="green", linestyle="--")
463+
axs[1].text(
464+
*[n_below_hull, 0.81],
465+
opt_label,
466+
color="green",
467+
va="bottom",
468+
ha="right",
469+
rotation=math.degrees(math.cos(math.atan(1 / n_below_hull))),
470+
bbox=bbox,
471+
)
472+
473+
elif backend == "plotly":
474+
fig = df_cum.iloc[:: len(df_cum) // 1000 or 1].plot(
475+
backend=backend, facet_col="metric", **kwargs
450476
)
477+
fig.update_traces(line=dict(width=4))
478+
for idx in range(1, 3):
479+
fig.update_xaxes(
480+
title_text="Number of materials predicted stable", row=1, col=idx
481+
)
482+
fig.update_yaxes(title="Precision", col=1)
483+
fig.update_yaxes(title="Recall", col=2)
484+
fig.for_each_annotation(lambda a: a.update(text=""))
485+
fig.update_layout(legend=dict(title=""))
486+
fig.update_layout(showlegend=False)
451487

452-
return ax
488+
return fig, df_cum
453489

454490

455491
def wandb_scatter(table: wandb.Table, fields: dict[str, str], **kwargs: Any) -> None:

scripts/hist_classified_stable_vs_hull_dist.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@
6060

6161

6262
# %%
63-
fig_name = f"{ROOT}/figures/{today}-wren-wbm-hull-dist-hist-{which_energy=}.pdf"
63+
img_path = f"{ROOT}/figures/{today}-wren-wbm-hull-dist-hist-{which_energy=}.pdf"
6464
if hasattr(ax, "write_image"):
65-
# fig.write_image(fig_name)
66-
ax.write_html(fig_name.replace(".pdf", ".html"))
65+
# fig.write_image(img_path)
66+
ax.write_html(img_path.replace(".pdf", ".html"))
6767
else:
68-
ax.figure.savefig(fig_name)
68+
ax.figure.savefig(img_path)

scripts/hist_classified_stable_vs_hull_dist_models.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
backend: Backend = "matplotlib"
3939
if backend == "matplotlib":
40-
fig, axs = plt.subplots(3, 3, figsize=(18, 12))
40+
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(18, 12))
4141
else:
4242
fig = make_subplots(rows=3, cols=3)
4343

@@ -75,9 +75,13 @@
7575
frameon=False,
7676
)
7777

78-
fig
78+
fig.show()
7979

8080

8181
# %%
82-
img_path = f"{ROOT}/figures/{today}-wbm-hull-dist-hist-models.pdf"
83-
ax.figure.savefig(img_path)
82+
img_path = f"{ROOT}/figures/{today}-wbm-hull-dist-hist-models"
83+
# if hasattr(fig, "write_image"):
84+
# fig.write_image(f"{img_path}.pdf")
85+
# fig.write_html(f"{img_path}.html", include_ploltyjs="cdn")
86+
# else:
87+
# fig.savefig(f"{img_path}.pdf")

0 commit comments

Comments
 (0)