Skip to content

Commit 7bd2038

Browse files
committed
use different line styles for models in cumulative-precision-recall.svelte plot
update site/src/figs/hist-clf-true-hull-dist-models-5x2.svelte rm site/src/figs/hist-clf-true-hull-dist-models-2x4.svelte rm site/src/figs/hist-clf-true-hull-dist-models-3x3.svelte rm site/src/figs/hist-clf-true-hull-dist-models-4x2.svelte update <HistClfTrueHullDistModels /> fig caption delete <HistClfTrueHullDistModels /> from si/+page.md
1 parent e7f9fe8 commit 7bd2038

18 files changed

+128
-79
lines changed

citation.cff

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ authors:
4343
affiliations:
4444
- Cavendish Laboratory, University of Cambridge, UK
4545
- Lawrence Berkeley National Laboratory, Berkeley, USA
46+
- German Federal Institute of Materials Research and Testing (BAM)
4647
license: MIT
4748
license-url: https://github.com/janosh/matbench-discovery/blob/-/license"
4849
repository-code: https://github.com/janosh/matbench-discovery

matbench_discovery/plots.py

+50-35
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import functools
56
import math
67
import os
78
import subprocess
@@ -21,6 +22,8 @@
2122
import wandb
2223
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
2324
from pandas.io.formats.style import Styler
25+
from plotly.validators.scatter.line import DashValidator
26+
from plotly.validators.scatter.marker import SymbolValidator
2427
from tqdm import tqdm
2528

2629
from matbench_discovery import STABILITY_THRESHOLD
@@ -31,15 +34,20 @@
3134

3235
Backend = Literal["matplotlib", "plotly"]
3336

37+
plotly_markers = SymbolValidator().values[2::3] # noqa: PD011
38+
plotly_line_styles = DashValidator().values[:-1] # noqa: PD011
39+
# repeat line styles as many as times as needed to match number of markers
40+
plotly_line_styles *= len(plotly_markers) // len(plotly_line_styles)
3441

35-
def unit(text: str) -> str:
42+
43+
def plotly_unit(text: str) -> str:
3644
"""Wrap text in a span with decreased font size and weight to display units in
3745
plotly labels.
3846
"""
3947
return f"<span style='font-size: 0.8em; font-weight: lighter;'>({text})</span>"
4048

4149

42-
ev_per_atom = unit("eV/atom")
50+
ev_per_atom = plotly_unit("eV/atom")
4351

