Skip to content

Commit 0fb7550

Browse files
committed
add backend=plotly | matplotlib to hist_classified_stable_vs_hull_dist()
1 parent 6dd4398 commit 0fb7550

5 files changed

+160
-102
lines changed

matbench_discovery/plots.py

+99-78
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pandas as pd
88
import plotly.express as px
9+
import plotly.graph_objs as go
910
import plotly.io as pio
1011
import scipy.interpolate
1112
import scipy.stats
@@ -19,7 +20,7 @@
1920

2021
WhichEnergy = Literal["true", "pred"]
2122
AxLine = Literal["x", "y", "xy", ""]
22-
23+
Backend = Literal["matplotlib", "plotly"]
2324

2425
# --- start global plot settings
2526
quantity_labels = dict(
@@ -53,8 +54,11 @@
5354
dft="DFT",
5455
)
5556
px.defaults.labels = quantity_labels | model_labels
56-
57-
pio.templates.default = "plotly_white"
57+
pastel_layout = dict(
58+
colorway=px.colors.qualitative.Pastel, margin=dict(l=40, r=30, t=60, b=30)
59+
)
60+
pio.templates["pastel"] = dict(layout=pastel_layout)
61+
pio.templates.default = "plotly_white+pastel"
5862

5963
# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924
6064
# when seeing MathJax "loading" message in exported PDFs, try:
@@ -79,7 +83,9 @@ def hist_classified_stable_vs_hull_dist(
7983
show_threshold: bool = True,
8084
x_lim: tuple[float | None, float | None] = (-0.4, 0.4),
8185
rolling_accuracy: float | None = 0.02,
82-
) -> tuple[plt.Axes, dict[str, float]]:
86+
backend: Backend = "plotly",
87+
ylabel: str = "Number of materials",
88+
) -> tuple[plt.Axes | go.Figure, dict[str, float]]:
8389
"""
8490
Histogram of the energy difference (either according to DFT ground truth [default]
8591
or model predicted energy) to the convex hull for materials in the WBM data set. The
@@ -106,16 +112,16 @@ def hist_classified_stable_vs_hull_dist(
106112
vertical line.
107113
x_lim (tuple[float | None, float | None]): x-axis limits.
108114
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to None
109-
or 0 to disable. Defaults to 0.01.
115+
or 0 to disable. Defaults to 0.02, meaning 20 meV / atom.
116+
backend ('matplotlib' | 'plotly'], optional): Which plotting backend to use.
117+
Changes the return type.
110118
111119
Returns:
112120
tuple[plt.Axes, dict[str, float]]: plot axes and classification metrics
113121
114122
NOTE this figure plots hist bars separately which causes aliasing in pdf. Can be
115123
fixed in Inkscape or similar by merging regions by color.
116124
"""
117-
ax = ax or plt.gca()
118-
119125
true_pos, false_neg, false_pos, true_neg = classify_stable(
120126
e_above_hull_true, e_above_hull_pred, stability_threshold
121127
)
@@ -131,90 +137,105 @@ def hist_classified_stable_vs_hull_dist(
131137
eah_false_neg = e_above_hull[false_neg]
132138
eah_false_pos = e_above_hull[false_pos]
133139
eah_true_neg = e_above_hull[true_neg]
134-
xlabel = dict(
135-
true="$E_\\mathrm{above\\ hull}$ (eV / atom)",
136-
pred="$E_\\mathrm{above\\ hull\\ pred}$ (eV / atom)",
137-
)[which_energy]
138-
139-
ax.hist(
140-
[eah_true_pos, eah_false_neg, eah_false_pos, eah_true_neg],
141-
bins=200,
142-
range=x_lim,
143-
alpha=0.5,
144-
color=["tab:green", "tab:orange", "tab:red", "tab:blue"],
145-
label=[
146-
"True Positives",
147-
"False Negatives",
148-
"False Positives",
149-
"True Negatives",
150-
],
151-
stacked=True,
152-
)
153-
154140
n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
155-
len, (eah_true_pos, eah_false_pos, eah_true_neg, eah_false_neg)
141+
sum, (true_pos, false_pos, true_neg, false_neg)
156142
)
157143
# null = (tp + fn) / (tp + tn + fp + fn)
158144
precision = n_true_pos / (n_true_pos + n_false_pos)
159145

