Skip to content

Commit c5d3496

Browse files
committed
remove kwarg stability_crit from hist_classified_stable_vs_hull_dist() and cumulative_clf_metric()
1 parent 20ef518 commit c5d3496

6 files changed

+41
-142
lines changed

matbench_discovery/plots.py

+4-36
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Literal, get_args
3+
from typing import Any, Literal
44

55
import matplotlib.pyplot as plt
66
import numpy as np
@@ -15,7 +15,6 @@
1515
__author__ = "Janosh Riebesell"
1616
__date__ = "2022-08-05"
1717

18-
StabilityCriterion = Literal["energy", "energy+std", "energy-std"]
1918
WhichEnergy = Literal["true", "pred"]
2019
AxLine = Literal["x", "y", "xy", ""]
2120

@@ -72,14 +71,12 @@
7271
def hist_classified_stable_vs_hull_dist(
7372
e_above_hull_pred: pd.Series,
7473
e_above_hull_true: pd.Series,
75-
std_pred: pd.Series = None,
7674
ax: plt.Axes = None,
7775
which_energy: WhichEnergy = "true",
78-
stability_crit: StabilityCriterion = "energy",
7976
stability_threshold: float = 0,
8077
show_threshold: bool = True,
8178
x_lim: tuple[float | None, float | None] = (-0.4, 0.4),
82-
rolling_accuracy: float = 0.02,
79+
rolling_accuracy: float | None = 0.02,
8380
) -> tuple[plt.Axes, dict[str, float]]:
8481
"""
8582
Histogram of the energy difference (either according to DFT ground truth [default]
@@ -98,21 +95,16 @@ def hist_classified_stable_vs_hull_dist(
9895
energy.
9996
e_above_hull_true (pd.Series): energy diff to convex hull according to DFT
10097
ground truth.
101-
std_pred (pd.Series, optional): standard deviation of the model's predicted
102-
formation energy.
10398
ax (plt.Axes, optional): matplotlib axes to plot on.
10499
which_energy (WhichEnergy, optional): Whether to use the true formation energy
105100
or the model's predicted formation energy for the histogram.
106-
stability_crit (StabilityCriterion, optional): Whether to add/subtract the
107-
model's predicted uncertainty from its energy prediction when measuring
108-
predicted stability.
109101
stability_threshold (float, optional): set stability threshold as distance to
110102
convex hull in eV/atom, usually 0 or 0.1 eV.
111103
show_threshold (bool, optional): Whether to plot stability threshold as dashed
112104
vertical line.
113105
x_lim (tuple[float | None, float | None]): x-axis limits.
114-
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to 0 to
115-
disable. Defaults to 0.01.
106+
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to None
107+
or 0 to disable. Defaults to 0.01.
116108
117109
Returns:
118110
tuple[plt.Axes, dict[str, float]]: plot axes and classification metrics
@@ -122,17 +114,7 @@ def hist_classified_stable_vs_hull_dist(
122114
"""
123115
ax = ax or plt.gca()
124116