4452
# --- start global plot settings
4553
quantity_labels = dict(
@@ -51,11 +59,11 @@ def unit(text: str) -> str:
5159
n_sites="Lattice site count",
5260
energy_per_atom=f"Energy {ev_per_atom}",
5361
e_form=f"DFT E<sub>form</sub> {ev_per_atom}",
54-
e_above_hull=f"E<sub>above hull</sub> {ev_per_atom}",
55-
e_above_hull_mp2020_corrected_ppd_mp=f"DFT E<sub>above hull</sub> {ev_per_atom}",
56-
e_above_hull_pred=f"Predicted E<sub>above hull</sub> {ev_per_atom}",
62+
e_above_hull=f"E<sub>hull dist</sub> {ev_per_atom}",
63+
e_above_hull_mp2020_corrected_ppd_mp=f"DFT E<sub>hull dist</sub> {ev_per_atom}",
64+
e_above_hull_pred=f"Predicted E<sub>hull dist</sub> {ev_per_atom}",
5765
e_above_hull_mp=f"E<sub>above MP hull</sub> {ev_per_atom}",
58-
e_above_hull_error=f"Error in E<sub>above hull</sub> {ev_per_atom}",
66+
e_above_hull_error=f"Error in E<sub>hull dist</sub> {ev_per_atom}",
5967
vol_diff="Volume difference (A^3)",
6068
e_form_per_atom_mp2020_corrected=f"DFT E<sub>form</sub> {ev_per_atom}",
6169
e_form_per_atom_pred=f"Predicted E<sub>form</sub> {ev_per_atom}",
@@ -547,7 +555,7 @@ def rolling_mae_vs_hull_dist(
547555
scatter_kwds = dict(
548556
fill="toself", opacity=0.2, hoverinfo="skip", showlegend=False
549557
)
550-
triangle_anno = "MAE > |E<sub>above hull</sub>|"
558+
triangle_anno = "MAE > |E<sub>hull dist</sub>|"
551559
fig.add_scatter(
552560
x=(-1, -dft_acc, dft_acc, 1) if show_dft_acc else (-1, 0, 1),
553561
y=(1, dft_acc, dft_acc, 1) if show_dft_acc else (1, 0, 1),
@@ -632,6 +640,7 @@ def cumulative_metrics(
632640
optimal_recall: str | None = "Optimal Recall",
633641
show_n_stable: bool = True,
634642
backend: Backend = "plotly",
643+
n_points: int = 50,
635644
**kwargs: Any,
636645
) -> tuple[plt.Figure | go.Figure, pd.DataFrame]:
637646
"""Create 2 subplots side-by-side with cumulative precision and recall curves for
@@ -661,18 +670,24 @@ def cumulative_metrics(
661670
number of stable materials. Defaults to True.
662671
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
663672
Changes the return type. Defaults to 'plotly'.
673+
n_points (int, optional): Number of points to use for interpolation of the
674+
metric curves. Defaults to 80.
664675
**kwargs: Keyword arguments passed to df.plot().
665676
666677
Returns:
667678
tuple[plt.Figure | go.Figure, pd.DataFrame]: The matplotlib/plotly figure and
668679
dataframe of cumulative metrics for each model.
669680
"""
670-
factory = lambda: pd.DataFrame(index=range(len(e_above_hull_true)))
671-
dfs: dict[str, pd.DataFrame] = defaultdict(factory)
672-
metrics_no_case = [*map(str.casefold, metrics)]
681+
dfs: dict[str, pd.DataFrame] = defaultdict(pd.DataFrame)
682+
683+
# largest number of materials predicted stable by any model, determines x-axis range
684+
n_max_pred_stable = (df_preds < stability_threshold).sum().max()
685+
longest_xs = np.linspace(0, n_max_pred_stable - 1, n_points)
686+
for metric in metrics:
687+
dfs[metric].index = longest_xs
673688

674-
valid_metrics = {"precision", "recall", "f1", "mae", "rmse"}
675-
if invalid_metrics := set(metrics_no_case) - valid_metrics:
689+
valid_metrics = {"Precision", "Recall", "F1", "MAE", "RMSE"}
690+
if invalid_metrics := set(metrics) - valid_metrics:
676691
raise ValueError(
677692
f"{invalid_metrics=}, should be case-insensitive subset of {valid_metrics=}"
678693
)
@@ -691,35 +706,36 @@ def cumulative_metrics(
691706
precision_cum = true_pos_cum / (true_pos_cum + false_pos_cum)
692707
recall_cum = true_pos_cum / n_total_pos # aka true_pos_rate aka sensitivity
693708

694-
end = int(np.argmax(recall_cum))
695-
xs = np.arange(end)
709+
n_pred_stable = sum(each_pred <= stability_threshold)
710+
model_range = np.arange(n_pred_stable) # xs for interpolation
711+
xs_model = longest_xs[longest_xs < n_pred_stable - 1] # xs for plotting
696712

697-
if "precision" in metrics_no_case:
698-
prec_interp = scipy.interpolate.interp1d(
699-
xs, precision_cum[:end], kind="cubic"
700-
)
701-
dfs["Precision"][model_name] = pd.Series(prec_interp(xs))
702-
if "recall" in metrics_no_case:
703-
recall_interp = scipy.interpolate.interp1d(
704-
xs, recall_cum[:end], kind="cubic"
705-
)
706-
dfs["Recall"][model_name] = pd.Series(recall_interp(xs))
707-
if "f1" in metrics_no_case:
713+
cubic_interpolate = functools.partial(scipy.interpolate.interp1d, kind="cubic")
714+
715+
if "Precision" in metrics:
716+
prec_interp = cubic_interpolate(model_range, precision_cum[:n_pred_stable])
717+
dfs["Precision"][model_name] = dict(zip(xs_model, prec_interp(xs_model)))
718+
719+
if "Recall" in metrics:
720+
recall_interp = cubic_interpolate(model_range, recall_cum[:n_pred_stable])
721+
dfs["Recall"][model_name] = dict(zip(xs_model, recall_interp(xs_model)))
722+
723+
if "F1" in metrics:
708724
f1_cum = 2 * (precision_cum * recall_cum) / (precision_cum + recall_cum)
709-
f1_interp = scipy.interpolate.interp1d(xs, f1_cum[:end], kind="cubic")
710-
dfs["F1"][model_name] = pd.Series(f1_interp(xs))
725+
f1_interp = cubic_interpolate(model_range, f1_cum[:n_pred_stable])
726+
dfs["F1"][model_name] = dict(zip(xs_model, f1_interp(xs_model)))
711727

712-
if "mae" in metrics_no_case:
728+
if "MAE" in metrics:
713729
cum_errors = (each_true - each_pred).abs().cumsum()
714730
cum_counts = np.arange(1, len(each_true) + 1)
715731
mae_cum = cum_errors / cum_counts
716-
mae_interp = scipy.interpolate.interp1d(xs, mae_cum[:end], kind="cubic")
717-
dfs["MAE"][model_name] = pd.Series(mae_interp(xs))
732+
mae_interp = cubic_interpolate(model_range, mae_cum[:n_pred_stable])
733+
dfs["MAE"][model_name] = dict(zip(xs_model, mae_interp(xs_model)))
718734

719-
if "rmse" in metrics_no_case:
735+
if "RMSE" in metrics:
720736
rmse_cum = (((each_true - each_pred) ** 2).cumsum() / cum_counts) ** 0.5
721-
rmse_interp = scipy.interpolate.interp1d(xs, rmse_cum[:end], kind="cubic")
722-
dfs["RMSE"][model_name] = pd.Series(rmse_interp(xs))
737+
rmse_interp = cubic_interpolate(model_range, rmse_cum[:n_pred_stable])
738+
dfs["RMSE"][model_name] = dict(zip(xs_model, rmse_interp(xs_model)))
723739

724740
for key in dfs:
725741
# drop all-NaN rows so plotly plot x-axis only extends to largest number of
@@ -730,7 +746,6 @@ def cumulative_metrics(
730746

731747
df_cum = pd.concat(dfs.values())
732748
# subselect rows for speed, plot has sufficient precision with 1k rows
733-
df_cum = df_cum.iloc[:: len(df_cum) // 1000 or 1]
734749
n_stable = sum(e_above_hull_true <= STABILITY_THRESHOLD)
735750

736751
if backend == "matplotlib":
@@ -751,7 +766,7 @@ def cumulative_metrics(
751766
# plotting speed and reduced file size
752767
# falls back on every row if df has less than 1000 rows
753768
df = dfs[metric]
754-
df.iloc[:: len(df) // 1000 or 1].plot(
769+
df.plot(
755770
ax=ax,
756771
legend=False,
757772
backend=backend,

matbench_discovery/preds.py

+15
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,21 @@ def load_df_wbm_with_preds(
169169
models = list(df_metrics.T.MAE.sort_values().index)
170170

171171

172+
# To avoid confusion for anyone reading this code, we calculate the formation energy MAE
173+
# here and report it as the MAE for the energy above the convex hull prediction. The
174+
# former is more easily calculated but the two quantities are the same. The formation
175+
# energy of a material is the difference in energy between a material and its
176+
# constituent elements in their standard states. The distance to the convex hull is
177+
# defined as the difference between a material's formation energy and the minimum
178+
# formation energy of all possible stable materials made from the same elements. Since
179+
# the formation energy of a material is used to calculate the distance to the convex
180+
# hull, the error of a formation energy prediction directly determines the error in the
181+
# distance to the convex hull prediction.
182+
183+
# A further point of clarification: whenever we say convex hull distance we mean
184+
# the signed distance that is positive for thermodynamically unstable materials above
185+
# the hull and negative for stable materials below it.
186+
172187
# dataframe of all models' energy above convex hull (EACH) predictions (eV/atom)
173188
df_each_pred = pd.DataFrame()
174189
for model in models:

scripts/model_figs/cumulative_metrics.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
from pymatviz.utils import save_fig
1414

1515
from matbench_discovery import PDF_FIGS, SITE_FIGS
16-
from matbench_discovery.plots import cumulative_metrics
16+
from matbench_discovery.plots import (
17+
cumulative_metrics,
18+
plotly_line_styles,
19+
plotly_markers,
20+
)
1721
from matbench_discovery.preds import (
1822
df_each_pred,
19-
df_metrics,
2023
df_preds,
2124
each_true_col,
2225
models,
@@ -43,16 +46,14 @@
4346
# facet_col_wrap=2,
4447
# increase facet col gap
4548
facet_col_spacing=0.05,
49+
# markers=True,
4650
)
4751

4852
x_label = "Number of screened WBM test set materials"
4953
if backend == "matplotlib":
5054
# fig.suptitle(title)
5155
fig.text(0.5, -0.08, x_label, ha="center", fontdict={"size": 16})
5256
if backend == "plotly":
53-
fig.layout.legend = dict(x=1, y=0, bgcolor="rgba(0,0,0,0)", xanchor="right")
54-
if "MAE" in metrics:
55-
fig.layout.legend.update(traceorder="reversed")
5657
fig.layout.margin.update(l=0, r=0, t=30, b=50)
5758
fig.add_annotation(
5859
x=0.5,
@@ -64,14 +65,21 @@
6465
font=dict(size=14),
6566
)
6667
fig.update_traces(line=dict(width=3))
67-
fig.layout.legend.update(
68-
orientation="h", yanchor="bottom", y=1.1, xanchor="center", x=0.5
69-
)
68+
fig.layout.legend.update(bgcolor="rgba(0,0,0,0)")
69+
# fig.layout.legend.update(
70+
# orientation="h", yanchor="bottom", y=1.1, xanchor="center", x=0.5
71+
# )
72+
# if "MAE" in metrics:
73+
# fig.layout.legend.update(traceorder="reversed")
74+
75+
for trace, ls, marker in zip(fig.data, plotly_line_styles, plotly_markers):
76+
trace.line.dash = ls
77+
trace.marker.symbol = marker
7078

71-
for trace in fig.data:
7279
# show only the N best models by default
73-
if trace.name in df_metrics.T.sort_values("F1").index[:-6]:
74-
trace.visible = "legendonly"
80+
# if trace.name in df_metrics.T.sort_values("F1").index[:-6]:
81+
# trace.visible = "legendonly"
82+
7583
last_idx = pd.Series(trace.y).last_valid_index()
7684
last_x = trace.x[last_idx]
7785
last_y = trace.y[last_idx]
@@ -113,4 +121,4 @@
113121
# %%
114122
img_name = f"cumulative-{'-'.join(metrics).lower()}"
115123
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
116-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=900, height=400)
124+
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=1000, height=400)

scripts/model_figs/hist_classified_stable_vs_hull_dist_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,4 @@
123123
fig.layout.height = n_rows * 180
124124
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
125125
fig.layout.height = orig_height
126-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=n_cols * 220, height=n_rows * 100)
126+
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=n_cols * 280, height=n_rows * 130)

scripts/model_figs/make_hull_dist_box_plot.py

+10
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,19 @@
6969
fig.add_trace(box_plot)
7070

7171
fig.layout.legend.update(orientation="h", y=1.15)
72+
# prevent x-labels from rotating
73+
fig.layout.xaxis.tickangle = 0
74+
# use line breaks to offset every other x-label
75+
x_labels_with_offset = [
76+
label if idx % 2 == 0 else f"<br>{label}" for idx, label in enumerate(models)
77+
]
78+
fig.layout.xaxis.update(tickvals=models, ticktext=x_labels_with_offset)
79+
7280
fig.show()
7381

7482

7583
# %%
7684
save_fig(fig, f"{SITE_FIGS}/box-hull-dist-errors.svelte")
85+
fig.layout.showlegend = False
7786
save_fig(fig, f"{PDF_FIGS}/box-hull-dist-errors.pdf")
87+
fig.layout.showlegend = True

scripts/model_figs/roc_prc_curves_models.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@
8080
for anno in fig.layout.annotations:
8181
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles
8282

83-
line_styles = "solid dash dot dashdot".split() * 3
84-
markers = "circle square triangle-up triangle-down diamond cross star x".split() * 2
85-
for trace, ls, marker in zip(fig.data, line_styles, markers):
83+
for trace, ls, marker in zip(fig.data, plots.plotly_line_styles, plots.plotly_markers):
8684
trace.line.dash = ls
8785
trace.marker.symbol = marker
8886

scripts/model_figs/rolling_mae_vs_hull_dist_models.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,23 @@
5151
if model in df_metrics.T.sort_values("MAE").index[8:]:
5252
trace.visible = "legendonly" # show only top models by default
5353

54-
# increase line width
55-
fig.update_traces(line=dict(width=3))
54+
fig.update_traces(line=dict(width=3)) # increase line width
5655
fig.layout.legend.update(
5756
bgcolor="rgba(0,0,0,0)", title="", x=1.01, y=0, yanchor="bottom"
5857
)
59-
# increase legend handle size and reverse order
6058
fig.layout.margin.update(l=5, r=5, t=5, b=55)
6159

62-
# plot marginal histogram of true hull distances
60+
# plot marginal histogram of true hull distances along top of figure
61+
# fixes plot artifacts by adding noise to avoid piling up data in some bins
62+
# from rounded data
63+
noise = np.random.random(len(df_preds)) * 1e-12
6364
counts, bins = np.histogram(
64-
df_preds[each_true_col], bins=400, range=fig.layout.xaxis.range
65+
df_preds[each_true_col] + noise, bins=100, range=fig.layout.xaxis.range
6566
)
6667
marginal_trace = go.Scatter(
6768
x=bins, y=counts, name="Density", fill="tozeroy", showlegend=False, yaxis="y2"
6869
)
6970
marginal_trace.marker.color = "rgba(0, 150, 200, 1)"
70-
# add marginal trace to existing figure
7171
fig.add_trace(marginal_trace)
7272

7373
# update layout to include marginal plot

0 commit comments

Comments
 (0)