Skip to content

Commit 70b2b1d

Browse files
committed
add global STABILITY_THRESHOLD to consistently parametrize across the codebase when materials count as thermodynamically stable
add figs/box-hull-dist.svelte generated by scripts/make_hull_dist_box_plot.py displayed on /si page add copy buttons to code blocks on /si page, add option to largest-error-scatter-select.svelte to show all figs in grid at once
1 parent 551050e commit 70b2b1d

17 files changed

+236
-30
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/charliermarsh/ruff-pre-commit
10-
rev: v0.0.269
10+
rev: v0.0.270
1111
hooks:
1212
- id: ruff
1313
args: [--fix]

data/mp/get_mp_energies.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pymatviz.utils import annotate_metrics
99
from tqdm import tqdm
1010

11-
from matbench_discovery import today
11+
from matbench_discovery import STABILITY_THRESHOLD, today
1212
from matbench_discovery.data import DATA_FILES
1313

1414
"""
@@ -34,7 +34,6 @@
3434
"energy_above_hull",
3535
"decomposition_enthalpy",
3636
"energy_type",
37-
"symmetry",
3837
}
3938

4039
with MPRester(use_document_model=False) as mpr:
@@ -86,7 +85,9 @@
8685
alpha=0.1,
8786
xlim=[-5, 1],
8887
ylim=[-1, 1],
89-
color=(df.decomposition_enthalpy > 0).map({True: "red", False: "blue"}),
88+
color=(df.decomposition_enthalpy > STABILITY_THRESHOLD).map(
89+
{True: "red", False: "blue"}
90+
),
9091
title=f"{today} - {len(df):,} MP entries",
9192
)
9293

data/wbm/eda.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from pymatviz.utils import save_fig
1414

15-
from matbench_discovery import FIGS, PDF_FIGS, ROOT, today
15+
from matbench_discovery import FIGS, PDF_FIGS, ROOT, STABILITY_THRESHOLD, today
1616
from matbench_discovery import plots as plots
1717
from matbench_discovery.data import DATA_FILES, df_wbm
1818
from matbench_discovery.energy import mp_elem_reference_entries
@@ -124,8 +124,8 @@
124124
fig = df_hist.plot.area(x=x_label, y="count", backend="plotly", range_x=range_x)
125125

126126
if col.startswith("e_above_hull"):
127-
n_stable = sum(df_wbm[col] <= 0)
128-
n_unstable = sum(df_wbm[col] > 0)
127+
n_stable = sum(df_wbm[col] <= STABILITY_THRESHOLD)
128+
n_unstable = sum(df_wbm[col] > STABILITY_THRESHOLD)
129129
assert n_stable + n_unstable == len(df_wbm.dropna())
130130

131131
dummy_mae = (df_wbm[col] - df_wbm[col].mean()).abs().mean()

matbench_discovery/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,8 @@
2222
# wandb <entity>/<project name> to record new runs to
2323
WANDB_PATH = "janosh/matbench-discovery"
2424

25+
# threshold on hull distance for a material to be considered stable
26+
STABILITY_THRESHOLD = 0
27+
2528
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
2629
today = timestamp.split("@")[0]

matbench_discovery/metrics.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import pandas as pd
77
from sklearn.metrics import r2_score
88

9+
from matbench_discovery import STABILITY_THRESHOLD
10+
911
"""Functions to classify energy above convex hull predictions as true/false
1012
positive/negative and compute performance metrics.
1113
"""
@@ -53,7 +55,7 @@ def classify_stable(
5355
def stable_metrics(
5456
each_true: Sequence[float],
5557
each_pred: Sequence[float],
56-
stability_threshold: float = 0,
58+
stability_threshold: float = STABILITY_THRESHOLD,
5759
) -> dict[str, float]:
5860
"""Get a dictionary of stability prediction metrics. Mostly binary classification
5961
metrics, but also MAE, RMSE and R2.
@@ -64,9 +66,12 @@ def stable_metrics(
6466
stability_threshold (float): Where to place stability threshold relative to
6567
convex hull in eV/atom, usually 0 or 0.1 eV. Defaults to 0.
6668
67-
Note: Should give equivalent classification metrics to sklearn.metrics.
68-
classification_report(each_true > 0, each_pred > 0, output_dict=True) which
69-
takes binary labels.
69+
Note: Should give equivalent classification metrics to
70+
sklearn.metrics.classification_report(
71+
each_true > STABILITY_THRESHOLD,
72+
each_pred > STABILITY_THRESHOLD,
73+
output_dict=True,
74+
)
7075
7176
Returns:
7277
dict[str, float]: dictionary of classification metrics with keys DAF, Precision,

matbench_discovery/plots.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pandas.io.formats.style import Styler
2323
from tqdm import tqdm
2424

25+
from matbench_discovery import STABILITY_THRESHOLD
2526
from matbench_discovery.metrics import classify_stable
2627

2728
__author__ = "Janosh Riebesell"
@@ -674,7 +675,7 @@ def cumulative_precision_recall(
674675
df_cum = pd.concat(dfs.values())
675676
# subselect rows for speed, plot has sufficient precision with 1k rows
676677
df_cum = df_cum.iloc[:: len(df_cum) // 1000 or 1]
677-
n_stable = sum(e_above_hull_true <= 0)
678+
n_stable = sum(e_above_hull_true <= STABILITY_THRESHOLD)
678679

679680
if backend == "matplotlib":
680681
fig, axs = plt.subplots(

matbench_discovery/preds.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77
from tqdm import tqdm
88

9-
from matbench_discovery import ROOT
9+
from matbench_discovery import ROOT, STABILITY_THRESHOLD
1010
from matbench_discovery.data import Files, df_wbm, glob_to_df
1111
from matbench_discovery.metrics import stable_metrics
1212
from matbench_discovery.plots import ev_per_atom, model_labels, quantity_labels
@@ -132,7 +132,7 @@ def load_df_wbm_with_preds(
132132

133133
df_metrics = pd.DataFrame()
134134
df_metrics_10k = pd.DataFrame() # look only at each model's 10k most stable predictions
135-
prevalence = (df_wbm[each_true_col] <= 0).mean()
135+
prevalence = (df_wbm[each_true_col] <= STABILITY_THRESHOLD).mean()
136136

137137
df_metrics.index.name = "model"
138138
for model in PRED_FILES:

scripts/make_hull_dist_box_plot.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# %%
2+
import plotly.express as px
3+
import plotly.graph_objects as go
4+
import seaborn as sns
5+
from pymatviz.utils import save_fig
6+
7+
from matbench_discovery import FIGS, PDF_FIGS, plots
8+
from matbench_discovery.preds import df_each_err, models
9+
10+
__author__ = "Janosh Riebesell"
11+
__date__ = "2023-05-25"
12+
13+
14+
# %%
15+
ax = df_each_err[models].plot.box(
16+
showfliers=False,
17+
rot=90,
18+
figsize=(12, 6),
19+
# color="blue",
20+
# different fill colors for each box
21+
# patch_artist=True,
22+
# notch=True,
23+
# bootstrap=10000,
24+
showmeans=True,
25+
# meanline=True,
26+
)
27+
ax.axhline(0, linewidth=1, color="gray", linestyle="--")
28+
29+
30+
# %%
31+
ax = sns.violinplot(
32+
data=df_each_err[models], inner="quartile", linewidth=0.3, palette="Set2", width=1
33+
)
34+
ax.set(ylim=(-0.9, 0.9))
35+
36+
37+
# %%
38+
px.box(
39+
df_each_err[models].melt(),
40+
x="variable",
41+
y="value",
42+
color="variable",
43+
points=False,
44+
hover_data={"variable": False},
45+
)
46+
47+
48+
# %%
49+
px.violin(
50+
df_each_err[models].melt(),
51+
x="variable",
52+
y="value",
53+
color="variable",
54+
violinmode="overlay",
55+
box=True,
56+
# points="all",
57+
hover_data={"variable": False},
58+
width=1000,
59+
height=500,
60+
)
61+
62+
63+
# %%
64+
fig = go.Figure()
65+
fig.layout.yaxis.title = plots.quantity_labels["e_above_hull_error"]
66+
fig.layout.margin = dict(l=0, r=0, b=0, t=0)
67+
68+
for col in models:
69+
val_min = df_each_err[col].quantile(0.05)
70+
lower_box = df_each_err[col].quantile(0.25)
71+
median = df_each_err[col].median()
72+
upper_box = df_each_err[col].quantile(0.75)
73+
val_max = df_each_err[col].quantile(0.95)
74+
75+
box_plot = go.Box(
76+
y=[val_min, lower_box, median, upper_box, val_max],
77+
name=col,
78+
width=0.7,
79+
)
80+
fig.add_trace(box_plot)
81+
82+
fig.layout.legend.update(orientation="h", y=1.15)
83+
fig.show()
84+
save_fig(fig, f"{FIGS}/box-hull-dist-errors.svelte")
85+
save_fig(fig, f"{PDF_FIGS}/box-hull-dist-errors.pdf")

scripts/prc_roc_curves_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.metrics import auc, precision_recall_curve, roc_curve
1313
from tqdm import tqdm
1414

15-
from matbench_discovery import FIGS, PDF_FIGS
15+
from matbench_discovery import FIGS, PDF_FIGS, STABILITY_THRESHOLD
1616
from matbench_discovery import plots as plots
1717
from matbench_discovery.preds import df_each_pred, df_preds, each_true_col, models
1818

@@ -35,7 +35,7 @@
3535
for model in (pbar := tqdm(models, desc="Calculating ROC curves")):
3636
pbar.set_postfix_str(model)
3737
na_mask = df_preds[each_true_col].isna() | df_each_pred[model].isna()
38-
y_true = (df_preds[~na_mask][each_true_col] <= 0).astype(int)
38+
y_true = (df_preds[~na_mask][each_true_col] <= STABILITY_THRESHOLD).astype(int)
3939
y_pred = df_each_pred[model][~na_mask]
4040
fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=0)
4141
AUC = auc(fpr, tpr)
@@ -98,7 +98,7 @@
9898
for model in (pbar := tqdm(list(df_each_pred), desc="Calculating ROC curves")):
9999
pbar.set_postfix_str(model)
100100
na_mask = df_preds[each_true_col].isna() | df_each_pred[model].isna()
101-
y_true = (df_preds[~na_mask][each_true_col] <= 0).astype(int)
101+
y_true = (df_preds[~na_mask][each_true_col] <= STABILITY_THRESHOLD).astype(int)
102102
y_pred = df_each_pred[model][~na_mask]
103103
prec, recall, thresholds = precision_recall_curve(y_true, y_pred, pos_label=0)
104104
df_tmp = pd.DataFrame(

site/src/app.css

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ pre code {
9999
display: inline-block;
100100
}
101101
pre {
102+
position: relative;
102103
border-radius: 4pt;
103104
font-size: 9.5pt;
104105
background-color: rgba(255, 255, 255, 0.05);

site/src/figs/box-hull-dist-errors.svelte

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/routes/+layout.svelte

+15-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import { repository } from '$site/package.json'
66
import { CmdPalette } from 'svelte-multiselect'
77
import Toc from 'svelte-toc'
8-
import { GitHubCorner, PrevNext } from 'svelte-zoo'
8+
import { CopyButton, GitHubCorner, PrevNext } from 'svelte-zoo'
99
import '../app.css'
1010
1111
const routes = Object.keys(import.meta.glob(`./*/+page.{svelte,md}`)).map(
@@ -45,6 +45,20 @@
4545
} else {
4646
document.documentElement.style.setProperty(`--main-max-width`, `50em`)
4747
}
48+
49+
for (const node of document.querySelectorAll('pre > code')) {
50+
// skip if <pre> already contains a button (presumably for copy)
51+
const pre = node.parentElement
52+
if (!pre || pre.querySelector(`button`)) continue
53+
54+
new CopyButton({
55+
target: pre,
56+
props: {
57+
content: node.textContent ?? '',
58+
style: 'position: absolute; top: 1ex; right: 1ex;',
59+
},
60+
})
61+
}
4862
})
4963
</script>
5064

site/src/routes/about-the-data/+page.svelte

+12-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import type { ChemicalElement } from 'elementari'
1010
import { ColorBar, ColorScaleSelect, PeriodicTable, TableInset } from 'elementari'
1111
import Select from 'svelte-multiselect'
12-
import { Toggle, Tooltip } from 'svelte-zoo'
12+
import { Toggle } from 'svelte-zoo'
1313
import type { Snapshot } from './$types'
1414
1515
const elem_counts = import.meta.glob(`./*-element-counts-{occu,comp}*.json`, {
@@ -59,12 +59,10 @@
5959
/>
6060
</TableInset>
6161
</PeriodicTable>
62-
<Tooltip
63-
text="occurrence=(Fe: 1, O: 1), composition: Fe2O3=(Fe: 2, O: 3)"
64-
style="display: inline-block; transform: translate(10cqw, 5ex);"
62+
<label
63+
for="count-mode"
64+
style="display: inline-block; transform: translate(10cqw, 5ex);">Count Mode</label
6565
>
66-
<label for="count-mode">Count Mode</label>
67-
</Tooltip>
6866
<Select
6967
id="count-mode"
7068
bind:selected={count_mode}
@@ -74,6 +72,7 @@
7472
/>
7573
<ColorScaleSelect bind:selected={color_scale} />
7674
</svelte:fragment>
75+
7776
<svelte:fragment slot="mp-elements-heatmap">
7877
<PeriodicTable
7978
heatmap_values={mp_elem_counts}
@@ -94,6 +93,13 @@
9493
/>
9594
</TableInset>
9695
</PeriodicTable>
96+
<p>
97+
The difference between count modes is best explained by example. <code
98+
>occurrence</code
99+
>
100+
mode maps Fe2O3 to (Fe: 1, O: 1), <code>composition</code> mode maps it to (Fe: 2, O:
101+
3).
102+
</p>
97103
</svelte:fragment>
98104
<svelte:fragment slot="wbm-each-hist">
99105
{#if browser}

0 commit comments

Comments
 (0)