125-
if stability_crit not in get_args(StabilityCriterion):
126-
raise ValueError(
127-
f"Invalid {stability_crit=} must be one of {get_args(StabilityCriterion)}"
128-
)
129-
130117
test = e_above_hull_pred + e_above_hull_true
131-
if stability_crit == "energy+std":
132-
test += std_pred
133-
elif stability_crit == "energy-std":
134-
test -= std_pred
135-
136118
# --- histogram of DFT-computed distance to convex hull
137119
if which_energy == "true":
138120
actual_pos = e_above_hull_true <= stability_threshold
@@ -348,8 +330,6 @@ def cumulative_clf_metric(
348330
e_above_hull_error: pd.Series,
349331
e_above_hull_true: pd.Series,
350332
metric: Literal["precision", "recall"],
351-
std_pred: pd.Series = None,
352-
stability_crit: StabilityCriterion = "energy",
353333
stability_threshold: float = 0, # set stability threshold as distance to convex
354334
# hull in eV / atom, usually 0 or 0.1 eV
355335
ax: plt.Axes = None,
@@ -370,9 +350,6 @@ def cumulative_clf_metric(
370350
e_above_hull_true (str, optional): Column name with convex hull distance values.
371351
Defaults to "e_above_hull".
372352
metric ('precision' | 'recall', optional): Metric to plot.
373-
stability_crit ('energy' | 'energy+std' | 'energy-std', optional): Whether to
374-
use energy+/-std as stability stability_crit where std is the model
375-
predicted uncertainty for the energy it stipulated. Defaults to "energy".
376353
stability_threshold (float, optional): Max distance from convex hull before
377354
material is considered unstable. Defaults to 0.
378355
label (str, optional): Model name used to identify its liens in the legend.
@@ -391,15 +368,6 @@ def cumulative_clf_metric(
391368
e_above_hull_error = e_above_hull_error.sort_values()
392369
e_above_hull_true = e_above_hull_true.loc[e_above_hull_error.index]
393370

394-
if stability_crit not in get_args(StabilityCriterion):
395-
raise ValueError(
396-
f"Invalid {stability_crit=} must be one of {get_args(StabilityCriterion)}"
397-
)
398-
if stability_crit == "energy+std":
399-
e_above_hull_error += std_pred
400-
elif stability_crit == "energy-std":
401-
e_above_hull_error -= std_pred
402-
403371
true_pos_mask = (e_above_hull_true <= stability_threshold) & (
404372
e_above_hull_error <= stability_threshold
405373
)
+26-45
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
# %%
2-
import pandas as pd
3-
4-
from matbench_discovery import ROOT, today
5-
from matbench_discovery.load_preds import df_wbm
6-
from matbench_discovery.plots import (
7-
StabilityCriterion,
8-
WhichEnergy,
9-
hist_classified_stable_vs_hull_dist,
10-
)
2+
from matbench_discovery import today
3+
from matbench_discovery.load_preds import load_df_wbm_with_preds
4+
from matbench_discovery.plots import WhichEnergy, hist_classified_stable_vs_hull_dist
115

126
__author__ = "Rhys Goodall, Janosh Riebesell"
137
__date__ = "2022-06-18"
@@ -25,56 +19,43 @@
2519

2620

2721
# %%
28-
df = pd.read_csv(
29-
# f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
30-
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
31-
).set_index("material_id")
32-
33-
df["e_above_hull"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
22+
model_name = "Wrenformer"
23+
df_wbm = load_df_wbm_with_preds(models=[model_name]).round(3)
3424

3525

3626
# %%
37-
nan_counts = df.isna().sum()
38-
assert all(nan_counts == 0), f"df should not have missing values: {nan_counts}"
39-
40-
# target_col = "e_form_target"
41-
target_col = "e_form_per_atom"
42-
stability_crit: StabilityCriterion = "energy"
27+
target_col = "e_form_per_atom_mp2020_corrected"
4328
which_energy: WhichEnergy = "true"
44-
45-
if "std" in stability_crit:
46-
# TODO column names to compute standard deviation from are currently hardcoded
47-
# needs to be updated when adding non-aviary models with uncertainty estimation
48-
var_aleatoric = (df.filter(like="_ale_") ** 2).mean(axis=1)
49-
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
50-
std_total = (var_epistemic + var_aleatoric) ** 0.5
51-
else:
52-
std_total = None
53-
54-
# make sure we average the expected number of ensemble member predictions
55-
pred_cols = df.filter(regex=r"_pred_\d").columns
56-
assert len(pred_cols) == 10
29+
# std_factor=0,+/-1,+/-2,... changes the criterion for material stability to
30+
# energy+std_factor*std. energy+std means predicted energy plus the model's uncertainty
31+
# in the prediction have to be on or below the convex hull to be considered stable. This
32+
# reduces the false positive rate, but increases the false negative rate. Vice versa for
33+
# energy-std. energy+std should be used for cautious exploration, energy-std for
34+
# exhaustive exploration.
35+
std_factor = 0
36+
37+
# TODO column names to compute standard deviation from are currently hardcoded
38+
# needs to be updated when adding non-aviary models with uncertainty estimation
39+
var_aleatoric = (df_wbm.filter(like="_ale_") ** 2).mean(axis=1)
40+
var_epistemic = df_wbm.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
41+
std_total = (var_epistemic + var_aleatoric) ** 0.5
42+
std_total = df_wbm[f"{model_name}_std"]
5743

5844
ax, metrics = hist_classified_stable_vs_hull_dist(
59-
e_above_hull_pred=df[pred_cols].mean(axis=1) - df[target_col],
60-
e_above_hull_true=df.e_above_hull,
45+
e_above_hull_pred=df_wbm[model_name] - std_factor * std_total - df_wbm[target_col],
46+
e_above_hull_true=df_wbm.e_above_hull_mp2020_corrected_ppd_mp,
6147
which_energy=which_energy,
62-
stability_crit=stability_crit,
63-
std_pred=std_total,
6448
# stability_threshold=-0.05,
65-
# rolling_accuracy=0,
49+
rolling_accuracy=0,
6650
)
6751

6852
fig = ax.figure
6953
fig.set_size_inches(10, 9)
7054

71-
ax.legend(
72-
loc="center left",
73-
frameon=False,
74-
title=f"Enrichment Factor = {metrics['enrichment']:.3}",
75-
)
55+
legend_title = f"Enrichment Factor = {metrics['enrichment']:.3}"
56+
ax.legend(loc="center left", frameon=False, title=legend_title)
7657

7758

7859
# %%
79-
fig_name = f"{today}-wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}"
60+
fig_name = f"{today}-wren-wbm-hull-dist-hist-{which_energy=}"
8061
# fig.savefig(f"{ROOT}/figures/{fig_name}.pdf")

scripts/hist_classified_stable_vs_hull_dist_batches.py

-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from matbench_discovery import ROOT, today
33
from matbench_discovery.load_preds import load_df_wbm_with_preds
44
from matbench_discovery.plots import (
5-
StabilityCriterion,
65
WhichEnergy,
76
hist_classified_stable_vs_hull_dist,
87
plt,
@@ -31,7 +30,6 @@
3130

3231
# %%
3332
which_energy: WhichEnergy = "true"
34-
stability_crit: StabilityCriterion = "energy"
3533
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
3634

3735
model_name = "Wrenformer"
@@ -44,7 +42,6 @@
4442
e_above_hull_pred=batch_df[model_name] - batch_df[target_col],
4543
e_above_hull_true=batch_df[e_above_hull_col],
4644
which_energy=which_energy,
47-
stability_crit=stability_crit,
4845
ax=ax,
4946
)
5047

@@ -59,7 +56,6 @@
5956
e_above_hull_pred=df_wbm[model_name] - df_wbm[target_col],
6057
e_above_hull_true=df_wbm[e_above_hull_col],
6158
which_energy=which_energy,
62-
stability_crit=stability_crit,
6359
ax=axs.flat[-1],
6460
)
6561

scripts/metrics_table.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@
135135
cmap="viridis",
136136
# gmap=np.log10(df_table) # for log scaled color map
137137
)
138-
df_styled
139138

140139

141140
# %%
@@ -145,4 +144,5 @@
145144
}
146145
df_styled.set_table_styles([dict(selector=sel, props=styles[sel]) for sel in styles])
147146

148-
df_styled.to_html(f"{ROOT}/figures/{today}-metrics-table.html")
147+
html_path = f"{ROOT}/figures/{today}-metrics-table.html"
148+
# df_styled.to_html(html_path)

scripts/precision_recall.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from matbench_discovery import ROOT, today
55
from matbench_discovery.load_preds import load_df_wbm_with_preds
6-
from matbench_discovery.plots import StabilityCriterion, cumulative_clf_metric, plt
6+
from matbench_discovery.plots import cumulative_clf_metric, plt
77

88
__author__ = "Rhys Goodall, Janosh Riebesell"
99

@@ -18,10 +18,6 @@
1818

1919
target_col = "e_form_per_atom_mp2020_corrected"
2020
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
21-
22-
23-
# %%
24-
stability_crit: StabilityCriterion = "energy"
2521
colors = "tab:blue tab:orange teal tab:pink black red turquoise tab:purple".split()
2622

2723

@@ -41,7 +37,6 @@
4137
color=color,
4238
label=f"{model_name}\n{F1=:.3}",
4339
project_end_point="xy",
44-
stability_crit=stability_crit,
4540
ax=ax_prec,
4641
metric="precision",
4742
)
@@ -52,7 +47,6 @@
5247
color=color,
5348
label=f"{model_name}\n{F1=:.3}",
5449
project_end_point="xy",
55-
stability_crit=stability_crit,
5650
ax=ax_recall,
5751
metric="recall",
5852
)
@@ -65,7 +59,7 @@
6559
# x-ticks every 10k materials
6660
# ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000))
6761

68-
fig.suptitle(f"{today} {stability_crit=}")
62+
fig.suptitle(f"{today} {model_name}")
6963
xlabel_cumulative = "Materials predicted stable sorted by hull distance"
7064
fig.text(0.5, -0.08, xlabel_cumulative, ha="center")
7165

0 commit comments

Comments
 (0)