160-
# assert (n_all := n_true_pos + n_false_pos + n_true_neg + n_false_neg) == len(
161-
# e_above_hull_true
162-
# ), f"{n_all} != {len(e_above_hull_true)}"
163-
164-
ax.set(xlabel=xlabel, ylabel="Number of compounds", xlim=x_lim)
165-
166-
if rolling_accuracy:
167-
# add moving average of the accuracy (computed within 20 meV/atom intervals) as
168-
# a function of ΔHd,MP is shown as a blue line (right axis)
169-
ax_acc = ax.twinx()
170-
ax_acc.set_ylabel("Accuracy", color="darkblue")
171-
ax_acc.tick_params(labelcolor="darkblue")
172-
ax_acc.set(ylim=(0, 1))
173-
174-
# --- moving average of the accuracy
175-
# compute accuracy within 20 meV/atom intervals
176-
bins = np.arange(x_lim[0], x_lim[1], rolling_accuracy)
177-
bin_counts = np.histogram(e_above_hull_true, bins)[0]
178-
bin_true_pos = np.histogram(eah_true_pos, bins)[0]
179-
bin_true_neg = np.histogram(eah_true_neg, bins)[0]
180-
181-
# compute accuracy
182-
bin_accuracies = (bin_true_pos + bin_true_neg) / bin_counts
183-
# plot accuracy
184-
ax_acc.plot(
185-
bins[:-1],
186-
bin_accuracies,
187-
color="tab:blue",
188-
label="Accuracy",
189-
linewidth=3,
146+
xlabel = dict(
147+
true=r"$E_\mathrm{above\ hull}\;\mathrm{(eV / atom)}$",
148+
pred=r"$E_\mathrm{above\ hull\ pred}\;\mathrm{(eV / atom)}$",
149+
)[which_energy]
150+
labels = ["True Positives", "False Negatives", "False Positives", "True Negatives"]
151+
152+
if backend == "matplotlib":
153+
ax = ax or plt.gca()
154+
ax.hist(
155+
[eah_true_pos, eah_false_neg, eah_false_pos, eah_true_neg],
156+
bins=200,
157+
range=x_lim,
158+
alpha=0.5,
159+
color=["tab:green", "tab:orange", "tab:red", "tab:blue"],
160+
label=labels,
161+
stacked=True,
190162
)
191-
# ax2.fill_between(
192-
# bin_centers,
193-
# bin_accuracy - bin_accuracy_std,
194-
# bin_accuracy + bin_accuracy_std,
195-
# color="tab:blue",
196-
# alpha=0.2,
197-
# )
198-
199-
if show_threshold:
163+
ax.set(xlabel=xlabel, ylabel=ylabel, xlim=x_lim)
164+
200165
ax.axvline(
201166
stability_threshold,
202-
color="k",
167+
color="black",
203168
linestyle="--",
204169
label="Stability Threshold",
205170
)
206171

207-
recall = n_true_pos / n_total_pos
172+
if rolling_accuracy:
173+
# add moving average of the accuracy computed within given window
174+
# as a function of e_above_hull shown as blue line (right axis)
175+
ax_acc = ax.twinx()
176+
ax_acc.set_ylabel("Accuracy", color="darkblue")
177+
ax_acc.tick_params(labelcolor="darkblue")
178+
ax_acc.set(ylim=(0, 1))
179+
180+
# --- moving average of the accuracy
181+
# compute accuracy within 20 meV/atom intervals
182+
bins = np.arange(x_lim[0], x_lim[1], rolling_accuracy)
183+
bin_counts = np.histogram(e_above_hull_true, bins)[0]
184+
bin_true_pos = np.histogram(eah_true_pos, bins)[0]
185+
bin_true_neg = np.histogram(eah_true_neg, bins)[0]
186+
187+
# compute accuracy
188+
bin_accuracies = (bin_true_pos + bin_true_neg) / bin_counts
189+
# plot accuracy
190+
ax_acc.plot(
191+
bins[:-1],
192+
bin_accuracies,
193+
color="tab:blue",
194+
label="Accuracy",
195+
linewidth=3,
196+
)
197+
# ax2.fill_between(
198+
# bin_centers,
199+
# bin_accuracy - bin_accuracy_std,
200+
# bin_accuracy + bin_accuracy_std,
201+
# color="tab:blue",
202+
# alpha=0.2,
203+
# )
204+
205+
if backend == "plotly":
206+
clf = (true_pos * 1 + false_neg * 2 + false_pos * 3 + true_neg * 4).map(
207+
dict(zip(range(1, 5), labels))
208+
)
209+
df = pd.DataFrame(dict(e_above_hull=e_above_hull, clf=clf))
208210

209-
return ax, {
210-
"enrichment": precision / null,
211-
"precision": precision,
212-
"recall": recall,
213-
"prevalence": null,
214-
"accuracy": (n_true_pos + n_true_neg)
211+
ax = px.histogram(
212+
df, x="e_above_hull", color="clf", nbins=20000, range_x=x_lim, opacity=0.9
213+
)
214+
ax.update_layout(
215+
dict(xaxis_title=xlabel, yaxis_title=ylabel),
216+
legend=dict(title=None, yanchor="top", y=1, xanchor="right", x=1),
217+
)
218+
219+
ax.add_vline(stability_threshold, line=dict(dash="dash", width=1))
220+
ax.add_annotation(
221+
text="Stability threshold",
222+
x=stability_threshold,
223+
y=1.1,
224+
yref="paper",
225+
font=dict(size=14, color="gray"),
226+
showarrow=False,
227+
)
228+
229+
recall = n_true_pos / n_total_pos
230+
return ax, dict(
231+
enrichment=precision / null,
232+
precision=precision,
233+
recall=recall,
234+
prevalence=null,
235+
accuracy=(n_true_pos + n_true_neg)
215236
/ (n_true_pos + n_true_neg + n_false_pos + n_false_neg),
216-
"f1": 2 * (precision * recall) / (precision + recall),
217-
}
237+
f1=2 * (precision * recall) / (precision + recall),
238+
)
218239

219240

220241
def rolling_mae_vs_hull_dist(
@@ -432,7 +453,7 @@ def cumulative_clf_metric(
432453

433454

434455
def wandb_scatter(table: wandb.Table, fields: dict[str, str], **kwargs: Any) -> None:
435-
"""Log a parity scatter plot using custom vega spec to WandB.
456+
"""Log a parity scatter plot using custom Vega spec to WandB.
436457
437458
Args:
438459
table (wandb.Table): WandB data table.

scripts/hist_classified_stable_vs_hull_dist.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# %%
2-
from matbench_discovery import today
2+
from matbench_discovery import ROOT, today
33
from matbench_discovery.load_preds import load_df_wbm_with_preds
44
from matbench_discovery.plots import WhichEnergy, hist_classified_stable_vs_hull_dist
55

@@ -50,15 +50,19 @@
5050
which_energy=which_energy,
5151
# stability_threshold=-0.05,
5252
rolling_accuracy=0,
53+
# backend="matplotlib",
5354
)
55+
if hasattr(ax, "legend"):
56+
legend_title = f"Enrichment Factor = {metrics['enrichment']:.3}"
57+
ax.legend(loc="upper left", frameon=False, title=legend_title)
5458

55-
fig = ax.figure
56-
fig.set_size_inches(10, 9)
57-
58-
legend_title = f"Enrichment Factor = {metrics['enrichment']:.3}"
59-
ax.legend(loc="center left", frameon=False, title=legend_title)
59+
ax
6060

6161

6262
# %%
63-
fig_name = f"{today}-wren-wbm-hull-dist-hist-{which_energy=}"
64-
# fig.savefig(f"{ROOT}/figures/{fig_name}.pdf")
63+
fig_name = f"{ROOT}/figures/{today}-wren-wbm-hull-dist-hist-{which_energy=}.pdf"
64+
if hasattr(ax, "write_image"):
65+
# fig.write_image(fig_name)
66+
ax.write_html(fig_name.replace(".pdf", ".html"))
67+
else:
68+
ax.figure.savefig(fig_name)

scripts/hist_classified_stable_vs_hull_dist_batches.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
+ (df_wbm[model_name] - df_wbm[target_col]),
5757
which_energy=which_energy,
5858
ax=axs.flat[-1],
59+
backend="matplotlib",
5960
)
6061

6162
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"

scripts/hist_classified_stable_vs_hull_dist_models.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# %%
2+
from plotly.subplots import make_subplots
3+
24
from matbench_discovery import ROOT, today
35
from matbench_discovery.load_preds import load_df_wbm_with_preds
46
from matbench_discovery.plots import (
7+
Backend,
58
WhichEnergy,
69
hist_classified_stable_vs_hull_dist,
710
plt,
@@ -30,30 +33,49 @@
3033

3134
# %%
3235
which_energy: WhichEnergy = "true"
33-
fig, axs = plt.subplots(3, 3, figsize=(18, 12))
34-
3536
model_name = "Wrenformer"
3637

37-
for model_name, ax in zip(models, axs.flat, strict=True):
38+
backend: Backend = "matplotlib"
39+
if backend == "matplotlib":
40+
fig, axs = plt.subplots(3, 3, figsize=(18, 12))
41+
else:
42+
fig = make_subplots(rows=3, cols=3)
43+
3844

45+
for idx, model_name in enumerate(models):
3946
ax, metrics = hist_classified_stable_vs_hull_dist(
4047
e_above_hull_true=df_wbm[e_above_hull_col],
4148
e_above_hull_pred=df_wbm[e_above_hull_col]
4249
+ (df_wbm[model_name] - df_wbm[target_col]),
4350
which_energy=which_energy,
44-
ax=ax,
51+
ax=axs.flat[idx],
52+
backend=backend,
4553
)
46-
47-
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
48-
ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes)
49-
5054
title = f"{model_name} ({len(df_wbm[model_name].dropna()):,})"
51-
ax.set(title=title)
52-
55+
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
5356

54-
# axs.flat[0].legend(frameon=False, loc="upper left")
57+
if backend == "matplotlib":
58+
ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes)
59+
ax.set(title=title)
60+
61+
else:
62+
ax.add_annotation(text=text, x=0.5, y=0.5, showarrow=False)
63+
ax.update_xaxes(title_text=title)
64+
65+
for trace in ax.data:
66+
fig.append_trace(trace, row=idx % 3 + 1, col=idx // 3 + 1)
67+
68+
if backend == "matplotlib":
69+
fig.suptitle(f"{today} {which_energy=}", y=1.07, fontsize=16)
70+
plt.figlegend(
71+
*ax.get_legend_handles_labels(),
72+
ncol=10,
73+
loc="lower center",
74+
bbox_to_anchor=(0.5, -0.05),
75+
frameon=False,
76+
)
5577

56-
fig.suptitle(f"{today} {which_energy=}", y=1.07, fontsize=16)
78+
fig
5779

5880

5981
# %%

0 commit comments

Comments
 (0)